1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import AST
import AST.Sparse.Types
import Data
-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
-- expanding them into their concrete implementations. Also ensure that
-- 'EAccum' has a dense sparsity.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t e -> zero t e
EDeepZero _ t e -> deepZero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
EFst _ e -> EFst ext (unMonoid e)
ESnd _ e -> ESnd ext (unMonoid e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext t (unMonoid e)
EInr _ t e -> EInr ext t (unMonoid e)
ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
ELNil _ t1 t2 -> ELNil ext t1 t2
ELInl _ t e -> ELInl ext t (unMonoid e)
ELInr _ t e -> ELInr ext t (unMonoid e)
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EAccum _ t p eidx sp eval eacc ->
accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
zero SMTNil e = elet e $ ENil ext
zero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
zero (SMTScal t) _ = case t of
STI32 -> EConst ext STI32 0
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
deepZero SMTNil e = elet e $ ENil ext
deepZero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
deepZero (SMTLEither t1 t2) e =
elcase e
(ELNil ext (fromSMTy t1) (fromSMTy t2))
(ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
(ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
deepZero (SMTMaybe t) e =
emaybe e
(ENothing ext (fromSMTy t))
(EJust ext (deepZero t (evar IZ)))
deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
deepZero (SMTScal t) _ = case t of
STI32 -> EConst ext STI32 0
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-- don't destroy the effects!
plus SMTNil a b = elet a $ elet (weakenExpr WSink b) $ ENil ext
plus (SMTPair t1 t2) a b =
let t = STPair (fromSMTy t1) (fromSMTy t2)
in ELet ext a $
ELet ext (weakenExpr WSink b) $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
plus (SMTLEither t1 t2) a b =
let t = STLEither (fromSMTy t1) (fromSMTy t2)
in ELet ext a $
ELet ext (weakenExpr WSink b) $
ELCase ext (EVar ext t (IS IZ))
(EVar ext t IZ)
(ELCase ext (EVar ext t (IS IZ))
(EVar ext t (IS (IS IZ)))
(ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
(EError ext t "plus l+r"))
(ELCase ext (EVar ext t (IS IZ))
(EVar ext t (IS (IS IZ)))
(EError ext t "plus r+l")
(ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
plus (SMTMaybe t) a b =
ELet ext b $
EMaybe ext
(EVar ext (STMaybe (fromSMTy t)) IZ)
(EJust ext
(EMaybe ext
(EVar ext (fromSMTy t) IZ)
(plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
(EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
(weakenExpr WSink a)
plus (SMTArr _ t) a b =
ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
ELet ext arg $
EVar ext (fromSMTy typ) IZ
(SMTPair t1 t2, SAPFst prj) ->
ELet ext idx $
let tidx = typeOf idx in
ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
let toh = fromSMTy t1 in
EPair ext (EVar ext toh IZ)
(zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
(SMTPair t1 t2, SAPSnd prj) ->
ELet ext idx $
let tidx = typeOf idx in
ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
let toh = fromSMTy t2 in
EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
(EVar ext toh IZ)
(SMTLEither t1 t2, SAPLeft prj) ->
ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
(SMTLEither t1 t2, SAPRight prj) ->
ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
(SMTMaybe t1, SAPJust prj) ->
EJust ext (onehot t1 prj idx arg)
(SMTArr n t1, SAPArrIdx prj) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
(ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
zero t1 (EVar ext (tZeroInfo t1) IZ))
accumulateSparse
:: SMTy t -> Sparse t t' -> Ex env t'
-> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
-> Ex env TNil
accumulateSparse topty topsp arg accum = case (topty, topsp) of
(_, s) | Just Refl <- isDense topty s ->
accum WId SAPHere (ENil ext) arg
(SMTScal _, SpScal) ->
accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
(_, SpSparse s) ->
emaybe arg
(ENil ext)
(accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
(_, SpAbsent) ->
ENil ext
(SMTPair t1 t2, SpPair s1 s2) ->
eunPair arg $ \w1 e1 e2 ->
elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
(SMTLEither t1 t2, SpLEither s1 s2) ->
elcase arg
(ENil ext)
(accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
(accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
(SMTMaybe t, SpMaybe s) ->
emaybe arg
(ENil ext)
(accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
(SMTArr n t, SpArr s) ->
let tn = tTup (sreplicate n tIx) in
elet arg $
elet (EBuild ext n (EShape ext (evar IZ)) $
accumulateSparse t s
(EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
(\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
ENil ext
acPrjCompose
:: SAIDense dense
-> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
-> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
-> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
k (SAPFst p') idx'
acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
k (SAPSnd p') idx'
acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
| Dict <- styKnown (typeOf idx1) =
acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
| Dict <- styKnown (typeOf idx1) =
acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
k (SAPLeft p') idx'
acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
k (SAPRight p') idx'
acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
k (SAPJust p') idx'
acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
| Dict <- styKnown (typeOf idx1) =
acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
| Dict <- styKnown (typeOf idx1) =
acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
|