diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-05-25 23:34:13 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-05-25 23:34:13 +0200 |
commit | faa9af2ec2e463c1774f54b9e8f0ae3733cdb048 (patch) | |
tree | 46be36e09685054c151bec1ba24c0edf01dd68d9 | |
parent | e9c4cad143d483e29213e9c121574d1d46c2d56a (diff) |
Implement mapExt as travExt
-rw-r--r-- | src/AST.hs | 81 |
1 files changed, 43 insertions, 38 deletions
@@ -20,6 +20,7 @@ module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const +import Data.Functor.Identity import Data.Kind (Type) import Array @@ -262,44 +263,48 @@ extOf = \case EError x _ _ -> x mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = \case - EVar x t i -> EVar (f x) t i - ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) - EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) - EFst x e -> EFst (f x) (mapExt f e) - ESnd x e -> ESnd (f x) (mapExt f e) - ENil x -> ENil (f x) - EInl x t e -> EInl (f x) t (mapExt f e) - EInr x t e -> EInr (f x) t (mapExt f e) - ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) - ENothing x t -> ENothing (f x) t - EJust x e -> EJust (f x) (mapExt f e) - EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) - ELNil x t1 t2 -> ELNil (f x) t1 t2 - ELInl x t e -> ELInl (f x) t (mapExt f e) - ELInr x t e -> ELInr (f x) t (mapExt f e) - ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c) - EConstArr x n t a -> EConstArr (f x) n t a - EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) - EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) - ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) - EUnit x e -> EUnit (f x) (mapExt f e) - EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) - EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) - EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) - EConst x t v -> EConst (f x) t v - EIdx0 x e -> EIdx0 (f x) (mapExt f e) - EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) - EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) - EShape x e -> EShape (f x) (mapExt f e) - EOp x op e -> EOp (f x) op (mapExt f e) - ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) - EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) - EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) - EZero x t e -> EZero (f x) t (mapExt f e) - EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) - EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) - EError x t s -> EError (f x) t s +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case + EVar x t i -> EVar <$> f x <*> pure t <*> pure i + ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body + EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b + EFst x e -> EFst <$> f x <*> travExt f e + ESnd x e -> ESnd <$> f x <*> travExt f e + ENil x -> ENil <$> f x + EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e + EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e + ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b + ENothing x t -> ENothing <$> f x <*> pure t + EJust x e -> EJust <$> f x <*> travExt f e + EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e + ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 + ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e + ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e + ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c + EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a + EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e + EUnit x e -> EUnit <$> f x <*> travExt f e + EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b + EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e + EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EConst x t v -> EConst <$> f x <*> pure t <*> pure v + EIdx0 x e -> EIdx0 <$> f x <*> travExt f e + EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b + EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es + EShape x e -> EShape <$> f x <*> travExt f e + EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e + ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 + EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 + EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b + EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b + EError x t s -> EError <$> f x <*> pure t <*> pure s substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t substInline repl = |