aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs90
1 files changed, 51 insertions, 39 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index d5afb5e..229661f 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -598,53 +598,65 @@ occCountX initialS topexpr k = case topexpr of
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
- EFold1InnerD1{} ->
- -- TODO: currently somewhat pessimistic on usage here
+ EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
- SsPair' _ (SsArr' (SsPair' _ sTape)) ->
- go topexpr sTape (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull sTape ))) s)
- -- the primed patsyns also consume SsFull, so if the above doesn't match,
- -- something along the way is SsNone
- _ ->
- go topexpr SsNone (projectSmallerSubstruc (SsPair SsFull (SsArr (SsPair SsFull SsNone))) s)
- where
- go :: Expr x env (TPair (TArr n t1) (TArr (S n) (TPair t1 tape)))
- -> Substruc tape tape'
- -> (forall env'. Ex env' (TPair (TArr n t1) (TArr (S n) (TPair t1 tape'))) -> Ex env' t')
- -> r
- go (EFold1InnerD1 _ cm a b c) sTape project =
- occCountX (SsPair SsFull sTape) a $ \env1_2' mka ->
- withSome (scaleMany (Some env1_2')) $ \env1_2 ->
- occEnvPop' env1_2 $ \env1_1 _ ->
- occEnvPop' env1_1 $ \env1 _ ->
- occCountX SsFull b $ \env2 mkb ->
- occCountX SsFull c $ \env3 mkc ->
- withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ -- If nothing is necessary, we can execute a fold and then proceed to ignore it
+ SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex
+ -- If we don't need the stores, still a fold suffices
+ SsPair' sP SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext)
+ -- If for whatever reason the additional stores themselves are
+ -- unnecessary but the shape of the array is, then oblige
+ SsPair' sP (SsArr' SsNone) ->
+ let STArr sn _ = typeOf e3
+ foldex =
+ elet (mapExt (\_ -> ext) e3) $
+ EPair ext
+ (EShape ext (evar IZ))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (mapExt (\_ -> ext) (weakenExpr WSink e2))
+ (evar IZ))
+ in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
+ k env1 $ \env' ->
+ eunPair (mkfoldex env') $ \_ eshape earr ->
+ EPair ext earr (EBuild ext sn eshape (ENil ext))
+ -- If at least some of the additional stores are required, we need to keep this a mapAccum
+ SsPair' _ (SsArr' sB) ->
+ -- TODO: propagate usage of primals
+ occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' _ ->
+ occCountX SsFull e2 $ \env2 mkb ->
+ occCountX SsFull e3 $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
- -- projectSmallerSubstruc SsFull s $
- project $
+ projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
(mkb env') (mkc env')
- go _ _ _ = error "impossible"
- EFold1InnerD2 _ cm t2 ef ep ezi ebog ed ->
- -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage
- occCountX SsFull ef $ \env1_4' mkef ->
- withSome (scaleMany (Some env1_4')) $ \env1_4 ->
- occEnvPop' env1_4 $ \env1_3 _ ->
- occEnvPop' env1_3 $ \env1_2 _ ->
- occEnvPop' env1_2 $ \env1_1 _ ->
- occEnvPop' env1_1 $ \env1 sTape ->
- occCountX SsFull ep $ \env2 mkep ->
- occCountX SsFull ezi $ \env3 mkezi ->
- occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog ->
+ EFold1InnerD2 _ cm ef ez eplus ebog ed ->
+ -- TODO: propagate usage of duals
+ occCountX SsFull ef $ \env1_2' mkef ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' sB ->
+ occCountX SsFull ez $ \env2' mkez ->
+ occCountX SsFull eplus $ \env3_2' mkeplus ->
+ occEnvPop' env3_2' $ \env3_1' _ ->
+ occEnvPop' env3_1' $ \env3' _ ->
+ occCountX (SsArr sB) ebog $ \env4 mkebog ->
occCountX SsFull ed $ \env5 mked ->
- withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env ->
+ withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env ->
k env $ \env' ->
projectSmallerSubstruc SsFull s $
- EFold1InnerD2 ext cm t2
- (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull))
- (mkep env') (mkezi env') (mkebog env') (mked env')
+ EFold1InnerD2 ext cm
+ (mkef (OccPush (OccPush env' () sB) () SsFull))
+ (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull))
+ (mkebog env') (mked env')
EConst _ t x ->
k OccEnd $ \_ ->