summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-05-25 23:34:13 +0200
committerTom Smeding <tom@tomsmeding.com>2025-05-25 23:34:13 +0200
commitfaa9af2ec2e463c1774f54b9e8f0ae3733cdb048 (patch)
tree46be36e09685054c151bec1ba24c0edf01dd68d9
parente9c4cad143d483e29213e9c121574d1d46c2d56a (diff)
Implement mapExt as travExt
-rw-r--r--src/AST.hs81
1 files changed, 43 insertions, 38 deletions
diff --git a/src/AST.hs b/src/AST.hs
index ca66e87..65664fc 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -20,6 +20,7 @@
module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
+import Data.Functor.Identity
import Data.Kind (Type)
import Array
@@ -262,44 +263,48 @@ extOf = \case
EError x _ _ -> x
mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
-mapExt f = \case
- EVar x t i -> EVar (f x) t i
- ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body)
- EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b)
- EFst x e -> EFst (f x) (mapExt f e)
- ESnd x e -> ESnd (f x) (mapExt f e)
- ENil x -> ENil (f x)
- EInl x t e -> EInl (f x) t (mapExt f e)
- EInr x t e -> EInr (f x) t (mapExt f e)
- ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b)
- ENothing x t -> ENothing (f x) t
- EJust x e -> EJust (f x) (mapExt f e)
- EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e)
- ELNil x t1 t2 -> ELNil (f x) t1 t2
- ELInl x t e -> ELInl (f x) t (mapExt f e)
- ELInr x t e -> ELInr (f x) t (mapExt f e)
- ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c)
- EConstArr x n t a -> EConstArr (f x) n t a
- EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b)
- EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c)
- ESum1Inner x e -> ESum1Inner (f x) (mapExt f e)
- EUnit x e -> EUnit (f x) (mapExt f e)
- EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b)
- EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e)
- EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e)
- EConst x t v -> EConst (f x) t v
- EIdx0 x e -> EIdx0 (f x) (mapExt f e)
- EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b)
- EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es)
- EShape x e -> EShape (f x) (mapExt f e)
- EOp x op e -> EOp (f x) op (mapExt f e)
- ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2)
- EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2)
- EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3)
- EZero x t e -> EZero (f x) t (mapExt f e)
- EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b)
- EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b)
- EError x t s -> EError (f x) t s
+mapExt f = runIdentity . travExt (Identity . f)
+
+{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-}
+travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t)
+travExt f = \case
+ EVar x t i -> EVar <$> f x <*> pure t <*> pure i
+ ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body
+ EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b
+ EFst x e -> EFst <$> f x <*> travExt f e
+ ESnd x e -> ESnd <$> f x <*> travExt f e
+ ENil x -> ENil <$> f x
+ EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e
+ EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e
+ ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b
+ ENothing x t -> ENothing <$> f x <*> pure t
+ EJust x e -> EJust <$> f x <*> travExt f e
+ EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e
+ ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2
+ ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e
+ ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e
+ ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
+ EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
+ EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
+ ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
+ EUnit x e -> EUnit <$> f x <*> travExt f e
+ EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
+ EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
+ EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EConst x t v -> EConst <$> f x <*> pure t <*> pure v
+ EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
+ EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
+ EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es
+ EShape x e -> EShape <$> f x <*> travExt f e
+ EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e
+ ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
+ EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
+ EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3
+ EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
+ EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
+ EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
+ EError x t s -> EError <$> f x <*> pure t <*> pure s
substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
substInline repl =