summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-05 12:12:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-05 12:12:57 +0200
commitff8aa61cfa28f9a8b2b599b7ca6ed9f404d7b377 (patch)
treefd1a4a7cae714f3922c43dda03d53479477a1d83 /src/AST
parent5ffb110bb5382b31c1acd3910b2064b36eeb2f77 (diff)
Generic accumulators
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs68
-rw-r--r--src/AST/Pretty.hs62
2 files changed, 65 insertions, 65 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index a4ff9f2..39d26c2 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count
deriving (Eq, Generic)
deriving (Semigroup, Monoid) via Generically Occ
+instance Show Occ where
+ showsPrec d (Occ l r) = showParen (d > 10) $
+ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
+
-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
@@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
getConst . occCountGeneral
- (\i o -> if idx2int i == idx2int idx then Const o else mempty)
+ (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
(\(Const o) -> Const o)
- (\_ (Const o) -> Const o)
(\(Const o1) (Const o2) -> Const (o1 <||> o2))
(\(Const o) -> Const (scaleMany o))
@@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd
occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv
- where
- occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
- occEnvPopN _ OccEnd = OccEnd
- occEnvPopN SZ e = e
- occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e
+occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
occCountGeneral :: forall r env t x.
(forall env'. Monoid (r env'))
- => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot
+ => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
-> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN
-> (forall env'. r env' -> r env' -> r env') -- ^ alternation
-> (forall env'. r env' -> r env') -- ^ scale-many
-> Expr x env t -> r env
-occCountGeneral onehot unpush unpushN alter many = go
+occCountGeneral onehot unpush alter many = go WId
where
- go :: Monoid (r env') => Expr x env' t' -> r env'
- go = \case
- EVar _ _ i -> onehot i (Occ One One)
- ELet _ rhs body -> go rhs <> unpush (go body)
- EPair _ a b -> go a <> go b
- EFst _ e -> go e
- ESnd _ e -> go e
+ go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
+ go w = \case
+ EVar _ _ i -> onehot w i (Occ One One)
+ ELet _ rhs body -> re rhs <> re1 body
+ EPair _ a b -> re a <> re b
+ EFst _ e -> re e
+ ESnd _ e -> re e
ENil _ -> mempty
- EInl _ _ e -> go e
- EInr _ _ e -> go e
- ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
- EBuild1 _ a b -> go a <> many (unpush (go b))
- EBuild _ n a b -> go a <> many (unpushN n (go b))
- EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
- EUnit _ e -> go e
- EReplicate _ e -> go e
+ EInl _ _ e -> re e
+ EInr _ _ e -> re e
+ ECase _ e a b -> re e <> (re1 a `alter` re1 b)
+ EBuild1 _ a b -> re a <> many (re1 b)
+ EBuild _ _ a b -> re a <> many (re1 b)
+ EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b
+ EUnit _ e -> re e
+ -- EReplicate _ e -> re e
EConst{} -> mempty
- EIdx0 _ e -> go e
- EIdx1 _ a b -> go a <> go b
- EIdx _ e es -> go e <> foldMap go es
- EOp _ _ e -> go e
- EWith a b -> go a <> unpush (go b)
- EAccum1 a b e -> go a <> go b <> go e
+ EIdx0 _ e -> re e
+ EIdx1 _ a b -> re a <> re b
+ EIdx _ _ a b -> re a <> re b
+ EShape _ e -> re e
+ EOp _ _ e -> re e
+ EWith a b -> re a <> re1 b
+ EAccum _ a b e -> re a <> re b <> re e
EError{} -> mempty
+ where
+ re :: Monoid (r env') => Expr x env' t'' -> r env'
+ re = go w
+
+ re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
+ re1 = unpush . go (WSink .> w)
deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index dbbc021..5610d36 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -1,16 +1,15 @@
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
-module AST.Pretty where
+{-# LANGUAGE TypeOperators #-}
+module AST.Pretty (ppExpr) where
import Control.Monad (ap)
import Data.List (intersperse)
-import Data.Foldable (toList)
import Data.Functor.Const
import AST
@@ -29,10 +28,6 @@ valprj (VPush x _) IZ = x
valprj (VPush _ env) (IS i) = valprj env i
valprj VTop i = case i of {}
-vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env)
-vpushN VNil v = v
-vpushN (name :< names) v = VPush (Const name) (vpushN names v)
-
newtype M a = M { runM :: Int -> (a, Int) }
deriving (Functor)
instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
@@ -115,12 +110,10 @@ ppExpr' d val = \case
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
- names <- sequence (vecGenerate n (\_ -> genName)) -- TODO generate underscores
- e' <- ppExpr' 0 (vpushN names val) b
+ name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b
+ e' <- ppExpr' 0 (VPush (Const name) val) b
return $ showParen (d > 10) $
- showString "build " . a' . showString " (\\["
- . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))
- . showString ("] -> ") . e' . showString ")"
+ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
EFold1 _ a b -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -135,9 +128,9 @@ ppExpr' d val = \case
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString "unit " . e'
- EReplicate _ e -> do
- e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "replicate " . e'
+ -- EReplicate _ e -> do
+ -- e' <- ppExpr' 11 val e
+ -- return $ showParen (d > 10) $ showString "replicate " . e'
EConst _ ty v -> return $ showString $ case ty of
STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
@@ -151,14 +144,15 @@ ppExpr' d val = \case
b' <- ppExpr' 9 val b
return $ showParen (d > 8) $ a' . showString " ! " . b'
- EIdx _ e es -> do
- e' <- ppExpr' 9 val e
- es' <- traverse (ppExpr' 0 val) es
+ EIdx _ _ a b -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 10 val b
return $ showParen (d > 8) $
- e' . showString " ! "
- . showString "["
- . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
- . showString "]"
+ a' . showString " !! " . b'
+
+ EShape _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "shape " . e'
EOp _ op (EPair _ a b)
| (Infix, ops) <- operator op -> do
@@ -175,30 +169,30 @@ ppExpr' d val = \case
EWith e1 e2 -> do
e1' <- ppExpr' 11 val e1
- let STArr n t = typeOf e1
- name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2
- e2' <- ppExpr' 11 (VPush (Const name) val) e2
+ name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
+ e2' <- ppExpr' 0 (VPush (Const name) val) e2
return $ showParen (d > 10) $
showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
. e2' . showString ")"
- EAccum1 e1 e2 e3 -> do
+ EAccum i e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ showParen (d > 10) $
- showString "accum1 " . e1' . showString " " . e2' . showString " " . e3'
+ showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
ppExprLet d val etop = do
- let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
+ let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS)
collect val' (ELet _ rhs body) = do
+ let occ = occCount IZ body
name <- genNameIfUsedIn (typeOf rhs) IZ body
rhs' <- ppExpr' 0 val' rhs
(binds, core) <- collect (VPush (Const name) val') body
- return ((name, rhs') : binds, core)
+ return ((name, occ, rhs') : binds, core)
collect val' e = ([],) <$> ppExpr' 0 val' e
(binds, core) <- collect val etop
@@ -210,7 +204,9 @@ ppExprLet d val etop = do
showString ("let " ++ open)
. foldr (.) id
(intersperse (showString " ; ")
- (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds))
+ (map (\(name, _occ, rhs) ->
+ showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs)
+ binds))
. showString (close ++ " in ")
. core