diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-09-05 12:12:57 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-05 12:12:57 +0200 |
commit | ff8aa61cfa28f9a8b2b599b7ca6ed9f404d7b377 (patch) | |
tree | fd1a4a7cae714f3922c43dda03d53479477a1d83 /src/AST | |
parent | 5ffb110bb5382b31c1acd3910b2064b36eeb2f77 (diff) |
Generic accumulators
Diffstat (limited to 'src/AST')
-rw-r--r-- | src/AST/Count.hs | 68 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 62 |
2 files changed, 65 insertions, 65 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index a4ff9f2..39d26c2 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count deriving (Eq, Generic) deriving (Semigroup, Monoid) via Generically Occ +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) @@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ occCount idx = getConst . occCountGeneral - (\i o -> if idx2int i == idx2int idx then Const o else mempty) + (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) (\(Const o) -> Const o) - (\_ (Const o) -> Const o) (\(Const o1) (Const o2) -> Const (o1 <||> o2)) (\(Const o) -> Const (scaleMany o)) @@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o occEnvPop OccEnd = OccEnd occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv - where - occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env - occEnvPopN _ OccEnd = OccEnd - occEnvPopN SZ e = e - occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e +occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv occCountGeneral :: forall r env t x. (forall env'. Monoid (r env')) - => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot + => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN -> (forall env'. r env' -> r env' -> r env') -- ^ alternation -> (forall env'. r env' -> r env') -- ^ scale-many -> Expr x env t -> r env -occCountGeneral onehot unpush unpushN alter many = go +occCountGeneral onehot unpush alter many = go WId where - go :: Monoid (r env') => Expr x env' t' -> r env' - go = \case - EVar _ _ i -> onehot i (Occ One One) - ELet _ rhs body -> go rhs <> unpush (go body) - EPair _ a b -> go a <> go b - EFst _ e -> go e - ESnd _ e -> go e + go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' + go w = \case + EVar _ _ i -> onehot w i (Occ One One) + ELet _ rhs body -> re rhs <> re1 body + EPair _ a b -> re a <> re b + EFst _ e -> re e + ESnd _ e -> re e ENil _ -> mempty - EInl _ _ e -> go e - EInr _ _ e -> go e - ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) - EBuild1 _ a b -> go a <> many (unpush (go b)) - EBuild _ n a b -> go a <> many (unpushN n (go b)) - EFold1 _ a b -> many (unpush (unpush (go a))) <> go b - EUnit _ e -> go e - EReplicate _ e -> go e + EInl _ _ e -> re e + EInr _ _ e -> re e + ECase _ e a b -> re e <> (re1 a `alter` re1 b) + EBuild1 _ a b -> re a <> many (re1 b) + EBuild _ _ a b -> re a <> many (re1 b) + EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b + EUnit _ e -> re e + -- EReplicate _ e -> re e EConst{} -> mempty - EIdx0 _ e -> go e - EIdx1 _ a b -> go a <> go b - EIdx _ e es -> go e <> foldMap go es - EOp _ _ e -> go e - EWith a b -> go a <> unpush (go b) - EAccum1 a b e -> go a <> go b <> go e + EIdx0 _ e -> re e + EIdx1 _ a b -> re a <> re b + EIdx _ _ a b -> re a <> re b + EShape _ e -> re e + EOp _ _ e -> re e + EWith a b -> re a <> re1 b + EAccum _ a b e -> re a <> re b <> re e EError{} -> mempty + where + re :: Monoid (r env') => Expr x env' t'' -> r env' + re = go w + + re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' + re1 = unpush . go (WSink .> w) deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index dbbc021..5610d36 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -1,16 +1,15 @@ -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE TupleSections #-} -module AST.Pretty where +{-# LANGUAGE TypeOperators #-} +module AST.Pretty (ppExpr) where import Control.Monad (ap) import Data.List (intersperse) -import Data.Foldable (toList) import Data.Functor.Const import AST @@ -29,10 +28,6 @@ valprj (VPush x _) IZ = x valprj (VPush _ env) (IS i) = valprj env i valprj VTop i = case i of {} -vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env) -vpushN VNil v = v -vpushN (name :< names) v = VPush (Const name) (vpushN names v) - newtype M a = M { runM :: Int -> (a, Int) } deriving (Functor) instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } @@ -115,12 +110,10 @@ ppExpr' d val = \case EBuild _ n a b -> do a' <- ppExpr' 11 val a - names <- sequence (vecGenerate n (\_ -> genName)) -- TODO generate underscores - e' <- ppExpr' 0 (vpushN names val) b + name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b + e' <- ppExpr' 0 (VPush (Const name) val) b return $ showParen (d > 10) $ - showString "build " . a' . showString " (\\[" - . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names)))) - . showString ("] -> ") . e' . showString ")" + showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" EFold1 _ a b -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a @@ -135,9 +128,9 @@ ppExpr' d val = \case e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "unit " . e' - EReplicate _ e -> do - e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "replicate " . e' + -- EReplicate _ e -> do + -- e' <- ppExpr' 11 val e + -- return $ showParen (d > 10) $ showString "replicate " . e' EConst _ ty v -> return $ showString $ case ty of STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v @@ -151,14 +144,15 @@ ppExpr' d val = \case b' <- ppExpr' 9 val b return $ showParen (d > 8) $ a' . showString " ! " . b' - EIdx _ e es -> do - e' <- ppExpr' 9 val e - es' <- traverse (ppExpr' 0 val) es + EIdx _ _ a b -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 10 val b return $ showParen (d > 8) $ - e' . showString " ! " - . showString "[" - . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) - . showString "]" + a' . showString " !! " . b' + + EShape _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "shape " . e' EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do @@ -175,30 +169,30 @@ ppExpr' d val = \case EWith e1 e2 -> do e1' <- ppExpr' 11 val e1 - let STArr n t = typeOf e1 - name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2 - e2' <- ppExpr' 11 (VPush (Const name) val) e2 + name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 + e2' <- ppExpr' 0 (VPush (Const name) val) e2 return $ showParen (d > 10) $ showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") . e2' . showString ")" - EAccum1 e1 e2 e3 -> do + EAccum i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ - showString "accum1 " . e1' . showString " " . e2' . showString " " . e3' + showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS ppExprLet d val etop = do - let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS) + let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS) collect val' (ELet _ rhs body) = do + let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body rhs' <- ppExpr' 0 val' rhs (binds, core) <- collect (VPush (Const name) val') body - return ((name, rhs') : binds, core) + return ((name, occ, rhs') : binds, core) collect val' e = ([],) <$> ppExpr' 0 val' e (binds, core) <- collect val etop @@ -210,7 +204,9 @@ ppExprLet d val etop = do showString ("let " ++ open) . foldr (.) id (intersperse (showString " ; ") - (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds)) + (map (\(name, _occ, rhs) -> + showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs) + binds)) . showString (close ++ " in ") . core |