diff options
| -rw-r--r-- | src/AST.hs | 81 | 
1 files changed, 43 insertions, 38 deletions
| @@ -20,6 +20,7 @@  module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where  import Data.Functor.Const +import Data.Functor.Identity  import Data.Kind (Type)  import Array @@ -262,44 +263,48 @@ extOf = \case    EError x _ _ -> x  mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = \case -  EVar x t i -> EVar (f x) t i -  ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) -  EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) -  EFst x e -> EFst (f x) (mapExt f e) -  ESnd x e -> ESnd (f x) (mapExt f e) -  ENil x -> ENil (f x) -  EInl x t e -> EInl (f x) t (mapExt f e) -  EInr x t e -> EInr (f x) t (mapExt f e) -  ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) -  ENothing x t -> ENothing (f x) t -  EJust x e -> EJust (f x) (mapExt f e) -  EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) -  ELNil x t1 t2 -> ELNil (f x) t1 t2 -  ELInl x t e -> ELInl (f x) t (mapExt f e) -  ELInr x t e -> ELInr (f x) t (mapExt f e) -  ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c) -  EConstArr x n t a -> EConstArr (f x) n t a -  EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) -  EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) -  ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) -  EUnit x e -> EUnit (f x) (mapExt f e) -  EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) -  EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) -  EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) -  EConst x t v -> EConst (f x) t v -  EIdx0 x e -> EIdx0 (f x) (mapExt f e) -  EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) -  EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) -  EShape x e -> EShape (f x) (mapExt f e) -  EOp x op e -> EOp (f x) op (mapExt f e) -  ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) -  EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) -  EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) -  EZero x t e -> EZero (f x) t (mapExt f e) -  EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) -  EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) -  EError x t s -> EError (f x) t s +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case +  EVar x t i -> EVar <$> f x <*> pure t <*> pure i +  ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body +  EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b +  EFst x e -> EFst <$> f x <*> travExt f e +  ESnd x e -> ESnd <$> f x <*> travExt f e +  ENil x -> ENil <$> f x +  EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e +  EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e +  ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b +  ENothing x t -> ENothing <$> f x <*> pure t +  EJust x e -> EJust <$> f x <*> travExt f e +  EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e +  ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 +  ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e +  ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e +  ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c +  EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a +  EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b +  EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c +  ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e +  EUnit x e -> EUnit <$> f x <*> travExt f e +  EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b +  EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e +  EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e +  EConst x t v -> EConst <$> f x <*> pure t <*> pure v +  EIdx0 x e -> EIdx0 <$> f x <*> travExt f e +  EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b +  EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es +  EShape x e -> EShape <$> f x <*> travExt f e +  EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e +  ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 +  EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 +  EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 +  EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e +  EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b +  EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b +  EError x t s -> EError <$> f x <*> pure t <*> pure s  substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t  substInline repl = | 
