aboutsummaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
commit4d456e4d34b1e4fb3725051d1b8a0c376b704692 (patch)
tree1385217efcc0b58ddb028e707e6a5a36b884ed65 /src/AST.hs
parent0e8e59c5f9af547cf1b79b9bae892e32700ace56 (diff)
Implement reshape
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs5
1 files changed, 5 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index f7b63cf..7549ff0 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -69,6 +69,7 @@ data Expr x env t where
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
-- MapAccum-like (is it real mapaccum? If so, rename)
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
@@ -233,6 +234,7 @@ typeOf = \case
EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2)
@@ -283,6 +285,7 @@ extOf = \case
EReplicate1Inner x _ _ -> x
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
+ EReshape x _ _ _ -> x
EFold1InnerD1 x _ _ _ _ -> x
EFold1InnerD2 x _ _ _ _ _ _ -> x
EConst x _ _ -> x
@@ -331,6 +334,7 @@ travExt f = \case
EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
@@ -392,6 +396,7 @@ subst' f w = \case
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
+ EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e)
EConst x t v -> EConst x t v