From 955af83f664639701fdbee54718186e07b31d42f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 Oct 2025 11:56:40 +0100 Subject: Better fold D{1,2} primitives --- src/AST/Count.hs | 92 +++++++++++++++++++++++++++++----------------------- src/AST/Pretty.hs | 22 +++++++------ src/AST/SplitLets.hs | 16 ++++----- src/AST/UnMonoid.hs | 2 +- 4 files changed, 73 insertions(+), 59 deletions(-) (limited to 'src/AST') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index d5afb5e..229661f 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -598,53 +598,65 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e - EFold1InnerD1{} -> - -- TODO: currently somewhat pessimistic on usage here + EFold1InnerD1 _ cm e1 e2 e3 -> case s of - SsPair' _ (SsArr' (SsPair' _ sTape)) -> - go topexpr sTape (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull sTape ))) s) - -- the primed patsyns also consume SsFull, so if the above doesn't match, - -- something along the way is SsNone - _ -> - go topexpr SsNone (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull SsNone))) s) - where - go :: Expr x env (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) - -> Substruc tape tape' - -> (forall env'. Ex env' (TPair (TArr n t1) (TArr (S n) (TPair t1 tape'))) -> Ex env' t') - -> r - go (EFold1InnerD1 _ cm a b c) sTape project = - occCountX (SsPair SsFull sTape) a $ \env1_2' mka -> - withSome (scaleMany (Some env1_2')) $ \env1_2 -> - occEnvPop' env1_2 $ \env1_1 _ -> - occEnvPop' env1_1 $ \env1 _ -> - occCountX SsFull b $ \env2 mkb -> - occCountX SsFull c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> - -- projectSmallerSubstruc SsFull s $ - project $ + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) (mkb env') (mkc env') - go _ _ _ = error "impossible" - - EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> - -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage - occCountX SsFull ef $ \env1_4' mkef -> - withSome (scaleMany (Some env1_4')) $ \env1_4 -> - occEnvPop' env1_4 $ \env1_3 _ -> - occEnvPop' env1_3 $ \env1_2 _ -> - occEnvPop' env1_2 $ \env1_1 _ -> - occEnvPop' env1_1 $ \env1 sTape -> - occCountX SsFull ep $ \env2 mkep -> - occCountX SsFull ezi $ \env3 mkezi -> - occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog -> + + EFold1InnerD2 _ cm ef ez eplus ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX SsFull ez $ \env2' mkez -> + occCountX SsFull eplus $ \env3_2' mkeplus -> + occEnvPop' env3_2' $ \env3_1' _ -> + occEnvPop' env3_1' $ \env3' _ -> + occCountX (SsArr sB) ebog $ \env4 mkebog -> occCountX SsFull ed $ \env5 mked -> - withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env -> + withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env -> k env $ \env' -> projectSmallerSubstruc SsFull s $ - EFold1InnerD2 ext cm t2 - (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull)) - (mkep env') (mkezi env') (mkebog env') (mked env') + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull)) + (mkebog env') (mked env') EConst _ t x -> k OccEnd $ \_ -> diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index afa62c6..587328d 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -245,21 +245,23 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] - EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do - let STArr _ (STPair t1 ttape) = typeOf ebog - name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef - name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef - name3 <- genNameIfUsedIn t1 (IS IZ) ef - name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef - ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef - ep' <- ppExpr' 11 val ep - ezi' <- ppExpr' 11 val ezi + EFold1InnerD2 _ cm ef ez eplus ebog ed -> do + let STArr _ tB = typeOf ebog + t2 = typeOf ez + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ez' <- ppExpr' 11 val ez + namep1 <- genNameIfUsedIn t2 (IS IZ) eplus + namep2 <- genNameIfUsedIn t2 IZ eplus + eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus ebog' <- ppExpr' 11 val ebog ed' <- ppExpr' 11 val ed let opname = "fold1iD2" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) - [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed'] + [ppLam [ppString namef1, ppString namef2] ef', ez' + ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed'] EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index f75e795..6034084 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -38,10 +38,10 @@ splitLets' = \sub -> \case EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD2 x cm t2 a b c d e -> - let STArr _ t1 = typeOf b - STArr _ (STPair _ ttape) = typeOf d - in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e) + EFold1InnerD2 x cm a b c d e -> + let t2 = typeOf b + STArr _ tB = typeOf d + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -106,10 +106,10 @@ splitLets' = \sub -> \case body -- TODO: abstract this to splitN lol wtf - split4 :: forall bind1 bind2 bind3 bind4 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t - split4 sub tbind1 tbind2 tbind3 tbind4 body = + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = let (ptrs1, bs1') = split @env' tbind1 (ptrs2, bs2') = split @(bind1 : env') tbind2 (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 555f0ec..6904715 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -45,7 +45,7 @@ unMonoid = \case EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) + EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) -- cgit v1.2.3-70-g09d2