aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
commit4d456e4d34b1e4fb3725051d1b8a0c376b704692 (patch)
tree1385217efcc0b58ddb028e707e6a5a36b884ed65 /src/AST
parent0e8e59c5f9af547cf1b79b9bae892e32700ace56 (diff)
Implement reshape
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs15
-rw-r--r--src/AST/Pretty.hs5
-rw-r--r--src/AST/SplitLets.hs1
-rw-r--r--src/AST/UnMonoid.hs1
4 files changed, 22 insertions, 0 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 229661f..66b4e0b 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -598,6 +598,21 @@ occCountX initialS topexpr k = case topexpr of
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+ EReshape _ n esh e ->
+ case s of
+ SsNone ->
+ occCountX SsNone esh $ \env1 mkesh ->
+ occCountX SsNone e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkesh env') $ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull esh $ \env1 mkesh ->
+ occCountX (SsArr s') e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReshape ext n (mkesh env') (mke env')
+
EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
-- If nothing is necessary, we can execute a fold and then proceed to ignore it
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 587328d..67197f9 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -235,6 +235,11 @@ ppExpr' d val expr = case expr of
e' <- ppExpr' 11 val e
return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+ EReshape _ n esh e -> do
+ esh' <- ppExpr' 11 val esh
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e'
+
EFold1InnerD1 _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf b) IZ a
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 6034084..73c1c67 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -61,6 +61,7 @@ splitLets' = \sub -> \case
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 6904715..e5a9708 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -44,6 +44,7 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
EConst _ t x -> EConst ext t x