diff options
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r-- | src/AST/Count.hs | 68 |
1 files changed, 36 insertions, 32 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index a4ff9f2..39d26c2 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count deriving (Eq, Generic) deriving (Semigroup, Monoid) via Generically Occ +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) @@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ occCount idx = getConst . occCountGeneral - (\i o -> if idx2int i == idx2int idx then Const o else mempty) + (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) (\(Const o) -> Const o) - (\_ (Const o) -> Const o) (\(Const o1) (Const o2) -> Const (o1 <||> o2)) (\(Const o) -> Const (scaleMany o)) @@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o occEnvPop OccEnd = OccEnd occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv - where - occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env - occEnvPopN _ OccEnd = OccEnd - occEnvPopN SZ e = e - occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e +occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv occCountGeneral :: forall r env t x. (forall env'. Monoid (r env')) - => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot + => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN -> (forall env'. r env' -> r env' -> r env') -- ^ alternation -> (forall env'. r env' -> r env') -- ^ scale-many -> Expr x env t -> r env -occCountGeneral onehot unpush unpushN alter many = go +occCountGeneral onehot unpush alter many = go WId where - go :: Monoid (r env') => Expr x env' t' -> r env' - go = \case - EVar _ _ i -> onehot i (Occ One One) - ELet _ rhs body -> go rhs <> unpush (go body) - EPair _ a b -> go a <> go b - EFst _ e -> go e - ESnd _ e -> go e + go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' + go w = \case + EVar _ _ i -> onehot w i (Occ One One) + ELet _ rhs body -> re rhs <> re1 body + EPair _ a b -> re a <> re b + EFst _ e -> re e + ESnd _ e -> re e ENil _ -> mempty - EInl _ _ e -> go e - EInr _ _ e -> go e - ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) - EBuild1 _ a b -> go a <> many (unpush (go b)) - EBuild _ n a b -> go a <> many (unpushN n (go b)) - EFold1 _ a b -> many (unpush (unpush (go a))) <> go b - EUnit _ e -> go e - EReplicate _ e -> go e + EInl _ _ e -> re e + EInr _ _ e -> re e + ECase _ e a b -> re e <> (re1 a `alter` re1 b) + EBuild1 _ a b -> re a <> many (re1 b) + EBuild _ _ a b -> re a <> many (re1 b) + EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b + EUnit _ e -> re e + -- EReplicate _ e -> re e EConst{} -> mempty - EIdx0 _ e -> go e - EIdx1 _ a b -> go a <> go b - EIdx _ e es -> go e <> foldMap go es - EOp _ _ e -> go e - EWith a b -> go a <> unpush (go b) - EAccum1 a b e -> go a <> go b <> go e + EIdx0 _ e -> re e + EIdx1 _ a b -> re a <> re b + EIdx _ _ a b -> re a <> re b + EShape _ e -> re e + EOp _ _ e -> re e + EWith a b -> re a <> re1 b + EAccum _ a b e -> re a <> re b <> re e EError{} -> mempty + where + re :: Monoid (r env') => Expr x env' t'' -> r env' + re = go w + + re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' + re1 = unpush . go (WSink .> w) deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r |