diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-01-27 15:08:02 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-01-27 15:08:02 +0100 |
commit | 88fae8c2914b805a733b71de58ab672124e6069c (patch) | |
tree | c155fb1a83ace92aab376202ebc8b4b8a919da7c | |
parent | 0bdc36d221703e5a2347d3d136d676a86bdb1b6a (diff) |
Add ext field to remaining AST constructors
-rw-r--r-- | src/AST.hs | 36 | ||||
-rw-r--r-- | src/AST/Count.hs | 10 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 12 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 26 | ||||
-rw-r--r-- | src/Analysis/Identity.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 26 | ||||
-rw-r--r-- | src/CHAD/Accum.hs | 2 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 2 | ||||
-rw-r--r-- | src/Compile.hs | 6 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 12 | ||||
-rw-r--r-- | src/Language/AST.hs | 2 | ||||
-rw-r--r-- | src/Simplify.hs | 68 |
13 files changed, 103 insertions, 103 deletions
@@ -111,17 +111,17 @@ data Expr x env t where -> Expr x env t -- accumulation effect - EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil + EWith :: x (TPair a t) -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EAccum :: x TNil -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: STy t -> Expr x env (D2 t) - EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t) + EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) + EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) + EOneHot :: x (D2 t) -> STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t) -- partiality - EError :: STy a -> String -> Expr x env a + EError :: x a -> STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) type Ex = Expr (Const ()) @@ -247,14 +247,14 @@ typeOf = \case ECustom _ _ _ _ e _ _ _ _ -> typeOf e - EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ -> STNil + EWith _ e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ _ _ -> STNil - EZero t -> d2 t - EPlus t _ _ -> d2 t - EOneHot t _ _ _ -> d2 t + EZero _ t -> d2 t + EPlus _ t _ _ -> d2 t + EOneHot _ t _ _ _ -> d2 t - EError t _ -> t + EError _ t _ -> t -- unSNat :: SNat n -> Nat -- unSNat SZ = Z @@ -322,12 +322,12 @@ subst' f w = \case EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) - EZero t -> EZero t - EPlus t a b -> EPlus t (subst' f w a) (subst' f w b) - EOneHot t i a b -> EOneHot t i (subst' f w a) (subst' f w b) - EError t s -> EError t s + EWith x e1 e2 -> EWith x (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum x i e1 e2 e3 -> EAccum x i (subst' f w e1) (subst' f w e2) (subst' f w e3) + EZero x t -> EZero x t + EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) + EOneHot x t i a b -> EOneHot x t i (subst' f w a) (subst' f w b) + EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 22a4da6..b7079ff 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -128,11 +128,11 @@ occCountGeneral onehot unpush alter many = go WId EShape _ e -> re e EOp _ _ e -> re e ECustom _ _ _ _ _ _ _ a b -> re a <> re b - EWith a b -> re a <> re1 b - EAccum _ a b e -> re a <> re b <> re e - EZero _ -> mempty - EPlus _ a b -> re a <> re b - EOneHot _ _ a b -> re a <> re b + EWith _ a b -> re a <> re1 b + EAccum _ _ a b e -> re a <> re b <> re e + EZero _ _ -> mempty + EPlus _ _ a b -> re a <> re b + EOneHot _ _ _ a b -> re a <> re b EError{} -> mempty where re :: Monoid (r env') => Expr x env' t'' -> r env' diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 663e9b0..24bacdb 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -203,7 +203,7 @@ ppExpr' d val = \case . e1' . showString " " . e2' - EWith e1 e2 -> do + EWith _ e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 e2' <- ppExpr' 0 (Const name `SCons` val) e2 @@ -211,27 +211,27 @@ ppExpr' d val = \case showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") . e2' . showString ")" - EAccum i e1 e2 e3 -> do + EAccum _ i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' - EZero t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")") + EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")") - EPlus _ a b -> do + EPlus _ _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' - EOneHot _ i a b -> do + EOneHot _ _ i a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b' - EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) + EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS ppExprLet d val etop = do diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 8da1e32..c87bed1 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -11,9 +11,9 @@ import Data unMonoid :: Ex env t -> Ex env t unMonoid = \case - EZero t -> zero t - EPlus t a b -> plus t a b - EOneHot t i a b -> onehot t i a b + EZero _ t -> zero t + EPlus _ t a b -> plus t a b + EOneHot _ t i a b -> onehot t i a b EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -42,9 +42,9 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - EWith a b -> EWith (unMonoid a) (unMonoid b) - EAccum n a b e -> EAccum n (unMonoid a) (unMonoid b) (unMonoid e) - EError t s -> EError t s + EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) + EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) + EError _ t s -> EError ext t s zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext @@ -52,7 +52,7 @@ zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) zero (STArr SZ t) = EUnit ext (zero t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty") +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty") zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -76,9 +76,9 @@ plus (STEither t1 t2) a b = ECase ext (EVar ext t (IS IZ)) (ECase ext (EVar ext t (IS IZ)) (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) - (EError t "plus l+r")) + (EError ext t "plus l+r")) (ECase ext (EVar ext t (IS IZ)) - (EError t "plus r+l") + (EError ext t "plus r+l") (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) plus (STMaybe t) a b = plusSparse (d2 t) a b $ @@ -130,9 +130,9 @@ onehot t (SS dep) idx val = case t of (ECase ext (weakenExpr WSink val) (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) (zero t2)) - (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) + (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) (ECase ext (weakenExpr WSink val) - (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l") + (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l") (EPair ext (zero t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) @@ -146,9 +146,9 @@ onehot t (SS dep) idx val = case t of ECase ext idx (ECase ext (weakenExpr WSink val) (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) - (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r")) + (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r")) (ECase ext (weakenExpr WSink val) - (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l") + (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l") (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) STMaybe t1 -> EJust ext (onehot t1 dep idx val) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 5c398d2..e481a77 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -5,8 +5,6 @@ module Analysis.Identity ( identityAnalysis, ) where -import Data.Functor.Const - import AST import Data import Util.IdGen diff --git a/src/CHAD.hs b/src/CHAD.hs index aa5bd4c..6118e48 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -319,10 +319,10 @@ indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl ------------------------------------ MONOIDS ----------------------------------- zero :: STy t -> Ex env (D2 t) -zero = EZero +zero = EZero ext plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus = EPlus +plus = EPlus ext zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext @@ -594,7 +594,7 @@ drev des = \case SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) + (EAccum ext SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) Idx2Me tupI -> Ret BTop @@ -689,7 +689,7 @@ drev des = \case (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) + (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e @@ -701,7 +701,7 @@ drev des = \case (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") + (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") (weakenExpr (WCopy (wSinks' @[_,_])) e2)) (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) @@ -820,10 +820,10 @@ drev des = \case (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $ weakenExpr (WCopy (WSink .> WSink)) b2) - EError t s -> + EError _ t s -> Ret BTop SETop - (EError (d1 t) s) + (EError ext (d1 t) s) (subenvNone (select SMerge des)) (ENil ext) @@ -922,8 +922,8 @@ drev des = \case subtape (EReplicate1Inner ext en1 e1) sub - (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero eltty) + (ELet ext (EFold1Inner ext (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext eltty) (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ weakenExpr (WCopy WSink) e2) @@ -970,7 +970,7 @@ drev des = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot (STArr n eltty) n + (ELet ext (EOneHot ext (STArr n eltty) n (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ)) (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ) SS{} | Refl <- lemAcValArrN (d2 eltty) n -> @@ -1042,7 +1042,7 @@ drev des = \case (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))) (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) - (EZero t)) $ + (EZero ext t)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) @@ -1073,10 +1073,10 @@ drevScoped des argty argsto expr SMerge -> case sub of SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero argty)) + SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) SAccum -> RetScoped e0 subtape e1 sub $ - EWith (EZero argty) $ + EWith ext (EZero ext argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) &. #ac (auto1 @(TAccum (D2 a))) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index e26f781..659c45f 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -21,7 +21,7 @@ makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex e makeAccumulators SNil e = e makeAccumulators (t `SCons` envpro) e = makeAccumulators envpro $ - EWith (EZero t) e + EWith ext (EZero ext t) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 12594f2..d058132 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -69,7 +69,7 @@ reassembleD2E (des `DPush` (_, SMerge)) e = EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t) +reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) diff --git a/src/Compile.hs b/src/Compile.hs index 0db0d0f..05d51c1 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -313,11 +313,11 @@ compile' env = \case ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) - EWith a b -> error "TODO" -- EWith (compile' a) (compile' b) + EWith _ a b -> error "TODO" -- EWith (compile' a) (compile' b) - EAccum n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) + EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) - EError t s -> do + EError _ t s -> do name <- emitStruct t -- using 'show' here is wrong, but it's good enough for me. emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);" diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index e8b140e..aa35a5b 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -178,7 +178,7 @@ dfwdDN = \case ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) - EError t s -> EError (dn t) s + EError _ t s -> EError ext (dn t) s EWith{} -> err_accum EAccum{} -> err_accum diff --git a/src/Interpreter.hs b/src/Interpreter.hs index bb4952c..deb829b 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -134,27 +134,27 @@ interpret'Rec env = \case e1' <- interpret' env e1 e2' <- interpret' env e2 interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr - EWith e1 e2 -> do + EWith _ e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) (typeOf e2) initval $ \accum -> interpret' (Value accum `SCons` env) e2 - EAccum i e1 e2 e3 -> do + EAccum _ i e1 e2 e3 -> do let STAccum t = typeOf e3 idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 accumAddSparse t i accum idx val - EZero t -> do + EZero _ t -> do return $ zeroD2 t - EPlus t a b -> do + EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b return $ addD2s t a' b' - EOneHot t i a b -> do + EOneHot _ t i a b -> do a' <- interpret' env a b' <- interpret' env b return $ onehotD2 i t a' b' - EError _ s -> error $ "Interpreter: Program threw error: " ++ s + EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t interpretOp op arg = case op of diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 8c91d59..022e797 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -205,7 +205,7 @@ fromNamedExpr val = \case (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) - NEError t s -> EError t s + NEError t s -> EError ext t s NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args where diff --git a/src/Simplify.hs b/src/Simplify.hs index 6303716..785e2bd 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -93,33 +93,33 @@ simplify' = \case -- TODO: constant folding for operations -- TODO: properly concatenate accum/onehot - EAccum SZ _ (EOneHot _ i idx val) acc -> + EAccum _ SZ _ (EOneHot _ _ i idx val) acc -> acted $ simplify' $ - EAccum i idx val acc - EAccum _ _ (EZero _) _ -> (Any True, ENil ext) - EPlus _ (EZero _) e -> acted $ simplify' e - EPlus _ e (EZero _) -> acted $ simplify' e - EOneHot _ SZ _ e -> acted $ simplify' e + EAccum ext i idx val acc + EAccum _ _ _ (EZero _ _) _ -> (Any True, ENil ext) + EPlus _ _ (EZero _ _) e -> acted $ simplify' e + EPlus _ _ e (EZero _ _) -> acted $ simplify' e + EOneHot _ _ SZ _ e -> acted $ simplify' e -- equations for plus - EPlus STNil _ _ -> (Any True, ENil ext) + EPlus _ STNil _ _ -> (Any True, ENil ext) - EPlus (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> - acted $ simplify' $ EJust ext (EPair ext (EPlus t1 a1 a2) (EPlus t2 b1 b2)) - EPlus STPair{} ENothing{} e -> acted $ simplify' e - EPlus STPair{} e ENothing{} -> acted $ simplify' e + EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> + acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) + EPlus _ STPair{} ENothing{} e -> acted $ simplify' e + EPlus _ STPair{} e ENothing{} -> acted $ simplify' e - EPlus (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> - acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus t1 a1 a2)) - EPlus (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> - acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus t2 b1 b2)) - EPlus STEither{} ENothing{} e -> acted $ simplify' e - EPlus STEither{} e ENothing{} -> acted $ simplify' e + EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> + acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) + EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> + acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) + EPlus _ STEither{} ENothing{} e -> acted $ simplify' e + EPlus _ STEither{} e ENothing{} -> acted $ simplify' e - EPlus (STMaybe t) (EJust _ e1) (EJust _ e2) -> - acted $ simplify' $ EJust ext (EPlus t e1 e2) - EPlus STMaybe{} ENothing{} e -> acted $ simplify' e - EPlus STMaybe{} e ENothing{} -> acted $ simplify' e + EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> + acted $ simplify' $ EJust ext (EPlus ext t e1 e2) + EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e + EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e -- fallback recursion EVar _ t i -> pure $ EVar ext t i @@ -154,12 +154,12 @@ simplify' = \case <*> (let ?accumInScope = False in simplify' b) <*> (let ?accumInScope = False in simplify' c) <*> simplify' e1 <*> simplify' e2 - EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 - EZero t -> pure $ EZero t - EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b - EOneHot t i a b -> EOneHot t i <$> simplify' a <*> simplify' b - EError t s -> pure $ EError t s + EWith _ e1 e2 -> EWith ext <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) + EAccum _ i e1 e2 e3 -> EAccum ext i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 + EZero _ t -> pure $ EZero ext t + EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b + EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b + EError _ t s -> pure $ EError ext t s acted :: (Any, a) -> (Any, a) acted (_, x) = (Any True, x) @@ -169,6 +169,8 @@ cheapExpr = \case EVar{} -> True ENil{} -> True EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e _ -> False -- | This can be made more precise by tracking (and not counting) adds on @@ -202,12 +204,12 @@ hasAdds = \case EIdx _ a b -> hasAdds a || hasAdds b EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e - EWith a b -> hasAdds a || hasAdds b - EAccum _ _ _ _ -> True - EZero _ -> False - EPlus _ a b -> hasAdds a || hasAdds b - EOneHot _ _ a b -> hasAdds a || hasAdds b - EError _ _ -> False + EWith _ a b -> hasAdds a || hasAdds b + EAccum _ _ _ _ _ -> True + EZero _ _ -> False + EPlus _ _ a b -> hasAdds a || hasAdds b + EOneHot _ _ _ a b -> hasAdds a || hasAdds b + EError _ _ _ -> False checkAccumInScope :: SList STy env -> Bool checkAccumInScope = \case SNil -> False |