aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs48
-rw-r--r--src/AST/Pretty.hs16
-rw-r--r--src/AST/SplitLets.hs6
3 files changed, 39 insertions, 31 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index bc02417..ac8634e 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -321,13 +321,7 @@ projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
(s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
(SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
- (SsArr s1, SsArr s2)
- | STArr n t <- typeOf ex ->
- elet ex $
- EBuild ext n (EShape ext (evar IZ)) $
- projectSmallerSubstruc s1 s2
- (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex
(s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
(SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
@@ -560,22 +554,23 @@ occCountX initialS topexpr k = case topexpr of
EMap ext (mka (OccPush env' () s1)) (mkb env')
EFold1Inner _ commut a b c ->
- occCountX SsFull a $ \env1''' mka ->
- withSome (scaleMany (Some env1''')) $ \env1'' ->
- occEnvPop' env1'' $ \env1' s2 ->
- occEnvPop' env1' $ \env1 s1 ->
- let s0 = case s of
+ occCountX SsFull a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1' ->
+ let s1 = case s1' of
+ SsNone -> Some SsNone
+ SsPair' s1'a s1'b -> Some s1'a <> Some s1'b
+ s0 = case s of
SsNone -> Some SsNone
SsArr' s' -> Some s' in
- withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ withSome (s1 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
- occCountX (SsArr sElt) c $ \env3 mkc ->
- withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsArr sElt) s $
EFold1Inner ext commut
(projectSmallerSubstruc SsFull sElt $
- mka (OccPush (OccPush env' () sElt) () sElt))
+ mka (OccPush env' () (SsPair sElt sElt)))
(mkb env') (mkc env')
ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
@@ -638,6 +633,20 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mkb env') $ mka env'
+ SsArr' (SsPair' SsNone s2) ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $
+ emap (EPair ext (ENil ext) (evar IZ)) (mkb env')
+ SsArr' (SsPair' s1 SsNone) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $
+ emap (EPair ext (evar IZ) (ENil ext)) (mka env')
SsArr' (SsPair' s1 s2) ->
occCountX (SsArr s1) a $ \env1 mka ->
occCountX (SsArr s2) b $ \env2 mkb ->
@@ -665,7 +674,7 @@ occCountX initialS topexpr k = case topexpr of
elet (mapExt (\_ -> ext) e3) $
EPair ext
(EShape ext (evar IZ))
- (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1)))
(mapExt (\_ -> ext) (weakenExpr WSink e2))
(evar IZ))
in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
@@ -675,15 +684,14 @@ occCountX initialS topexpr k = case topexpr of
-- If at least some of the additional stores are required, we need to keep this a mapAccum
SsPair' _ (SsArr' sB) ->
-- TODO: propagate usage of primals
- occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
- occEnvPop' env1_2' $ \env1_1' _ ->
+ occCountX (SsPair SsFull sB) e1 $ \env1_1' mka ->
occEnvPop' env1_1' $ \env1' _ ->
occCountX SsFull e2 $ \env2 mkb ->
occCountX SsFull e3 $ \env3 mkc ->
withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
- EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ EFold1InnerD1 ext cm (mka (OccPush env' () SsFull))
(mkb env') (mkc env')
EFold1InnerD2 _ cm ef ebog ed ->
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 2c51b85..bbcfd9e 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -206,21 +206,20 @@ ppExpr' d val expr = case expr of
EMap _ a b -> do
let STArr _ t1 = typeOf b
- name <- genNameIfUsedIn' "i" t1 IZ a
+ name <- genNameIfUsedIn t1 IZ a
a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
return $ ppParen (d > 0) $
ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
EFold1Inner _ cm a b c -> do
- name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
- name2 <- genNameIfUsedIn (typeOf a) IZ a
- a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
@@ -254,14 +253,13 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
EFold1InnerD1 _ cm a b c -> do
- name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
- name2 <- genNameIfUsedIn (typeOf b) IZ a
- a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
let opname = "fold1iD1" ++ ppCommut cm
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
EFold1InnerD2 _ cm ef ebog ed -> do
let STArr _ tB = typeOf ebog
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index d276e44..267dd87 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -34,10 +34,10 @@ splitLets' = \sub -> \case
in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
- in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
EFold1InnerD1 x cm a b c ->
let STArr _ t1 = typeOf c
- in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
EFold1InnerD2 x cm a b c ->
let STArr _ tB = typeOf b
STArr _ t2 = typeOf c
@@ -56,12 +56,14 @@ splitLets' = \sub -> \case
ELInr x t e -> ELInr x t (splitLets' sub e)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
+ EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b)
ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
EUnit x e -> EUnit x (splitLets' sub e)
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
+ EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)