summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 19:23:23 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 19:23:23 +0200
commit33d66f27f9eb3658d366fee71f0d0a0c5748a0e3 (patch)
tree21e8039b80377b72800f3173bd9038b78e141e31
parent8b047ff11ebd4715647bfc041a190f72dcf4d5a9 (diff)
Implement weakenExpr using subst
This saves one traversal function.
-rw-r--r--src/AST.hs63
-rw-r--r--src/Simplify.hs46
2 files changed, 45 insertions, 64 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 2267672..90baaf0 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -185,27 +185,54 @@ unSScalTy = \case
STF64 -> TF64
STBool -> TBool
-weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
-weakenExpr w = \case
- EVar x t i -> EVar x t (w @> i)
- ELet x rhs body -> ELet x (weakenExpr w rhs) (weakenExpr (WCopy w) body)
- EPair x e1 e2 -> EPair x (weakenExpr w e1) (weakenExpr w e2)
- EFst x e -> EFst x (weakenExpr w e)
- ESnd x e -> ESnd x (weakenExpr w e)
+subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+subst1 repl = subst $ \x t -> \case IZ -> repl
+ IS i -> EVar x t i
+
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
+ -> Expr x env t -> Expr x env' t
+subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
+
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+ -> env' :> envOut
+ -> Expr x env t
+ -> Expr x envOut t
+subst' f w = \case
+ EVar x t i -> f x t w i
+ ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
+ EPair x a b -> EPair x (subst' f w a) (subst' f w b)
+ EFst x e -> EFst x (subst' f w e)
+ ESnd x e -> ESnd x (subst' f w e)
ENil x -> ENil x
- EInl x t e -> EInl x t (weakenExpr w e)
- EInr x t e -> EInr x t (weakenExpr w e)
- ECase x e1 e2 e3 -> ECase x (weakenExpr w e1) (weakenExpr (WCopy w) e2) (weakenExpr (WCopy w) e3)
- EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2)
- EBuild x es e -> EBuild x (weakenExpr w <$> es) (weakenExpr (wcopyN (vecLength es) w) e)
- EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2)
+ EInl x t e -> EInl x t (subst' f w e)
+ EInr x t e -> EInr x t (subst' f w e)
+ ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
+ EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
EConst x t v -> EConst x t v
- EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2)
- EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenExpr w <$> es)
- EOp x op e -> EOp x op (weakenExpr w e)
- EWith e1 e2 -> EWith (weakenExpr w e1) (weakenExpr (WCopy w) e2)
- EAccum e1 e2 e3 -> EAccum (weakenExpr w e1) (weakenExpr w e2) (weakenExpr w e3)
+ EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
+ EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
+ EOp x op e -> EOp x op (subst' f w e)
+ EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3)
EError t s -> EError t s
+ where
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF f' x' t w' = \case
+ IZ -> EVar x' t (w' @> IZ)
+ IS i -> f' x' t (WPop w') i
+
+ sinkFN :: SNat n
+ -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
+ sinkFN SZ f' x t w' i = f' x t w' i
+ sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
+ sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i
+
+weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
+weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
wsinkN :: SNat n -> env :> ConsN n TIx env
wsinkN SZ = WId
diff --git a/src/Simplify.hs b/src/Simplify.hs
index a5f90b3..39b3afd 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -88,52 +88,6 @@ cheapExpr = \case
EConst{} -> True
_ -> False
-subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
-subst1 repl = subst $ \x t -> \case IZ -> repl
- IS i -> EVar x t i
-
-subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
- -> Expr x env t -> Expr x env' t
-subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
-
-subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
- -> env' :> envOut
- -> Expr x env t
- -> Expr x envOut t
-subst' f w = \case
- EVar x t i -> f x t w i
- ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
- EPair x a b -> EPair x (subst' f w a) (subst' f w b)
- EFst x e -> EFst x (subst' f w e)
- ESnd x e -> ESnd x (subst' f w e)
- ENil x -> ENil x
- EInl x t e -> EInl x t (subst' f w e)
- EInr x t e -> EInr x t (subst' f w e)
- ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
- EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
- EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
- EConst x t v -> EConst x t v
- EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
- EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
- EOp x op e -> EOp x op (subst' f w e)
- EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3)
- EError t s -> EError t s
- where
- sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
- sinkF f' x' t w' = \case
- IZ -> EVar x' t (w' @> IZ)
- IS i -> f' x' t (WPop w') i
-
- sinkFN :: SNat n
- -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
- sinkFN SZ f' x t w' i = f' x t w' i
- sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
- sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i
-
-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
hasAdds :: Expr x env t -> Bool