aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs170
1 files changed, 0 insertions, 170 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
deleted file mode 100644
index ca4d7ab..0000000
--- a/src/AST/Count.hs
+++ /dev/null
@@ -1,170 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.Count where
-
-import Data.Functor.Const
-import GHC.Generics (Generic, Generically(..))
-
-import AST
-import AST.Env
-import Data
-
-
-data Count = Zero | One | Many
- deriving (Show, Eq, Ord)
-
-instance Semigroup Count where
- Zero <> n = n
- n <> Zero = n
- _ <> _ = Many
-instance Monoid Count where
- mempty = Zero
-
-data Occ = Occ { _occLexical :: Count
- , _occRuntime :: Count }
- deriving (Eq, Generic)
- deriving (Semigroup, Monoid) via Generically Occ
-
-instance Show Occ where
- showsPrec d (Occ l r) = showParen (d > 10) $
- showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
-
--- | One of the two branches is taken
-(<||>) :: Occ -> Occ -> Occ
-Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-
--- | This code is executed many times
-scaleMany :: Occ -> Occ
-scaleMany (Occ l Zero) = Occ l Zero
-scaleMany (Occ l _) = Occ l Many
-
-occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx =
- getConst . occCountGeneral
- (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
- (\(Const o) -> Const o)
- (\(Const o1) (Const o2) -> Const (o1 <||> o2))
- (\(Const o) -> Const (scaleMany o))
-
-
-data OccEnv env where
- OccEnd :: OccEnv env -- not necessarily top!
- OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
-
-instance Semigroup (OccEnv env) where
- OccEnd <> e = e
- e <> OccEnd = e
- OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
-
-instance Monoid (OccEnv env) where
- mempty = OccEnd
-
-onehotOccEnv :: Idx env t -> Occ -> OccEnv env
-onehotOccEnv IZ v = OccPush OccEnd v
-onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
-
-(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
-OccEnd <||>! e = e
-e <||>! OccEnd = e
-OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
-
-scaleManyOccEnv :: OccEnv env -> OccEnv env
-scaleManyOccEnv OccEnd = OccEnd
-scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
-
-occEnvPop :: OccEnv (t : env) -> OccEnv env
-occEnvPop (OccPush o _) = o
-occEnvPop OccEnd = OccEnd
-
-occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
-
-occCountGeneral :: forall r env t x.
- (forall env'. Monoid (r env'))
- => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
- -> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env'. r env' -> r env' -> r env') -- ^ alternation
- -> (forall env'. r env' -> r env') -- ^ scale-many
- -> Expr x env t -> r env
-occCountGeneral onehot unpush alter many = go WId
- where
- go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
- go w = \case
- EVar _ _ i -> onehot w i (Occ One One)
- ELet _ rhs body -> re rhs <> re1 body
- EPair _ a b -> re a <> re b
- EFst _ e -> re e
- ESnd _ e -> re e
- ENil _ -> mempty
- EInl _ _ e -> re e
- EInr _ _ e -> re e
- ECase _ e a b -> re e <> (re1 a `alter` re1 b)
- ENothing _ _ -> mempty
- EJust _ e -> re e
- EMaybe _ a b e -> re a <> re1 b <> re e
- ELNil _ _ _ -> mempty
- ELInl _ _ e -> re e
- ELInr _ _ e -> re e
- ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)
- EConstArr{} -> mempty
- EBuild _ _ a b -> re a <> many (re1 b)
- EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
- ESum1Inner _ e -> re e
- EUnit _ e -> re e
- EReplicate1Inner _ a b -> re a <> re b
- EMaximum1Inner _ e -> re e
- EMinimum1Inner _ e -> re e
- EConst{} -> mempty
- EIdx0 _ e -> re e
- EIdx1 _ a b -> re a <> re b
- EIdx _ a b -> re a <> re b
- EShape _ e -> re e
- EOp _ _ e -> re e
- ECustom _ _ _ _ _ _ _ a b -> re a <> re b
- ERecompute _ e -> re e
- EWith _ _ a b -> re a <> re1 b
- EAccum _ _ _ a _ b e -> re a <> re b <> re e
- EZero _ _ e -> re e
- EDeepZero _ _ e -> re e
- EPlus _ _ a b -> re a <> re b
- EOneHot _ _ _ a b -> re a <> re b
- EError{} -> mempty
- where
- re :: Monoid (r env') => Expr x env' t'' -> r env'
- re = go w
-
- re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
- re1 = unpush . go (WSink .> w)
-
-
-deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
-deleteUnused SNil OccEnd k = k SETop
-deleteUnused (_ `SCons` env) OccEnd k =
- deleteUnused env OccEnd $ \sub -> k (SENo sub)
-deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
- deleteUnused env occenv $ \sub ->
- case count of Zero -> k (SENo sub)
- _ -> k (SEYesR sub)
-
-unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
-unsafeWeakenWithSubenv = \sub ->
- subst (\x t i -> case sinkViaSubenv i sub of
- Just i' -> EVar x t i'
- Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
- where
- sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkViaSubenv IZ (SEYesR _) = Just IZ
- sinkViaSubenv IZ (SENo _) = Nothing
- sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
- sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub