From 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:58:08 +0100 Subject: Simplify foldD2 to not sum x0 contributions --- src/AST/Count.hs | 38 +++++++++++++------------------------- src/AST/Pretty.hs | 11 +++-------- src/AST/SplitLets.hs | 8 ++++---- src/AST/UnMonoid.hs | 2 +- 4 files changed, 21 insertions(+), 38 deletions(-) (limited to 'src/AST') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 66b4e0b..296c021 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -515,9 +515,8 @@ occCountX initialS topexpr k = case topexpr of EConstArr _ n t x -> case s of SsNone -> k OccEnd (\_ -> ENil ext) - SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) - SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x) - SsFull -> occCountX (SsArr SsFull) topexpr k + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) EBuild _ n a b -> case s of @@ -533,7 +532,7 @@ occCountX initialS topexpr k = case topexpr of weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ ENil ext) $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX s' b $ \env2'' mkb -> withSome (scaleMany (Some env2'')) $ \env2' -> @@ -543,7 +542,6 @@ occCountX initialS topexpr k = case topexpr of EBuild ext n (mka env') $ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) - SsFull -> occCountX (SsArr SsFull) topexpr k EFold1Inner _ commut a b c -> occCountX SsFull a $ \env1''' mka -> @@ -552,8 +550,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env1' $ \env1 s1 -> let s0 = case s of SsNone -> Some SsNone - SsArr s' -> Some s' - SsFull -> Some SsFull in + SsArr' s' -> Some s' in withSome (Some s1 <> Some s2 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> occCountX (SsArr sElt) c $ \env3 mkc -> @@ -573,11 +570,10 @@ occCountX initialS topexpr k = case topexpr of occCountX SsNone e $ \env mke -> k env $ \env' -> use (mke env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX s' e $ \env mke -> k env $ \env' -> EUnit ext (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k EReplicate1Inner _ a b -> case s of @@ -587,13 +583,12 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mka env') $ use (mkb env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX (SsArr s') b $ \env2 mkb -> withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> EReplicate1Inner ext (mka env') (mkb env') - SsFull -> occCountX (SsArr SsFull) topexpr k EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e @@ -654,23 +649,18 @@ occCountX initialS topexpr k = case topexpr of EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) (mkb env') (mkc env') - EFold1InnerD2 _ cm ef ez eplus ebog ed -> + EFold1InnerD2 _ cm ef ebog ed -> -- TODO: propagate usage of duals occCountX SsFull ef $ \env1_2' mkef -> occEnvPop' env1_2' $ \env1_1' _ -> occEnvPop' env1_1' $ \env1' sB -> - occCountX SsFull ez $ \env2' mkez -> - occCountX SsFull eplus $ \env3_2' mkeplus -> - occEnvPop' env3_2' $ \env3_1' _ -> - occEnvPop' env3_1' $ \env3' _ -> - occCountX (SsArr sB) ebog $ \env4 mkebog -> - occCountX SsFull ed $ \env5 mked -> - withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc SsFull s $ EFold1InnerD2 ext cm (mkef (OccPush (OccPush env' () sB) () SsFull)) - (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull)) (mkebog env') (mked env') EConst _ t x -> @@ -692,13 +682,12 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mka env') $ use (mkb env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX (SsArr s') a $ \env1 mka -> occCountX SsFull b $ \env2 mkb -> withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> EIdx1 ext (mka env') (mkb env') - SsFull -> occCountX (SsArr SsFull) topexpr k EIdx _ a b -> case s of @@ -863,16 +852,15 @@ occCountX initialS topexpr k = case topexpr of occCountX SsNone e $ \env mke -> k env $ \env' -> use (mke env') $ ENil ext - SsArr SsNone -> + SsArr' SsNone -> occCountX (SsArr SsNone) e $ \env mke -> k env $ \env' -> elet (mke env') $ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr SsFull -> + SsArr' SsFull -> occCountX (SsArr SsFull) e $ \env mke -> k env $ \env' -> reduce (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 67197f9..68fc629 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -250,23 +250,18 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] - EFold1InnerD2 _ cm ef ez eplus ebog ed -> do + EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog - t2 = typeOf ez + STArr _ t2 = typeOf ed namef1 <- genNameIfUsedIn tB (IS IZ) ef namef2 <- genNameIfUsedIn t2 IZ ef ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef - ez' <- ppExpr' 11 val ez - namep1 <- genNameIfUsedIn t2 (IS IZ) eplus - namep2 <- genNameIfUsedIn t2 IZ eplus - eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus ebog' <- ppExpr' 11 val ebog ed' <- ppExpr' 11 val ed let opname = "fold1iD2" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) - [ppLam [ppString namef1, ppString namef2] ef', ez' - ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed'] + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 73c1c67..d276e44 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -38,10 +38,10 @@ splitLets' = \sub -> \case EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD2 x cm a b c d e -> - let t2 = typeOf b - STArr _ tB = typeOf d - in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index e5a9708..a22b73f 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -46,7 +46,7 @@ unMonoid = \case EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) -- cgit v1.2.3-70-g09d2