aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
commit42176d4a8a0fe7954f17da5c0506721695aa477f (patch)
tree8a29e847faa613e9becf1bccdcaad010187e639b /src/AST
parent7729c45a325fe653421d654ed4c28b040585fce9 (diff)
WIP fold: everything but Compile (slow, but should be sound)
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs68
-rw-r--r--src/AST/Pretty.hs34
-rw-r--r--src/AST/SplitLets.hs36
-rw-r--r--src/AST/UnMonoid.hs2
4 files changed, 119 insertions, 21 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index bec5a9d..cb363a3 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -356,7 +356,14 @@ pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
fullOccEnv SNil = OccEnd
fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
--- * s: how much of the result is required
+-- In one traversal, count occurrences of variables and determine what parts of
+-- expressions are actually used. These two results are computed independently:
+-- even if (almost) nothing of a particular term is actually used, variable
+-- references in that term still count as usual.
+--
+-- In @occCountX s t k@:
+-- * s: how much of the result of this term is required
+-- * t: the term to analyse
-- * k: is passed the actual environment usage of this expression, including
-- occurrence counts. The callback reconstructs a new expression in an
-- updated "response" environment. The response must be at least as large as
@@ -434,8 +441,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env1' $ \env1 s1 ->
occEnvPop' env2' $ \env2 s2 ->
occCountX (SsEither s1 s2) e $ \env0 mke ->
- withSome (Some env1 <||> Some env2) $ \env12 ->
- withSome (Some env12 <> Some env0) $ \env ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
k env $ \env' ->
ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
ENothing _ t ->
@@ -459,8 +465,7 @@ occCountX initialS topexpr k = case topexpr of
occCountX s b $ \env2' mkb ->
occEnvPop' env2' $ \env2 s2 ->
occCountX (SsMaybe s2) e $ \env0 mke ->
- withSome (Some env1 <||> Some env2) $ \env12 ->
- withSome (Some env12 <> Some env0) $ \env ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
k env $ \env' ->
EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
ELNil _ t1 t2 ->
@@ -497,9 +502,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env2' $ \env2 s1 ->
occEnvPop' env3' $ \env3 s2 ->
occCountX (SsLEither s1 s2) e $ \env0 mke ->
- withSome (Some env1 <||> Some env2) $ \env12 ->
- withSome (Some env12 <||> Some env3) $ \env123 ->
- withSome (Some env123 <> Some env0) $ \env ->
+ withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env ->
k env $ \env' ->
ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
@@ -550,16 +553,11 @@ occCountX initialS topexpr k = case topexpr of
occCountX (SsArr sElt) c $ \env3 mkc ->
withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
- let expr = EFold1Inner ext commut
- (projectSmallerSubstruc SsFull sElt $
- mka (OccPush (OccPush env' () sElt) () sElt))
- (mkb env') (mkc env') in
- case s of
- SsNone -> use expr $ ENil ext
- SsArr s' -> projectSmallerSubstruc (SsArr sElt) (SsArr s') expr
- SsFull -> case testEquality sElt SsFull of
- Just Refl -> expr
- Nothing -> error "unreachable"
+ projectSmallerSubstruc (SsArr sElt) s $
+ EFold1Inner ext commut
+ (projectSmallerSubstruc SsFull sElt $
+ mka (OccPush (OccPush env' () sElt) () sElt))
+ (mkb env') (mkc env')
ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
@@ -594,6 +592,40 @@ occCountX initialS topexpr k = case topexpr of
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+ EFold1InnerD1 _ cm a b c ->
+ -- TODO: currently maximally pessimistic on usage here; in particular,
+ -- usage tracking of the 'tape' stores can be improved
+ occCountX SsFull a $ \env1_2' mka ->
+ withSome (scaleMany (Some env1_2')) $ \env1_2 ->
+ occEnvPop' env1_2 $ \env1_1 _ ->
+ occEnvPop' env1_1 $ \env1 _ ->
+ occCountX SsFull b $ \env2 mkb ->
+ occCountX SsFull c $ \env3 mkc ->
+ withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ (mkb env') (mkc env')
+
+ EFold1InnerD2 _ cm t2 ef ep ezi ebog ed ->
+ -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage
+ occCountX SsFull ef $ \env1_4' mkef ->
+ withSome (scaleMany (Some env1_4')) $ \env1_4 ->
+ occEnvPop' env1_4 $ \env1_3 _ ->
+ occEnvPop' env1_3 $ \env1_2 _ ->
+ occEnvPop' env1_2 $ \env1_1 _ ->
+ occEnvPop' env1_1 $ \env1 sTape ->
+ occCountX SsFull ep $ \env2 mkep ->
+ occCountX SsFull ezi $ \env3 mkezi ->
+ occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog ->
+ occCountX SsFull ed $ \env5 mked ->
+ withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD2 ext cm t2
+ (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull))
+ (mkep env') (mkezi env') (mkebog env') (mked env')
+
EConst _ t x ->
k OccEnd $ \_ ->
case s of
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 9018602..afa62c6 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -72,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex
_ -> return "_"
| otherwise = genName' prefix
+-- TODO: let this return a type-tagged thing so that name environments are more typed than Const
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
@@ -209,8 +210,7 @@ ppExpr' d val expr = case expr of
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
- let opname = case cm of Commut -> "fold1i(C)"
- Noncommut -> "fold1i"
+ let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
@@ -235,6 +235,32 @@ ppExpr' d val expr = case expr of
e' <- ppExpr' 11 val e
return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+ EFold1InnerD1 _ cm a b c -> do
+ name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
+ name2 <- genNameIfUsedIn (typeOf b) IZ a
+ a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
+ let opname = "fold1iD1" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+
+ EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do
+ let STArr _ (STPair t1 ttape) = typeOf ebog
+ name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef
+ name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef
+ name3 <- genNameIfUsedIn t1 (IS IZ) ef
+ name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef
+ ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef
+ ep' <- ppExpr' 11 val ep
+ ezi' <- ppExpr' 11 val ezi
+ ebog' <- ppExpr' 11 val ebog
+ ed' <- ppExpr' 11 val ed
+ let opname = "fold1iD2" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr)
+ [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed']
+
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -386,6 +412,10 @@ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
ppSparse (SMTScal _) SpScal = "."
+ppCommut :: Commutative -> String
+ppCommut Commut = "(C)"
+ppCommut Noncommut = ""
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 82ec1d6..f75e795 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
splitLets' = \sub -> \case
EVar _ t i -> sub t i WId
- ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
ECase x e a b ->
let STEither t1 t2 = typeOf e
in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
@@ -35,6 +35,13 @@ splitLets' = \sub -> \case
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD1 x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD2 x cm t2 a b c d e ->
+ let STArr _ t1 = typeOf b
+ STArr _ (STPair _ ttape) = typeOf d
+ in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
@@ -98,6 +105,33 @@ splitLets' = \sub -> \case
t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
body
+ -- TODO: abstract this to splitN lol wtf
+ split4 :: forall bind1 bind2 bind3 bind4 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
+ split4 sub tbind1 tbind2 tbind3 tbind4 body =
+ let (ptrs1, bs1') = split @env' tbind1
+ (ptrs2, bs2') = split @(bind1 : env') tbind2
+ (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
+ (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4
+ bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1')
+ bs2 = fst (weakenBindingsE (WSink .> WSink) bs2')
+ bs3 = fst (weakenBindingsE WSink bs3')
+ b1 = bindingsBinds bs1
+ b2 = bindingsBinds bs2
+ b3 = bindingsBinds bs3
+ b4 = bindingsBinds bs4
+ in letBinds bs1 $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1))
+ _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink))
+ _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink))
+ _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env')))
+ t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w)))))))))
+ body
+
type family Split t where
Split (TPair a b) = SplitRec (TPair a b)
Split _ = '[]
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 48dd709..555f0ec 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -44,6 +44,8 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)