From d1b2e2c3a3cdaf49ff5e4bae6fe9b0612c3779c2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 00:00:11 +0200 Subject: Tests pass, should check if output is sensible --- src/Analysis/Identity.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/Analysis') diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 4501c32..2fd321d 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -307,11 +307,11 @@ idana env expr = case expr of let res = VIPair v2 x2 pure (res, EWith res t e1' e2') - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' e2' e3') + pure (VINil, EAccum VINil t prj e1' sp e2' e3') EZero _ t e1 -> do -- Approximate the result of EZero to be independent from the zero info -- cgit v1.2.3-70-g09d2 From 62639875102decae2bb96b3847ae48db5d1f8fd0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:09:56 +0200 Subject: Complete pattern matches --- src/AST/Count.hs | 1 + src/AST/SplitLets.hs | 1 + src/Analysis/Identity.hs | 7 +++++++ src/ForwardAD/DualNumbers.hs | 1 + 4 files changed, 10 insertions(+) (limited to 'src/Analysis') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 05be524..ca4d7ab 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -136,6 +136,7 @@ occCountGeneral onehot unpush alter many = go WId EWith _ _ a b -> re a <> re1 b EAccum _ _ _ a _ b e -> re a <> re b <> re e EZero _ _ e -> re e + EDeepZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b EError{} -> mempty diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 2dad17a..dcaf82f 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -65,6 +65,7 @@ splitLets' = \sub -> \case EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 2fd321d..b54946b 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -320,6 +320,13 @@ idana env expr = case expr of res <- genIds (fromSMTy t) pure (res, EZero res t e1') + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index a6d5ec8..3ab08af 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -190,6 +190,7 @@ dfwdDN = \case EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid -- cgit v1.2.3-70-g09d2