aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-03 23:09:37 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-03 23:10:23 +0100
commit81d88dbc430ca6ec8390636f8b7162887b390873 (patch)
tree849c126fad3b923c2e5b815aa5c8488907bc2318 /src/AST
parent2ca218d2e97e521bcc49dea8f4774737ba083ede (diff)
WIP map + zip
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs49
-rw-r--r--src/AST/Pretty.hs15
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/AST/Weaken.hs2
4 files changed, 60 insertions, 8 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 296c021..bc02417 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of
SsNone ->
occCountX SsFull a $ \env1 mka ->
occCountX SsNone b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
use (EBuild ext n (mka env') $
use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
@@ -535,14 +534,31 @@ occCountX initialS topexpr k = case topexpr of
SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+ EMap _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
+ ENil ext
+ SsArr' s' ->
+ occCountX s' a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMap ext (mka (OccPush env' () s1)) (mkb env')
+
EFold1Inner _ commut a b c ->
occCountX SsFull a $ \env1''' mka ->
withSome (scaleMany (Some env1''')) $ \env1'' ->
@@ -608,6 +624,27 @@ occCountX initialS topexpr k = case topexpr of
k env $ \env' ->
EReshape ext n (mkesh env') (mke env')
+ EZip _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $ mka env'
+ SsArr' (SsPair' s1 s2) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EZip ext (mka env') (mkb env')
+
EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
-- If nothing is necessary, we can execute a fold and then proceed to ignore it
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 68fc629..2c51b85 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -204,6 +204,14 @@ ppExpr' d val expr = case expr of
<> hardline <> e')
(ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
+ EMap _ a b -> do
+ let STArr _ t1 = typeOf b
+ name <- genNameIfUsedIn' "i" t1 IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 0) $
+ ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
+
EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
@@ -238,7 +246,12 @@ ppExpr' d val expr = case expr of
EReshape _ n esh e -> do
esh' <- ppExpr' 11 val esh
e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e'
+ return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e']
+
+ EZip _ e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
EFold1InnerD1 _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index a22b73f..1712ba5 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -38,6 +38,7 @@ unMonoid = \case
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
@@ -45,6 +46,7 @@ unMonoid = \case
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 3a97fd1..f0820b8 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -129,7 +129,7 @@ wCopies bs w =
let bs' = slistMap (\_ -> Const ()) bs
in WStack bs' bs' WId w
-wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
+wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env
wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)