aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST/Count.hs22
-rw-r--r--src/Example.hs5
-rw-r--r--src/Simplify.hs15
-rw-r--r--test/Main.hs2
4 files changed, 33 insertions, 11 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index a53822d..ac8634e 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -321,13 +321,7 @@ projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
(s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
(SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
- (SsArr s1, SsArr s2)
- | STArr n t <- typeOf ex ->
- elet ex $
- EBuild ext n (EShape ext (evar IZ)) $
- projectSmallerSubstruc s1 s2
- (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex
(s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
(SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
@@ -639,6 +633,20 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mkb env') $ mka env'
+ SsArr' (SsPair' SsNone s2) ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $
+ emap (EPair ext (ENil ext) (evar IZ)) (mkb env')
+ SsArr' (SsPair' s1 SsNone) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $
+ emap (EPair ext (evar IZ) (ENil ext)) (mka env')
SsArr' (SsPair' s1 s2) ->
occCountX (SsArr s1) a $ \env1 mka ->
occCountX (SsArr s2) b $ \env2 mkb ->
diff --git a/src/Example.hs b/src/Example.hs
index 2c51291..e996002 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -34,9 +34,8 @@ pipeline config term
| Dict <- styKnown (d2 (typeOf term)) =
simplifyFix $ pruneExpr knownEnv $
simplifyFix $ unMonoid $
- chad' config knownEnv $
- simplifyFix $
- term
+ simplifyFix $ chad' config knownEnv $
+ simplifyFix $ term
-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures
pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO ()
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 1889adc..19d0c17 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -185,6 +185,21 @@ simplify'Rec = \case
ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
acted $ simplify' e1
+ -- map (\_ -> x) e ~> build (shape e) (\_ -> x)
+ EMap _ e1 e2
+ | Occ Zero Zero <- occCount IZ e1
+ , STArr n _ <- typeOf e2 ->
+ acted $ simplify' $
+ EBuild ext n (EShape ext e2) $
+ subst (\_ t' -> \case IZ -> error "Unused variable was used"
+ IS i -> EVar ext t' (IS i))
+ e1
+
+ -- vertical fusion
+ EMap _ e1 (EMap _ e2 e3) ->
+ acted $ simplify' $
+ EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3
+
-- projection down-commuting
EFst _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
diff --git a/test/Main.hs b/test/Main.hs
index d586973..c2141ee 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -351,7 +351,7 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
dtermSChadS = simplifyFix dtermSChad0
dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
- dtermSChadSUSP = pruneExpr env dtermSChadSUS
+ dtermSChadSUSP = simplifyFix $ pruneExpr env dtermSChadSUS
in
withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS
when (not (null output)) $