aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
commit4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f (patch)
treee371c4962f1beee96cc68d55accffab16e18b97a /src/AST/Count.hs
parent4d456e4d34b1e4fb3725051d1b8a0c376b704692 (diff)
Simplify foldD2 to not sum x0 contributions
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs38
1 files changed, 13 insertions, 25 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 66b4e0b..296c021 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -515,9 +515,8 @@ occCountX initialS topexpr k = case topexpr of
EConstArr _ n t x ->
case s of
SsNone -> k OccEnd (\_ -> ENil ext)
- SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
- SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
- SsFull -> occCountX (SsArr SsFull) topexpr k
+ SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
EBuild _ n a b ->
case s of
@@ -533,7 +532,7 @@ occCountX initialS topexpr k = case topexpr of
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
ENil ext) $
ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
withSome (scaleMany (Some env2'')) $ \env2' ->
@@ -543,7 +542,6 @@ occCountX initialS topexpr k = case topexpr of
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
- SsFull -> occCountX (SsArr SsFull) topexpr k
EFold1Inner _ commut a b c ->
occCountX SsFull a $ \env1''' mka ->
@@ -552,8 +550,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env1' $ \env1 s1 ->
let s0 = case s of
SsNone -> Some SsNone
- SsArr s' -> Some s'
- SsFull -> Some SsFull in
+ SsArr' s' -> Some s' in
withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
occCountX (SsArr sElt) c $ \env3 mkc ->
@@ -573,11 +570,10 @@ occCountX initialS topexpr k = case topexpr of
occCountX SsNone e $ \env mke ->
k env $ \env' ->
use (mke env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX s' e $ \env mke ->
k env $ \env' ->
EUnit ext (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EReplicate1Inner _ a b ->
case s of
@@ -587,13 +583,12 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mka env') $ use (mkb env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX (SsArr s') b $ \env2 mkb ->
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
EReplicate1Inner ext (mka env') (mkb env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
@@ -654,23 +649,18 @@ occCountX initialS topexpr k = case topexpr of
EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
(mkb env') (mkc env')
- EFold1InnerD2 _ cm ef ez eplus ebog ed ->
+ EFold1InnerD2 _ cm ef ebog ed ->
-- TODO: propagate usage of duals
occCountX SsFull ef $ \env1_2' mkef ->
occEnvPop' env1_2' $ \env1_1' _ ->
occEnvPop' env1_1' $ \env1' sB ->
- occCountX SsFull ez $ \env2' mkez ->
- occCountX SsFull eplus $ \env3_2' mkeplus ->
- occEnvPop' env3_2' $ \env3_1' _ ->
- occEnvPop' env3_1' $ \env3' _ ->
- occCountX (SsArr sB) ebog $ \env4 mkebog ->
- occCountX SsFull ed $ \env5 mked ->
- withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env ->
+ occCountX (SsArr sB) ebog $ \env2 mkebog ->
+ occCountX SsFull ed $ \env3 mked ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc SsFull s $
EFold1InnerD2 ext cm
(mkef (OccPush (OccPush env' () sB) () SsFull))
- (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull))
(mkebog env') (mked env')
EConst _ t x ->
@@ -692,13 +682,12 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mka env') $ use (mkb env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX (SsArr s') a $ \env1 mka ->
occCountX SsFull b $ \env2 mkb ->
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
EIdx1 ext (mka env') (mkb env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EIdx _ a b ->
case s of
@@ -863,16 +852,15 @@ occCountX initialS topexpr k = case topexpr of
occCountX SsNone e $ \env mke ->
k env $ \env' ->
use (mke env') $ ENil ext
- SsArr SsNone ->
+ SsArr' SsNone ->
occCountX (SsArr SsNone) e $ \env mke ->
k env $ \env' ->
elet (mke env') $
EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr SsFull ->
+ SsArr' SsFull ->
occCountX (SsArr SsFull) e $ \env mke ->
k env $ \env' ->
reduce (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r