{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} module AST.Count where import Data.Functor.Const import GHC.Generics (Generic, Generically(..)) import AST import Data data Count = Zero | One | Many deriving (Show, Eq, Ord) instance Semigroup Count where Zero <> n = n n <> Zero = n _ <> _ = Many instance Monoid Count where mempty = Zero data Occ = Occ { _occLexical :: Count , _occRuntime :: Count } deriving (Eq, Generic) deriving (Semigroup, Monoid) via Generically Occ -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) -- | This code is executed many times scaleMany :: Occ -> Occ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ occCount idx = getConst . occCountGeneral (\i o -> if idx2int i == idx2int idx then Const o else mempty) (\(Const o) -> Const o) (\_ (Const o) -> Const o) (\(Const o1) (Const o2) -> Const (o1 <||> o2)) (\(Const o) -> Const (scaleMany o)) data OccEnv env where OccEnd :: OccEnv env -- not necessarily top! OccPush :: OccEnv env -> Occ -> OccEnv (t : env) instance Semigroup (OccEnv env) where OccEnd <> e = e e <> OccEnd = e OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') instance Monoid (OccEnv env) where mempty = OccEnd onehotOccEnv :: Idx env t -> Occ -> OccEnv env onehotOccEnv IZ v = OccPush OccEnd v onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty (<||>!) :: OccEnv env -> OccEnv env -> OccEnv env OccEnd <||>! e = e e <||>! OccEnd = e OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') scaleManyOccEnv :: OccEnv env -> OccEnv env scaleManyOccEnv OccEnd = OccEnd scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) occCountAll :: Expr x env t -> OccEnv env occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv where unpush :: OccEnv (t : env) -> OccEnv env unpush (OccPush o _) = o unpush OccEnd = OccEnd unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env unpushN _ OccEnd = OccEnd unpushN SZ e = e unpushN (SS n) (OccPush e _) = unpushN n e occCountGeneral :: forall r env t x. (forall env'. Monoid (r env')) => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot -> (forall env' a. r (a : env') -> r env') -- ^ unpush -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN -> (forall env'. r env' -> r env' -> r env') -- ^ alternation -> (forall env'. r env' -> r env') -- ^ scale-many -> Expr x env t -> r env occCountGeneral onehot unpush unpushN alter many = go where go :: Monoid (r env') => Expr x env' t' -> r env' go = \case EVar _ _ i -> onehot i (Occ One One) ELet _ rhs body -> go rhs <> unpush (go body) EPair _ a b -> go a <> go b EFst _ e -> go e ESnd _ e -> go e ENil _ -> mempty EInl _ _ e -> go e EInr _ _ e -> go e ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) EBuild1 _ a b -> go a <> many (unpush (go b)) EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) EFold1 _ a b -> many (unpush (unpush (go a))) <> go b EConst{} -> mempty EIdx0 _ e -> go e EIdx1 _ a b -> go a <> go b EIdx _ e es -> go e <> foldMap go es EOp _ _ e -> go e EWith a b -> go a <> unpush (go b) EAccum a b e -> go a <> go b <> go e EError{} -> mempty