{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Count where import Data.Functor.Const import GHC.Generics (Generic, Generically(..)) import AST import AST.Env import Data data Count = Zero | One | Many deriving (Show, Eq, Ord) instance Semigroup Count where Zero <> n = n n <> Zero = n _ <> _ = Many instance Monoid Count where mempty = Zero data Occ = Occ { _occLexical :: Count , _occRuntime :: Count } deriving (Eq, Generic) deriving (Semigroup, Monoid) via Generically Occ instance Show Occ where showsPrec d (Occ l r) = showParen (d > 10) $ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) -- | This code is executed many times scaleMany :: Occ -> Occ scaleMany (Occ l Zero) = Occ l Zero scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ occCount idx = getConst . occCountGeneral (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) (\(Const o) -> Const o) (\(Const o1) (Const o2) -> Const (o1 <||> o2)) (\(Const o) -> Const (scaleMany o)) data OccEnv env where OccEnd :: OccEnv env -- not necessarily top! OccPush :: OccEnv env -> Occ -> OccEnv (t : env) instance Semigroup (OccEnv env) where OccEnd <> e = e e <> OccEnd = e OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') instance Monoid (OccEnv env) where mempty = OccEnd onehotOccEnv :: Idx env t -> Occ -> OccEnv env onehotOccEnv IZ v = OccPush OccEnd v onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty (<||>!) :: OccEnv env -> OccEnv env -> OccEnv env OccEnd <||>! e = e e <||>! OccEnd = e OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') scaleManyOccEnv :: OccEnv env -> OccEnv env scaleManyOccEnv OccEnd = OccEnd scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) occEnvPop :: OccEnv (t : env) -> OccEnv env occEnvPop (OccPush o _) = o occEnvPop OccEnd = OccEnd occCountAll :: Expr x env t -> OccEnv env occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv occCountGeneral :: forall r env t x. (forall env'. Monoid (r env')) => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot -> (forall env' a. r (a : env') -> r env') -- ^ unpush -> (forall env'. r env' -> r env' -> r env') -- ^ alternation -> (forall env'. r env' -> r env') -- ^ scale-many -> Expr x env t -> r env occCountGeneral onehot unpush alter many = go WId where go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' go w = \case EVar _ _ i -> onehot w i (Occ One One) ELet _ rhs body -> re rhs <> re1 body EPair _ a b -> re a <> re b EFst _ e -> re e ESnd _ e -> re e ENil _ -> mempty EInl _ _ e -> re e EInr _ _ e -> re e ECase _ e a b -> re e <> (re1 a `alter` re1 b) ENothing _ _ -> mempty EJust _ e -> re e EMaybe _ a b e -> re a <> re1 b <> re e EConstArr{} -> mempty EBuild _ _ a b -> re a <> many (re1 b) EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c ESum1Inner _ e -> re e EUnit _ e -> re e EReplicate1Inner _ a b -> re a <> re b EConst{} -> mempty EIdx0 _ e -> re e EIdx1 _ a b -> re a <> re b EIdx _ a b -> re a <> re b EShape _ e -> re e EOp _ _ e -> re e EWith a b -> re a <> re1 b EAccum _ a b e -> re a <> re b <> re e EZero _ -> mempty EPlus _ a b -> re a <> re b EOneHot _ _ a b -> re a <> re b EError{} -> mempty where re :: Monoid (r env') => Expr x env' t'' -> r env' re = go w re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' re1 = unpush . go (WSink .> w) deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r deleteUnused SNil OccEnd k = k SETop deleteUnused (_ `SCons` env) OccEnd k = deleteUnused env OccEnd $ \sub -> k (SENo sub) deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = deleteUnused env occenv $ \sub -> case count of Zero -> k (SENo sub) _ -> k (SEYes sub) unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t unsafeWeakenWithSubenv = \sub -> subst (\x t i -> case sinkViaSubenv i sub of Just i' -> EVar x t i' Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") where sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) sinkViaSubenv IZ (SEYes _) = Just IZ sinkViaSubenv IZ (SENo _) = Nothing sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub