aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
commit4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f (patch)
treee371c4962f1beee96cc68d55accffab16e18b97a /src/AST
parent4d456e4d34b1e4fb3725051d1b8a0c376b704692 (diff)
Simplify foldD2 to not sum x0 contributions
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs38
-rw-r--r--src/AST/Pretty.hs11
-rw-r--r--src/AST/SplitLets.hs8
-rw-r--r--src/AST/UnMonoid.hs2
4 files changed, 21 insertions, 38 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 66b4e0b..296c021 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -515,9 +515,8 @@ occCountX initialS topexpr k = case topexpr of
EConstArr _ n t x ->
case s of
SsNone -> k OccEnd (\_ -> ENil ext)
- SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
- SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
- SsFull -> occCountX (SsArr SsFull) topexpr k
+ SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
EBuild _ n a b ->
case s of
@@ -533,7 +532,7 @@ occCountX initialS topexpr k = case topexpr of
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
ENil ext) $
ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
withSome (scaleMany (Some env2'')) $ \env2' ->
@@ -543,7 +542,6 @@ occCountX initialS topexpr k = case topexpr of
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
- SsFull -> occCountX (SsArr SsFull) topexpr k
EFold1Inner _ commut a b c ->
occCountX SsFull a $ \env1''' mka ->
@@ -552,8 +550,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env1' $ \env1 s1 ->
let s0 = case s of
SsNone -> Some SsNone
- SsArr s' -> Some s'
- SsFull -> Some SsFull in
+ SsArr' s' -> Some s' in
withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
occCountX (SsArr sElt) c $ \env3 mkc ->
@@ -573,11 +570,10 @@ occCountX initialS topexpr k = case topexpr of
occCountX SsNone e $ \env mke ->
k env $ \env' ->
use (mke env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX s' e $ \env mke ->
k env $ \env' ->
EUnit ext (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EReplicate1Inner _ a b ->
case s of
@@ -587,13 +583,12 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mka env') $ use (mkb env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX (SsArr s') b $ \env2 mkb ->
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
EReplicate1Inner ext (mka env') (mkb env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
@@ -654,23 +649,18 @@ occCountX initialS topexpr k = case topexpr of
EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
(mkb env') (mkc env')
- EFold1InnerD2 _ cm ef ez eplus ebog ed ->
+ EFold1InnerD2 _ cm ef ebog ed ->
-- TODO: propagate usage of duals
occCountX SsFull ef $ \env1_2' mkef ->
occEnvPop' env1_2' $ \env1_1' _ ->
occEnvPop' env1_1' $ \env1' sB ->
- occCountX SsFull ez $ \env2' mkez ->
- occCountX SsFull eplus $ \env3_2' mkeplus ->
- occEnvPop' env3_2' $ \env3_1' _ ->
- occEnvPop' env3_1' $ \env3' _ ->
- occCountX (SsArr sB) ebog $ \env4 mkebog ->
- occCountX SsFull ed $ \env5 mked ->
- withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env ->
+ occCountX (SsArr sB) ebog $ \env2 mkebog ->
+ occCountX SsFull ed $ \env3 mked ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc SsFull s $
EFold1InnerD2 ext cm
(mkef (OccPush (OccPush env' () sB) () SsFull))
- (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull))
(mkebog env') (mked env')
EConst _ t x ->
@@ -692,13 +682,12 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mka env') $ use (mkb env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX (SsArr s') a $ \env1 mka ->
occCountX SsFull b $ \env2 mkb ->
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
EIdx1 ext (mka env') (mkb env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EIdx _ a b ->
case s of
@@ -863,16 +852,15 @@ occCountX initialS topexpr k = case topexpr of
occCountX SsNone e $ \env mke ->
k env $ \env' ->
use (mke env') $ ENil ext
- SsArr SsNone ->
+ SsArr' SsNone ->
occCountX (SsArr SsNone) e $ \env mke ->
k env $ \env' ->
elet (mke env') $
EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr SsFull ->
+ SsArr' SsFull ->
occCountX (SsArr SsFull) e $ \env mke ->
k env $ \env' ->
reduce (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 67197f9..68fc629 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -250,23 +250,18 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
- EFold1InnerD2 _ cm ef ez eplus ebog ed -> do
+ EFold1InnerD2 _ cm ef ebog ed -> do
let STArr _ tB = typeOf ebog
- t2 = typeOf ez
+ STArr _ t2 = typeOf ed
namef1 <- genNameIfUsedIn tB (IS IZ) ef
namef2 <- genNameIfUsedIn t2 IZ ef
ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
- ez' <- ppExpr' 11 val ez
- namep1 <- genNameIfUsedIn t2 (IS IZ) eplus
- namep2 <- genNameIfUsedIn t2 IZ eplus
- eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus
ebog' <- ppExpr' 11 val ebog
ed' <- ppExpr' 11 val ed
let opname = "fold1iD2" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr)
- [ppLam [ppString namef1, ppString namef2] ef', ez'
- ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed']
+ [ppLam [ppString namef1, ppString namef2] ef', ebog', ed']
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 73c1c67..d276e44 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -38,10 +38,10 @@ splitLets' = \sub -> \case
EFold1InnerD1 x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
- EFold1InnerD2 x cm a b c d e ->
- let t2 = typeOf b
- STArr _ tB = typeOf d
- in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e)
+ EFold1InnerD2 x cm a b c ->
+ let STArr _ tB = typeOf b
+ STArr _ t2 = typeOf c
+ in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index e5a9708..a22b73f 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -46,7 +46,7 @@ unMonoid = \case
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)