diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Count.hs | 68 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 34 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 36 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 2 |
4 files changed, 119 insertions, 21 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index bec5a9d..cb363a3 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -356,7 +356,14 @@ pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) fullOccEnv SNil = OccEnd fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull --- * s: how much of the result is required +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse -- * k: is passed the actual environment usage of this expression, including -- occurrence counts. The callback reconstructs a new expression in an -- updated "response" environment. The response must be at least as large as @@ -434,8 +441,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env1' $ \env1 s1 -> occEnvPop' env2' $ \env2 s2 -> occCountX (SsEither s1 s2) e $ \env0 mke -> - withSome (Some env1 <||> Some env2) $ \env12 -> - withSome (Some env12 <> Some env0) $ \env -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> k env $ \env' -> ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) ENothing _ t -> @@ -459,8 +465,7 @@ occCountX initialS topexpr k = case topexpr of occCountX s b $ \env2' mkb -> occEnvPop' env2' $ \env2 s2 -> occCountX (SsMaybe s2) e $ \env0 mke -> - withSome (Some env1 <||> Some env2) $ \env12 -> - withSome (Some env12 <> Some env0) $ \env -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> k env $ \env' -> EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') ELNil _ t1 t2 -> @@ -497,9 +502,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env2' $ \env2 s1 -> occEnvPop' env3' $ \env3 s2 -> occCountX (SsLEither s1 s2) e $ \env0 mke -> - withSome (Some env1 <||> Some env2) $ \env12 -> - withSome (Some env12 <||> Some env3) $ \env123 -> - withSome (Some env123 <> Some env0) $ \env -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> k env $ \env' -> ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) @@ -550,16 +553,11 @@ occCountX initialS topexpr k = case topexpr of occCountX (SsArr sElt) c $ \env3 mkc -> withSome (Some env1 <> Some env2 <> Some env3) $ \env -> k env $ \env' -> - let expr = EFold1Inner ext commut - (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) - (mkb env') (mkc env') in - case s of - SsNone -> use expr $ ENil ext - SsArr s' -> projectSmallerSubstruc (SsArr sElt) (SsArr s') expr - SsFull -> case testEquality sElt SsFull of - Just Refl -> expr - Nothing -> error "unreachable" + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush (OccPush env' () sElt) () sElt)) + (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -594,6 +592,40 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + EFold1InnerD1 _ cm a b c -> + -- TODO: currently maximally pessimistic on usage here; in particular, + -- usage tracking of the 'tape' stores can be improved + occCountX SsFull a $ \env1_2' mka -> + withSome (scaleMany (Some env1_2')) $ \env1_2 -> + occEnvPop' env1_2 $ \env1_1 _ -> + occEnvPop' env1_1 $ \env1 _ -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull c $ \env3 mkc -> + withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> + -- TODO: currently very pessimistic on usage here, can at the very least improve tape usage + occCountX SsFull ef $ \env1_4' mkef -> + withSome (scaleMany (Some env1_4')) $ \env1_4 -> + occEnvPop' env1_4 $ \env1_3 _ -> + occEnvPop' env1_3 $ \env1_2 _ -> + occEnvPop' env1_2 $ \env1_1 _ -> + occEnvPop' env1_1 $ \env1 sTape -> + occCountX SsFull ep $ \env2 mkep -> + occCountX SsFull ezi $ \env3 mkezi -> + occCountX (SsArr (SsPair SsFull sTape)) ebog $ \env4 mkebog -> + occCountX SsFull ed $ \env5 mked -> + withSome (Some env1 <> Some env2 <> Some env3 <> Some env4 <> Some env5) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm t2 + (mkef (OccPush (OccPush (OccPush (OccPush env' () sTape) () SsFull) () SsFull) () SsFull)) + (mkep env') (mkezi env') (mkebog env') (mked env') + EConst _ t x -> k OccEnd $ \_ -> case s of diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 9018602..afa62c6 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -72,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -209,8 +210,7 @@ ppExpr' d val expr = case expr of a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] @@ -235,6 +235,32 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EFold1InnerD1 _ cm a b c -> do + name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a + name2 <- genNameIfUsedIn (typeOf b) IZ a + a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + + EFold1InnerD2 _ cm t2 ef ep ezi ebog ed -> do + let STArr _ (STPair t1 ttape) = typeOf ebog + name1 <- genNameIfUsedIn ttape (IS (IS (IS IZ))) ef + name2 <- genNameIfUsedIn t1 (IS (IS IZ)) ef + name3 <- genNameIfUsedIn t1 (IS IZ) ef + name4 <- genNameIfUsedIn (fromSMTy t2) IZ ef + ef' <- ppExpr' 0 (Const name4 `SCons` Const name3 `SCons` Const name2 `SCons` Const name1 `SCons` val) ef + ep' <- ppExpr' 11 val ep + ezi' <- ppExpr' 11 val ezi + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString name1, ppString name2, ppString name3, ppString name4] ef', ep', ezi', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -386,6 +412,10 @@ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s ppSparse (SMTScal _) SpScal = "." +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 82ec1d6..f75e795 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) @@ -35,6 +35,13 @@ splitLets' = \sub -> \case EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm t2 a b c d e -> + let STArr _ t1 = typeOf b + STArr _ (STPair _ ttape) = typeOf d + in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -98,6 +105,33 @@ splitLets' = \sub -> \case t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 48dd709..555f0ec 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -44,6 +44,8 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm t2 a b c d e -> EFold1InnerD2 ext cm t2 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) |
