From 4d456e4d34b1e4fb3725051d1b8a0c376b704692 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:56:35 +0100 Subject: Implement reshape --- src/AST/Count.hs | 15 +++++++++++++++ src/AST/Pretty.hs | 5 +++++ src/AST/SplitLets.hs | 1 + src/AST/UnMonoid.hs | 1 + 4 files changed, 22 insertions(+) (limited to 'src/AST') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 229661f..66b4e0b 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -598,6 +598,21 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + EFold1InnerD1 _ cm e1 e2 e3 -> case s of -- If nothing is necessary, we can execute a fold and then proceed to ignore it diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 587328d..67197f9 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -235,6 +235,11 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' + EFold1InnerD1 _ cm a b c -> do name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a name2 <- genNameIfUsedIn (typeOf b) IZ a diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 6034084..73c1c67 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -61,6 +61,7 @@ splitLets' = \sub -> \case EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 6904715..e5a9708 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -44,6 +44,7 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) EConst _ t x -> EConst ext t x -- cgit v1.2.3-70-g09d2