From 955af83f664639701fdbee54718186e07b31d42f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 Oct 2025 11:56:40 +0100 Subject: Better fold D{1,2} primitives --- src/AST/Count.hs | 92 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 52 insertions(+), 40 deletions(-) (limited to 'src/AST/Count.hs') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index d5afb5e..229661f 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -598,53 +598,65 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e - EFold1InnerD1{} -> - -- TODO: currently somewhat pessimistic on usage here + EFold1InnerD1 _ cm e1 e2 e3 -> case s of - SsPair' _ (SsArr' (SsPair' _ sTape)) -> - go topexpr sTape (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull sTape ))) s) - -- the primed patsyns also consume SsFull, so if the above doesn't match, - -- something along the way is SsNone - _ -> - go topexpr SsNone (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull SsNone))) s) - where - go :: Expr x env (TPair (TArr n t1) (TArr (S n) (TPair t1 tape))) - -> Substruc tape tape' - -> (forall env'. Ex env' (TPair (TArr n t1) (TArr (S n) (TPair t1 tape'))) -> Ex env' t') - -> r - go (EFold1InnerD1 _ cm a b c) sTape project = - occCountX (SsPair SsFull sTape) a $ \env1_2' mka -> - withSome (scaleMany (Some env1_2')) $ \env1_2 -> - occEnvPop' env1_2 $ \env1_1 _ -> - occEnvPop' env1_1 $ \env1 _ -> - occCountX SsFull b $ \env2 mkb -> - occCountX SsFull c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> - -- projectSmallerSubstruc SsFull s $ - project $ + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) (mkb env') (mkc env') - go _ _ _ = error "impossible" - - EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> - -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage - occCountX SsFull ef $ \env1_4' mkef -> - withSome (scaleMany (Some env1_4')) $ \env1_4 -> - occEnvPop' env1_4 $ \env1_3 _ -> - occEnvPop' env1_3 $ \env1_2 _ -> - occEnvPop' env1_2 $ \env1_1 _ -> - occEnvPop' env1_1 $ \env1 sTape -> - occCountX SsFull ep $ \env2 mkep -> - occCountX SsFull ezi $ \env3 mkezi -> - occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog -> + + EFold1InnerD2 _ cm ef ez eplus ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX SsFull ez $ \env2' mkez -> + occCountX SsFull eplus $ \env3_2' mkeplus -> + occEnvPop' env3_2' $ \env3_1' _ -> + occEnvPop' env3_1' $ \env3' _ -> + occCountX (SsArr sB) ebog $ \env4 mkebog -> occCountX SsFull ed $ \env5 mked -> - withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env -> + withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env -> k env $ \env' -> projectSmallerSubstruc SsFull s $ - EFold1InnerD2 ext cm t2 - (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull)) - (mkep env') (mkezi env') (mkebog env') (mked env') + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull)) + (mkebog env') (mked env') EConst _ t x -> k OccEnd $ \_ -> -- cgit v1.2.3-70-g09d2