summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-28 22:40:41 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-28 22:40:41 +0100
commitc06b4bd71a94601d467b509a26c08020d1fbd794 (patch)
treeb16981c769231ef4af2c3ec5f002a01f857d95c6 /src/AST.hs
parenta3ba3bdc5c2f9606a0b98cdf53183841cca07eac (diff)
Pass around an accumMap (but it's empty still)
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs45
1 files changed, 36 insertions, 9 deletions
diff --git a/src/AST.hs b/src/AST.hs
index c8377de..652d003 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -246,6 +246,42 @@ extOf = \case
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
+mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt f = \case
+ EVar x t i -> EVar (f x) t i
+ ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body)
+ EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b)
+ EFst x e -> EFst (f x) (mapExt f e)
+ ESnd x e -> ESnd (f x) (mapExt f e)
+ ENil x -> ENil (f x)
+ EInl x t e -> EInl (f x) t (mapExt f e)
+ EInr x t e -> EInr (f x) t (mapExt f e)
+ ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b)
+ ENothing x t -> ENothing (f x) t
+ EJust x e -> EJust (f x) (mapExt f e)
+ EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e)
+ EConstArr x n t a -> EConstArr (f x) n t a
+ EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b)
+ EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c)
+ ESum1Inner x e -> ESum1Inner (f x) (mapExt f e)
+ EUnit x e -> EUnit (f x) (mapExt f e)
+ EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b)
+ EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e)
+ EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e)
+ EConst x t v -> EConst (f x) t v
+ EIdx0 x e -> EIdx0 (f x) (mapExt f e)
+ EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b)
+ EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es)
+ EShape x e -> EShape (f x) (mapExt f e)
+ EOp x op e -> EOp (f x) op (mapExt f e)
+ ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2)
+ EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2)
+ EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3)
+ EZero x t -> EZero (f x) t
+ EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b)
+ EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b)
+ EError x t s -> EError (f x) t s
+
subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
IS i -> EVar x t i
@@ -302,15 +338,6 @@ subst' f w = \case
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-slistIdx :: SList f list -> Idx list t -> f t
-slistIdx (SCons x _) IZ = x
-slistIdx (SCons _ list) (IS i) = slistIdx list i
-slistIdx SNil i = case i of {}
-
-idx2int :: Idx env t -> Int
-idx2int IZ = 0
-idx2int (IS n) = 1 + idx2int n
-
class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64