From 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:58:08 +0100 Subject: Simplify foldD2 to not sum x0 contributions --- src/AST/Count.hs | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) (limited to 'src/AST/Count.hs') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 66b4e0b..296c021 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -515,9 +515,8 @@ occCountX initialS topexpr k = case topexpr of EConstArr _ n t x -> case s of SsNone -> k OccEnd (\_ -> ENil ext) - SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) - SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x) - SsFull -> occCountX (SsArr SsFull) topexpr k + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) EBuild _ n a b -> case s of @@ -533,7 +532,7 @@ occCountX initialS topexpr k = case topexpr of weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ ENil ext) $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX s' b $ \env2'' mkb -> withSome (scaleMany (Some env2'')) $ \env2' -> @@ -543,7 +542,6 @@ occCountX initialS topexpr k = case topexpr of EBuild ext n (mka env') $ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) - SsFull -> occCountX (SsArr SsFull) topexpr k EFold1Inner _ commut a b c -> occCountX SsFull a $ \env1''' mka -> @@ -552,8 +550,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env1' $ \env1 s1 -> let s0 = case s of SsNone -> Some SsNone - SsArr s' -> Some s' - SsFull -> Some SsFull in + SsArr' s' -> Some s' in withSome (Some s1 <> Some s2 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> occCountX (SsArr sElt) c $ \env3 mkc -> @@ -573,11 +570,10 @@ occCountX initialS topexpr k = case topexpr of occCountX SsNone e $ \env mke -> k env $ \env' -> use (mke env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX s' e $ \env mke -> k env $ \env' -> EUnit ext (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k EReplicate1Inner _ a b -> case s of @@ -587,13 +583,12 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mka env') $ use (mkb env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX (SsArr s') b $ \env2 mkb -> withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> EReplicate1Inner ext (mka env') (mkb env') - SsFull -> occCountX (SsArr SsFull) topexpr k EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e @@ -654,23 +649,18 @@ occCountX initialS topexpr k = case topexpr of EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) (mkb env') (mkc env') - EFold1InnerD2 _ cm ef ez eplus ebog ed -> + EFold1InnerD2 _ cm ef ebog ed -> -- TODO: propagate usage of duals occCountX SsFull ef $ \env1_2' mkef -> occEnvPop' env1_2' $ \env1_1' _ -> occEnvPop' env1_1' $ \env1' sB -> - occCountX SsFull ez $ \env2' mkez -> - occCountX SsFull eplus $ \env3_2' mkeplus -> - occEnvPop' env3_2' $ \env3_1' _ -> - occEnvPop' env3_1' $ \env3' _ -> - occCountX (SsArr sB) ebog $ \env4 mkebog -> - occCountX SsFull ed $ \env5 mked -> - withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc SsFull s $ EFold1InnerD2 ext cm (mkef (OccPush (OccPush env' () sB) () SsFull)) - (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull)) (mkebog env') (mked env') EConst _ t x -> @@ -692,13 +682,12 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mka env') $ use (mkb env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX (SsArr s') a $ \env1 mka -> occCountX SsFull b $ \env2 mkb -> withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> EIdx1 ext (mka env') (mkb env') - SsFull -> occCountX (SsArr SsFull) topexpr k EIdx _ a b -> case s of @@ -863,16 +852,15 @@ occCountX initialS topexpr k = case topexpr of occCountX SsNone e $ \env mke -> k env $ \env' -> use (mke env') $ ENil ext - SsArr SsNone -> + SsArr' SsNone -> occCountX (SsArr SsNone) e $ \env mke -> k env $ \env' -> elet (mke env') $ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr SsFull -> + SsArr' SsFull -> occCountX (SsArr SsFull) e $ \env mke -> k env $ \env' -> reduce (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r -- cgit v1.2.3-70-g09d2