blob: b7079ff89d841cbf412367f5e7a1a551ffd65c1f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where
import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))
import AST
import AST.Env
import Data
data Count = Zero | One | Many
deriving (Show, Eq, Ord)
instance Semigroup Count where
Zero <> n = n
n <> Zero = n
_ <> _ = Many
instance Monoid Count where
mempty = Zero
data Occ = Occ { _occLexical :: Count
, _occRuntime :: Count }
deriving (Eq, Generic)
deriving (Semigroup, Monoid) via Generically Occ
instance Show Occ where
showsPrec d (Occ l r) = showParen (d > 10) $
showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l Zero) = Occ l Zero
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
getConst . occCountGeneral
(\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
(\(Const o) -> Const o)
(\(Const o1) (Const o2) -> Const (o1 <||> o2))
(\(Const o) -> Const (scaleMany o))
data OccEnv env where
OccEnd :: OccEnv env -- not necessarily top!
OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
instance Semigroup (OccEnv env) where
OccEnd <> e = e
e <> OccEnd = e
OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
instance Monoid (OccEnv env) where
mempty = OccEnd
onehotOccEnv :: Idx env t -> Occ -> OccEnv env
onehotOccEnv IZ v = OccPush OccEnd v
onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
OccEnd <||>! e = e
e <||>! OccEnd = e
OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
occEnvPop :: OccEnv (t : env) -> OccEnv env
occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd
occCountAll :: Expr x env t -> OccEnv env
occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
occCountGeneral :: forall r env t x.
(forall env'. Monoid (r env'))
=> (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
-> (forall env' a. r (a : env') -> r env') -- ^ unpush
-> (forall env'. r env' -> r env' -> r env') -- ^ alternation
-> (forall env'. r env' -> r env') -- ^ scale-many
-> Expr x env t -> r env
occCountGeneral onehot unpush alter many = go WId
where
go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
go w = \case
EVar _ _ i -> onehot w i (Occ One One)
ELet _ rhs body -> re rhs <> re1 body
EPair _ a b -> re a <> re b
EFst _ e -> re e
ESnd _ e -> re e
ENil _ -> mempty
EInl _ _ e -> re e
EInr _ _ e -> re e
ECase _ e a b -> re e <> (re1 a `alter` re1 b)
ENothing _ _ -> mempty
EJust _ e -> re e
EMaybe _ a b e -> re a <> re1 b <> re e
EConstArr{} -> mempty
EBuild _ _ a b -> re a <> many (re1 b)
EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
ESum1Inner _ e -> re e
EUnit _ e -> re e
EReplicate1Inner _ a b -> re a <> re b
EMaximum1Inner _ e -> re e
EMinimum1Inner _ e -> re e
EConst{} -> mempty
EIdx0 _ e -> re e
EIdx1 _ a b -> re a <> re b
EIdx _ a b -> re a <> re b
EShape _ e -> re e
EOp _ _ e -> re e
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
EWith _ a b -> re a <> re1 b
EAccum _ _ a b e -> re a <> re b <> re e
EZero _ _ -> mempty
EPlus _ _ a b -> re a <> re b
EOneHot _ _ _ a b -> re a <> re b
EError{} -> mempty
where
re :: Monoid (r env') => Expr x env' t'' -> r env'
re = go w
re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
re1 = unpush . go (WSink .> w)
deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil OccEnd k = k SETop
deleteUnused (_ `SCons` env) OccEnd k =
deleteUnused env OccEnd $ \sub -> k (SENo sub)
deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
deleteUnused env occenv $ \sub ->
case count of Zero -> k (SENo sub)
_ -> k (SEYes sub)
unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
subst (\x t i -> case sinkViaSubenv i sub of
Just i' -> EVar x t i'
Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
where
sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
sinkViaSubenv IZ (SEYes _) = Just IZ
sinkViaSubenv IZ (SENo _) = Nothing
sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
|