aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
commit57779d4303f377004705c8da06a5ac46177950b2 (patch)
tree0407089403d3d5c2de778c1aab7aed8adf2d01c0 /src/AST/Count.hs
parent351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff)
drevLambda works, TODO D[map]HEADmaster
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs26
1 files changed, 13 insertions, 13 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index bc02417..a53822d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -560,22 +560,23 @@ occCountX initialS topexpr k = case topexpr of
EMap ext (mka (OccPush env' () s1)) (mkb env')
EFold1Inner _ commut a b c ->
- occCountX SsFull a $ \env1''' mka ->
- withSome (scaleMany (Some env1''')) $ \env1'' ->
- occEnvPop' env1'' $ \env1' s2 ->
- occEnvPop' env1' $ \env1 s1 ->
- let s0 = case s of
+ occCountX SsFull a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1' ->
+ let s1 = case s1' of
+ SsNone -> Some SsNone
+ SsPair' s1'a s1'b -> Some s1'a <> Some s1'b
+ s0 = case s of
SsNone -> Some SsNone
SsArr' s' -> Some s' in
- withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ withSome (s1 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
- occCountX (SsArr sElt) c $ \env3 mkc ->
- withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsArr sElt) s $
EFold1Inner ext commut
(projectSmallerSubstruc SsFull sElt $
- mka (OccPush (OccPush env' () sElt) () sElt))
+ mka (OccPush env' () (SsPair sElt sElt)))
(mkb env') (mkc env')
ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
@@ -665,7 +666,7 @@ occCountX initialS topexpr k = case topexpr of
elet (mapExt (\_ -> ext) e3) $
EPair ext
(EShape ext (evar IZ))
- (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1)))
(mapExt (\_ -> ext) (weakenExpr WSink e2))
(evar IZ))
in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
@@ -675,15 +676,14 @@ occCountX initialS topexpr k = case topexpr of
-- If at least some of the additional stores are required, we need to keep this a mapAccum
SsPair' _ (SsArr' sB) ->
-- TODO: propagate usage of primals
- occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
- occEnvPop' env1_2' $ \env1_1' _ ->
+ occCountX (SsPair SsFull sB) e1 $ \env1_1' mka ->
occEnvPop' env1_1' $ \env1' _ ->
occCountX SsFull e2 $ \env2 mkb ->
occCountX SsFull e3 $ \env3 mkc ->
withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
- EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ EFold1InnerD1 ext cm (mka (OccPush env' () SsFull))
(mkb env') (mkc env')
EFold1InnerD2 _ cm ef ebog ed ->