aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs49
1 files changed, 43 insertions, 6 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 296c021..bc02417 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of
SsNone ->
occCountX SsFull a $ \env1 mka ->
occCountX SsNone b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
use (EBuild ext n (mka env') $
use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
@@ -535,14 +534,31 @@ occCountX initialS topexpr k = case topexpr of
SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+ EMap _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
+ ENil ext
+ SsArr' s' ->
+ occCountX s' a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMap ext (mka (OccPush env' () s1)) (mkb env')
+
EFold1Inner _ commut a b c ->
occCountX SsFull a $ \env1''' mka ->
withSome (scaleMany (Some env1''')) $ \env1'' ->
@@ -608,6 +624,27 @@ occCountX initialS topexpr k = case topexpr of
k env $ \env' ->
EReshape ext n (mkesh env') (mke env')
+ EZip _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $ mka env'
+ SsArr' (SsPair' s1 s2) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EZip ext (mka env') (mkb env')
+
EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
-- If nothing is necessary, we can execute a fold and then proceed to ignore it