summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-01-27 15:08:02 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-01-27 15:08:02 +0100
commit88fae8c2914b805a733b71de58ab672124e6069c (patch)
treec155fb1a83ace92aab376202ebc8b4b8a919da7c
parent0bdc36d221703e5a2347d3d136d676a86bdb1b6a (diff)
Add ext field to remaining AST constructors
-rw-r--r--src/AST.hs36
-rw-r--r--src/AST/Count.hs10
-rw-r--r--src/AST/Pretty.hs12
-rw-r--r--src/AST/UnMonoid.hs26
-rw-r--r--src/Analysis/Identity.hs2
-rw-r--r--src/CHAD.hs26
-rw-r--r--src/CHAD/Accum.hs2
-rw-r--r--src/CHAD/Top.hs2
-rw-r--r--src/Compile.hs6
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs12
-rw-r--r--src/Language/AST.hs2
-rw-r--r--src/Simplify.hs68
13 files changed, 103 insertions, 103 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 333f306..bcbb19a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -111,17 +111,17 @@ data Expr x env t where
-> Expr x env t
-- accumulation effect
- EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
- EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
+ EWith :: x (TPair a t) -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ EAccum :: x TNil -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
-- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: STy t -> Expr x env (D2 t)
- EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
- EOneHot :: STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t)
+ EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)
+ EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
+ EOneHot :: x (D2 t) -> STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t)
-- partiality
- EError :: STy a -> String -> Expr x env a
+ EError :: x a -> STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
type Ex = Expr (Const ())
@@ -247,14 +247,14 @@ typeOf = \case
ECustom _ _ _ _ e _ _ _ _ -> typeOf e
- EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ -> STNil
+ EWith _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
+ EAccum _ _ _ _ _ -> STNil
- EZero t -> d2 t
- EPlus t _ _ -> d2 t
- EOneHot t _ _ _ -> d2 t
+ EZero _ t -> d2 t
+ EPlus _ t _ _ -> d2 t
+ EOneHot _ t _ _ _ -> d2 t
- EError t _ -> t
+ EError _ t _ -> t
-- unSNat :: SNat n -> Nat
-- unSNat SZ = Z
@@ -322,12 +322,12 @@ subst' f w = \case
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
- EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
- EZero t -> EZero t
- EPlus t a b -> EPlus t (subst' f w a) (subst' f w b)
- EOneHot t i a b -> EOneHot t i (subst' f w a) (subst' f w b)
- EError t s -> EError t s
+ EWith x e1 e2 -> EWith x (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum x i e1 e2 e3 -> EAccum x i (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EZero x t -> EZero x t
+ EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
+ EOneHot x t i a b -> EOneHot x t i (subst' f w a) (subst' f w b)
+ EError x t s -> EError x t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
-> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 22a4da6..b7079ff 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -128,11 +128,11 @@ occCountGeneral onehot unpush alter many = go WId
EShape _ e -> re e
EOp _ _ e -> re e
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
- EWith a b -> re a <> re1 b
- EAccum _ a b e -> re a <> re b <> re e
- EZero _ -> mempty
- EPlus _ a b -> re a <> re b
- EOneHot _ _ a b -> re a <> re b
+ EWith _ a b -> re a <> re1 b
+ EAccum _ _ a b e -> re a <> re b <> re e
+ EZero _ _ -> mempty
+ EPlus _ _ a b -> re a <> re b
+ EOneHot _ _ _ a b -> re a <> re b
EError{} -> mempty
where
re :: Monoid (r env') => Expr x env' t'' -> r env'
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 663e9b0..24bacdb 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -203,7 +203,7 @@ ppExpr' d val = \case
. e1' . showString " "
. e2'
- EWith e1 e2 -> do
+ EWith _ e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
e2' <- ppExpr' 0 (Const name `SCons` val) e2
@@ -211,27 +211,27 @@ ppExpr' d val = \case
showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
. e2' . showString ")"
- EAccum i e1 e2 e3 -> do
+ EAccum _ i e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ showParen (d > 10) $
showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
- EZero t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")")
+ EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")")
- EPlus _ a b -> do
+ EPlus _ _ a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
- EOneHot _ i a b -> do
+ EOneHot _ _ i a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $
showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b'
- EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
+ EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
ppExprLet d val etop = do
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 8da1e32..c87bed1 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -11,9 +11,9 @@ import Data
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
- EZero t -> zero t
- EPlus t a b -> plus t a b
- EOneHot t i a b -> onehot t i a b
+ EZero _ t -> zero t
+ EPlus _ t a b -> plus t a b
+ EOneHot _ t i a b -> onehot t i a b
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
@@ -42,9 +42,9 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
- EWith a b -> EWith (unMonoid a) (unMonoid b)
- EAccum n a b e -> EAccum n (unMonoid a) (unMonoid b) (unMonoid e)
- EError t s -> EError t s
+ EWith _ a b -> EWith ext (unMonoid a) (unMonoid b)
+ EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e)
+ EError _ t s -> EError ext t s
zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
@@ -52,7 +52,7 @@ zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
zero (STArr SZ t) = EUnit ext (zero t)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty")
+zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty")
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -76,9 +76,9 @@ plus (STEither t1 t2) a b =
ECase ext (EVar ext t (IS IZ))
(ECase ext (EVar ext t (IS IZ))
(EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
- (EError t "plus l+r"))
+ (EError ext t "plus l+r"))
(ECase ext (EVar ext t (IS IZ))
- (EError t "plus r+l")
+ (EError ext t "plus r+l")
(EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
plus (STMaybe t) a b =
plusSparse (d2 t) a b $
@@ -130,9 +130,9 @@ onehot t (SS dep) idx val = case t of
(ECase ext (weakenExpr WSink val)
(EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
(zero t2))
- (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
+ (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
(ECase ext (weakenExpr WSink val)
- (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
+ (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
(EPair ext (zero t1)
(onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
@@ -146,9 +146,9 @@ onehot t (SS dep) idx val = case t of
ECase ext idx
(ECase ext (weakenExpr WSink val)
(EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
- (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
+ (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
(ECase ext (weakenExpr WSink val)
- (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l")
+ (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l")
(EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
STMaybe t1 -> EJust ext (onehot t1 dep idx val)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 5c398d2..e481a77 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -5,8 +5,6 @@ module Analysis.Identity (
identityAnalysis,
) where
-import Data.Functor.Const
-
import AST
import Data
import Util.IdGen
diff --git a/src/CHAD.hs b/src/CHAD.hs
index aa5bd4c..6118e48 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -319,10 +319,10 @@ indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
------------------------------------ MONOIDS -----------------------------------
zero :: STy t -> Ex env (D2 t)
-zero = EZero
+zero = EZero ext
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus = EPlus
+plus = EPlus ext
zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
@@ -594,7 +594,7 @@ drev des = \case
SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
- (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
+ (EAccum ext SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -689,7 +689,7 @@ drev des = \case
(zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks' @[_,_])) e2)
- (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
+ (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
(EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))
EInr _ t1 e
@@ -701,7 +701,7 @@ drev des = \case
(EMaybe ext
(zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
- (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
+ (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
(EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))
@@ -820,10 +820,10 @@ drev des = \case
(ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $
weakenExpr (WCopy (WSink .> WSink)) b2)
- EError t s ->
+ EError _ t s ->
Ret BTop
SETop
- (EError (d1 t) s)
+ (EError ext (d1 t) s)
(subenvNone (select SMerge des))
(ENil ext)
@@ -922,8 +922,8 @@ drev des = \case
subtape
(EReplicate1Inner ext en1 e1)
sub
- (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (EZero eltty)
+ (ELet ext (EFold1Inner ext (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext eltty)
(EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
weakenExpr (WCopy WSink) e2)
@@ -970,7 +970,7 @@ drev des = \case
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EOneHot (STArr n eltty) n
+ (ELet ext (EOneHot ext (STArr n eltty) n
(arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ))
(case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ)
SS{} | Refl <- lemAcValArrN (d2 eltty) n ->
@@ -1042,7 +1042,7 @@ drev des = \case
(EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
(EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))))
(EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
- (EZero t)) $
+ (EZero ext t)) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
@@ -1073,10 +1073,10 @@ drevScoped des argty argsto expr
SMerge ->
case sub of
SEYes sub' -> RetScoped e0 subtape e1 sub' e2
- SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero argty))
+ SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))
SAccum ->
RetScoped e0 subtape e1 sub $
- EWith (EZero argty) $
+ EWith ext (EZero ext argty) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #body (subList (bindingsBinds e0) subtape)
&. #ac (auto1 @(TAccum (D2 a)))
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index e26f781..659c45f 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -21,7 +21,7 @@ makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex e
makeAccumulators SNil e = e
makeAccumulators (t `SCons` envpro) e =
makeAccumulators envpro $
- EWith (EZero t) e
+ EWith ext (EZero ext t) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 12594f2..d058132 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -69,7 +69,7 @@ reassembleD2E (des `DPush` (_, SMerge)) e =
EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
(EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
(ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t)
+reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t)
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
diff --git a/src/Compile.hs b/src/Compile.hs
index 0db0d0f..05d51c1 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -313,11 +313,11 @@ compile' env = \case
ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2)
- EWith a b -> error "TODO" -- EWith (compile' a) (compile' b)
+ EWith _ a b -> error "TODO" -- EWith (compile' a) (compile' b)
- EAccum n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e)
+ EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e)
- EError t s -> do
+ EError _ t s -> do
name <- emitStruct t
-- using 'show' here is wrong, but it's good enough for me.
emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);"
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index e8b140e..aa35a5b 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -178,7 +178,7 @@ dfwdDN = \case
ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
- EError t s -> EError (dn t) s
+ EError _ t s -> EError ext (dn t) s
EWith{} -> err_accum
EAccum{} -> err_accum
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index bb4952c..deb829b 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -134,27 +134,27 @@ interpret'Rec env = \case
e1' <- interpret' env e1
e2' <- interpret' env e2
interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
- EWith e1 e2 -> do
+ EWith _ e1 e2 -> do
initval <- interpret' env e1
withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
- EAccum i e1 e2 e3 -> do
+ EAccum _ i e1 e2 e3 -> do
let STAccum t = typeOf e3
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
accumAddSparse t i accum idx val
- EZero t -> do
+ EZero _ t -> do
return $ zeroD2 t
- EPlus t a b -> do
+ EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
return $ addD2s t a' b'
- EOneHot t i a b -> do
+ EOneHot _ t i a b -> do
a' <- interpret' env a
b' <- interpret' env b
return $ onehotD2 i t a' b'
- EError _ s -> error $ "Interpreter: Program threw error: " ++ s
+ EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
interpretOp op arg = case op of
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 8c91d59..022e797 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -205,7 +205,7 @@ fromNamedExpr val = \case
(fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
(go e1) (go e2)
- NEError t s -> EError t s
+ NEError t s -> EError ext t s
NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
where
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 6303716..785e2bd 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -93,33 +93,33 @@ simplify' = \case
-- TODO: constant folding for operations
-- TODO: properly concatenate accum/onehot
- EAccum SZ _ (EOneHot _ i idx val) acc ->
+ EAccum _ SZ _ (EOneHot _ _ i idx val) acc ->
acted $ simplify' $
- EAccum i idx val acc
- EAccum _ _ (EZero _) _ -> (Any True, ENil ext)
- EPlus _ (EZero _) e -> acted $ simplify' e
- EPlus _ e (EZero _) -> acted $ simplify' e
- EOneHot _ SZ _ e -> acted $ simplify' e
+ EAccum ext i idx val acc
+ EAccum _ _ _ (EZero _ _) _ -> (Any True, ENil ext)
+ EPlus _ _ (EZero _ _) e -> acted $ simplify' e
+ EPlus _ _ e (EZero _ _) -> acted $ simplify' e
+ EOneHot _ _ SZ _ e -> acted $ simplify' e
-- equations for plus
- EPlus STNil _ _ -> (Any True, ENil ext)
+ EPlus _ STNil _ _ -> (Any True, ENil ext)
- EPlus (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
- acted $ simplify' $ EJust ext (EPair ext (EPlus t1 a1 a2) (EPlus t2 b1 b2))
- EPlus STPair{} ENothing{} e -> acted $ simplify' e
- EPlus STPair{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
+ acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2))
+ EPlus _ STPair{} ENothing{} e -> acted $ simplify' e
+ EPlus _ STPair{} e ENothing{} -> acted $ simplify' e
- EPlus (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) ->
- acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus t1 a1 a2))
- EPlus (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) ->
- acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus t2 b1 b2))
- EPlus STEither{} ENothing{} e -> acted $ simplify' e
- EPlus STEither{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) ->
+ acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2))
+ EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) ->
+ acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2))
+ EPlus _ STEither{} ENothing{} e -> acted $ simplify' e
+ EPlus _ STEither{} e ENothing{} -> acted $ simplify' e
- EPlus (STMaybe t) (EJust _ e1) (EJust _ e2) ->
- acted $ simplify' $ EJust ext (EPlus t e1 e2)
- EPlus STMaybe{} ENothing{} e -> acted $ simplify' e
- EPlus STMaybe{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) ->
+ acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
+ EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e
+ EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
@@ -154,12 +154,12 @@ simplify' = \case
<*> (let ?accumInScope = False in simplify' b)
<*> (let ?accumInScope = False in simplify' c)
<*> simplify' e1 <*> simplify' e2
- EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
- EZero t -> pure $ EZero t
- EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b
- EOneHot t i a b -> EOneHot t i <$> simplify' a <*> simplify' b
- EError t s -> pure $ EError t s
+ EWith _ e1 e2 -> EWith ext <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
+ EAccum _ i e1 e2 e3 -> EAccum ext i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
+ EZero _ t -> pure $ EZero ext t
+ EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
+ EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b
+ EError _ t s -> pure $ EError ext t s
acted :: (Any, a) -> (Any, a)
acted (_, x) = (Any True, x)
@@ -169,6 +169,8 @@ cheapExpr = \case
EVar{} -> True
ENil{} -> True
EConst{} -> True
+ EFst _ e -> cheapExpr e
+ ESnd _ e -> cheapExpr e
_ -> False
-- | This can be made more precise by tracking (and not counting) adds on
@@ -202,12 +204,12 @@ hasAdds = \case
EIdx _ a b -> hasAdds a || hasAdds b
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
- EWith a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ -> True
- EZero _ -> False
- EPlus _ a b -> hasAdds a || hasAdds b
- EOneHot _ _ a b -> hasAdds a || hasAdds b
- EError _ _ -> False
+ EWith _ a b -> hasAdds a || hasAdds b
+ EAccum _ _ _ _ _ -> True
+ EZero _ _ -> False
+ EPlus _ _ a b -> hasAdds a || hasAdds b
+ EOneHot _ _ _ a b -> hasAdds a || hasAdds b
+ EError _ _ _ -> False
checkAccumInScope :: SList STy env -> Bool
checkAccumInScope = \case SNil -> False