diff options
Diffstat (limited to 'src/AST/Count.hs')
| -rw-r--r-- | src/AST/Count.hs | 170 |
1 files changed, 0 insertions, 170 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index ca4d7ab..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,170 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Count where - -import Data.Functor.Const -import GHC.Generics (Generic, Generically(..)) - -import AST -import AST.Env -import Data - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) - --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = - getConst . occCountGeneral - (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) - (\(Const o) -> Const o) - (\(Const o1) (Const o2) -> Const (o1 <||> o2)) - (\(Const o) -> Const (scaleMany o)) - - -data OccEnv env where - OccEnd :: OccEnv env -- not necessarily top! - OccPush :: OccEnv env -> Occ -> OccEnv (t : env) - -instance Semigroup (OccEnv env) where - OccEnd <> e = e - e <> OccEnd = e - OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') - -instance Monoid (OccEnv env) where - mempty = OccEnd - -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty - -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') - -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) - -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd - -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv - -occCountGeneral :: forall r env t x. - (forall env'. Monoid (r env')) - => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot - -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env'. r env' -> r env' -> r env') -- ^ alternation - -> (forall env'. r env' -> r env') -- ^ scale-many - -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId - where - go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' - go w = \case - EVar _ _ i -> onehot w i (Occ One One) - ELet _ rhs body -> re rhs <> re1 body - EPair _ a b -> re a <> re b - EFst _ e -> re e - ESnd _ e -> re e - ENil _ -> mempty - EInl _ _ e -> re e - EInr _ _ e -> re e - ECase _ e a b -> re e <> (re1 a `alter` re1 b) - ENothing _ _ -> mempty - EJust _ e -> re e - EMaybe _ a b e -> re a <> re1 b <> re e - ELNil _ _ _ -> mempty - ELInl _ _ e -> re e - ELInr _ _ e -> re e - ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c) - EConstArr{} -> mempty - EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c - ESum1Inner _ e -> re e - EUnit _ e -> re e - EReplicate1Inner _ a b -> re a <> re b - EMaximum1Inner _ e -> re e - EMinimum1Inner _ e -> re e - EConst{} -> mempty - EIdx0 _ e -> re e - EIdx1 _ a b -> re a <> re b - EIdx _ a b -> re a <> re b - EShape _ e -> re e - EOp _ _ e -> re e - ECustom _ _ _ _ _ _ _ a b -> re a <> re b - ERecompute _ e -> re e - EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a _ b e -> re a <> re b <> re e - EZero _ _ e -> re e - EDeepZero _ _ e -> re e - EPlus _ _ a b -> re a <> re b - EOneHot _ _ _ a b -> re a <> re b - EError{} -> mempty - where - re :: Monoid (r env') => Expr x env' t'' -> r env' - re = go w - - re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' - re1 = unpush . go (WSink .> w) - - -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = - deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = - deleteUnused env occenv $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYesR sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYesR _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub |
