From 81d88dbc430ca6ec8390636f8b7162887b390873 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Nov 2025 23:09:37 +0100 Subject: WIP map + zip --- src/AST/Count.hs | 49 +++++++++++++++++++++++++++++++++++++++++++------ src/AST/Pretty.hs | 15 ++++++++++++++- src/AST/UnMonoid.hs | 2 ++ src/AST/Weaken.hs | 2 +- 4 files changed, 60 insertions(+), 8 deletions(-) (limited to 'src/AST') diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 296c021..bc02417 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of SsNone -> occCountX SsFull a $ \env1 mka -> occCountX SsNone b $ \env2'' mkb -> - withSome (scaleMany (Some env2'')) $ \env2' -> - occEnvPop' env2' $ \env2 s2 -> - withSome (Some env1 <> Some env2) $ \env -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> k env $ \env' -> use (EBuild ext n (mka env') $ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ @@ -535,14 +534,31 @@ occCountX initialS topexpr k = case topexpr of SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX s' b $ \env2'' mkb -> - withSome (scaleMany (Some env2'')) $ \env2' -> - occEnvPop' env2' $ \env2 s2 -> - withSome (Some env1 <> Some env2) $ \env -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> k env $ \env' -> EBuild ext n (mka env') $ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + EFold1Inner _ commut a b c -> occCountX SsFull a $ \env1''' mka -> withSome (scaleMany (Some env1''')) $ \env1'' -> @@ -608,6 +624,27 @@ occCountX initialS topexpr k = case topexpr of k env $ \env' -> EReshape ext n (mkesh env') (mke env') + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + EFold1InnerD1 _ cm e1 e2 e3 -> case s of -- If nothing is necessary, we can execute a fold and then proceed to ignore it diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 68fc629..2c51b85 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -204,6 +204,14 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn' "i" t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + EFold1Inner _ cm a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a @@ -238,7 +246,12 @@ ppExpr' d val expr = case expr of EReshape _ n esh e -> do esh' <- ppExpr' 11 val esh e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] EFold1InnerD1 _ cm a b c -> do name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index a22b73f..1712ba5 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -38,6 +38,7 @@ unMonoid = \case ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) @@ -45,6 +46,7 @@ unMonoid = \case EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 3a97fd1..f0820b8 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -129,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) -- cgit v1.2.3-70-g09d2