summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs68
1 files changed, 36 insertions, 32 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index a4ff9f2..39d26c2 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count
deriving (Eq, Generic)
deriving (Semigroup, Monoid) via Generically Occ
+instance Show Occ where
+ showsPrec d (Occ l r) = showParen (d > 10) $
+ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
+
-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
@@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
getConst . occCountGeneral
- (\i o -> if idx2int i == idx2int idx then Const o else mempty)
+ (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
(\(Const o) -> Const o)
- (\_ (Const o) -> Const o)
(\(Const o1) (Const o2) -> Const (o1 <||> o2))
(\(Const o) -> Const (scaleMany o))
@@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd
occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv
- where
- occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
- occEnvPopN _ OccEnd = OccEnd
- occEnvPopN SZ e = e
- occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e
+occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
occCountGeneral :: forall r env t x.
(forall env'. Monoid (r env'))
- => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot
+ => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
-> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN
-> (forall env'. r env' -> r env' -> r env') -- ^ alternation
-> (forall env'. r env' -> r env') -- ^ scale-many
-> Expr x env t -> r env
-occCountGeneral onehot unpush unpushN alter many = go
+occCountGeneral onehot unpush alter many = go WId
where
- go :: Monoid (r env') => Expr x env' t' -> r env'
- go = \case
- EVar _ _ i -> onehot i (Occ One One)
- ELet _ rhs body -> go rhs <> unpush (go body)
- EPair _ a b -> go a <> go b
- EFst _ e -> go e
- ESnd _ e -> go e
+ go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
+ go w = \case
+ EVar _ _ i -> onehot w i (Occ One One)
+ ELet _ rhs body -> re rhs <> re1 body
+ EPair _ a b -> re a <> re b
+ EFst _ e -> re e
+ ESnd _ e -> re e
ENil _ -> mempty
- EInl _ _ e -> go e
- EInr _ _ e -> go e
- ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
- EBuild1 _ a b -> go a <> many (unpush (go b))
- EBuild _ n a b -> go a <> many (unpushN n (go b))
- EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
- EUnit _ e -> go e
- EReplicate _ e -> go e
+ EInl _ _ e -> re e
+ EInr _ _ e -> re e
+ ECase _ e a b -> re e <> (re1 a `alter` re1 b)
+ EBuild1 _ a b -> re a <> many (re1 b)
+ EBuild _ _ a b -> re a <> many (re1 b)
+ EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b
+ EUnit _ e -> re e
+ -- EReplicate _ e -> re e
EConst{} -> mempty
- EIdx0 _ e -> go e
- EIdx1 _ a b -> go a <> go b
- EIdx _ e es -> go e <> foldMap go es
- EOp _ _ e -> go e
- EWith a b -> go a <> unpush (go b)
- EAccum1 a b e -> go a <> go b <> go e
+ EIdx0 _ e -> re e
+ EIdx1 _ a b -> re a <> re b
+ EIdx _ _ a b -> re a <> re b
+ EShape _ e -> re e
+ EOp _ _ e -> re e
+ EWith a b -> re a <> re1 b
+ EAccum _ a b e -> re a <> re b <> re e
EError{} -> mempty
+ where
+ re :: Monoid (r env') => Expr x env' t'' -> r env'
+ re = go w
+
+ re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
+ re1 = unpush . go (WSink .> w)
deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r