{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} module AST.Count where import GHC.Generics (Generic, Generically(..)) import AST import Data data Count = Zero | One | Many deriving (Show, Eq, Ord) instance Semigroup Count where Zero <> n = n n <> Zero = n _ <> _ = Many instance Monoid Count where mempty = Zero data Occ = Occ { _occLexical :: Count , _occRuntime :: Count } deriving (Eq, Generic) deriving (Semigroup, Monoid) via Generically Occ -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) -- | This code is executed many times scaleMany :: Occ -> Occ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ occCount idx = \case EVar _ _ i | idx2int i == idx2int idx -> Occ One One | otherwise -> mempty ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body EPair _ a b -> occCount idx a <> occCount idx b EFst _ e -> occCount idx e ESnd _ e -> occCount idx e ENil _ -> mempty EInl _ _ e -> occCount idx e EInr _ _ e -> occCount idx e ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b EConst{} -> mempty EIdx1 _ a b -> occCount idx a <> occCount idx b EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es EOp _ _ e -> occCount idx e EWith a b -> occCount idx a <> occCount (IS idx) b EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e EError{} -> mempty