From 62639875102decae2bb96b3847ae48db5d1f8fd0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:09:56 +0200 Subject: Complete pattern matches --- src/AST/Count.hs | 1 + src/AST/SplitLets.hs | 1 + src/Analysis/Identity.hs | 7 +++++++ src/ForwardAD/DualNumbers.hs | 1 + 4 files changed, 10 insertions(+) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 05be524..ca4d7ab 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -136,6 +136,7 @@ occCountGeneral onehot unpush alter many = go WId EWith _ _ a b -> re a <> re1 b EAccum _ _ _ a _ b e -> re a <> re b <> re e EZero _ _ e -> re e + EDeepZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b EError{} -> mempty diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 2dad17a..dcaf82f 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -65,6 +65,7 @@ splitLets' = \sub -> \case EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 2fd321d..b54946b 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -320,6 +320,13 @@ idana env expr = case expr of res <- genIds (fromSMTy t) pure (res, EZero res t e1') + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index a6d5ec8..3ab08af 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -190,6 +190,7 @@ dfwdDN = \case EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid -- cgit v1.2.3-70-g09d2