summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST/Count.hs1
-rw-r--r--src/AST/SplitLets.hs1
-rw-r--r--src/Analysis/Identity.hs7
-rw-r--r--src/ForwardAD/DualNumbers.hs1
4 files changed, 10 insertions, 0 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 05be524..ca4d7ab 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -136,6 +136,7 @@ occCountGeneral onehot unpush alter many = go WId
EWith _ _ a b -> re a <> re1 b
EAccum _ _ _ a _ b e -> re a <> re b <> re e
EZero _ _ e -> re e
+ EDeepZero _ _ e -> re e
EPlus _ _ a b -> re a <> re b
EOneHot _ _ _ a b -> re a <> re b
EError{} -> mempty
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 2dad17a..dcaf82f 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -65,6 +65,7 @@ splitLets' = \sub -> \case
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
EZero x t ezi -> EZero x t (splitLets' sub ezi)
+ EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 2fd321d..b54946b 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -320,6 +320,13 @@ idana env expr = case expr of
res <- genIds (fromSMTy t)
pure (res, EZero res t e1')
+ EDeepZero _ t e1 -> do
+ -- Approximate the result of EDeepZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EDeepZero res t e1')
+
EPlus _ t e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index a6d5ec8..3ab08af 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -190,6 +190,7 @@ dfwdDN = \case
EWith{} -> err_accum
EAccum{} -> err_accum
+ EDeepZero{} -> err_monoid
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid