blob: ddae7fe96d31d848d4ecad954be41df3ae4f1946 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
|
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module AST.Sparse where
import Data.Kind (Constraint, Type)
import Data.Type.Equality
import AST
data Sparse t t' where
SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
SpAbsent :: Sparse t TNil
SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
SpScal :: Sparse (TScal t) (TScal t)
deriving instance Show (Sparse t t')
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
instance ApplySparse STy where
applySparse (SpSparse s) t = STMaybe (applySparse s t)
applySparse SpAbsent _ = STNil
applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
applySparse SpScal t = t
instance ApplySparse SMTy where
applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
applySparse SpAbsent _ = SMTNil
applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
applySparse SpScal t = t
class IsSubType s where
type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
subtTrans :: s a b -> s b c -> s a c
subtFull :: IsSubTypeSubject s f => f t -> s t t
instance IsSubType (:~:) where
type IsSubTypeSubject (:~:) f = ()
subtApply = gcastWith
subtTrans = trans
subtFull _ = Refl
instance IsSubType Sparse where
type IsSubTypeSubject Sparse f = f ~ SMTy
subtApply = applySparse
subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
subtTrans _ SpAbsent = SpAbsent
subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
subtTrans SpScal SpScal = SpScal
subtFull = spDense
spDense :: SMTy t -> Sparse t t
spDense SMTNil = SpAbsent
spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
spDense (SMTMaybe t) = SpMaybe (spDense t)
spDense (SMTArr _ t) = SpArr (spDense t)
spDense (SMTScal _) = SpScal
isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
isDense SMTNil SpAbsent = Just Refl
isDense _ SpSparse{} = Nothing
isDense _ SpAbsent = Nothing
isDense (SMTPair t1 t2) (SpPair s1 s2)
| Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
| otherwise = Nothing
isDense (SMTLEither t1 t2) (SpLEither s1 s2)
| Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
| otherwise = Nothing
isDense (SMTMaybe t) (SpMaybe s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
isDense (SMTScal _) SpScal = Just Refl
isAbsent :: Sparse t t' -> Bool
isAbsent (SpSparse s) = isAbsent s
isAbsent SpAbsent = True
isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
isAbsent SpScal = False
data SBool b where
SF :: SBool False
ST :: SBool True
deriving instance Show (SBool b)
data Injection sp a b where
-- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
-- 'sparsePlusS' can provide injections even if the caller doesn't require
-- them. This eliminates pointless checks.
Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
Noinj :: Injection False a b
withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
withInj (Inj f) k = Inj (k f)
withInj Noinj _ = Noinj
withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
-> ((forall e. Ex e a1 -> Ex e b1)
-> (forall e. Ex e a2 -> Ex e b2)
-> (forall e'. Ex e' a' -> Ex e' b'))
-> Injection sp a' b'
withInj2 (Inj f) (Inj g) k = Inj (k f g)
withInj2 Noinj _ _ = Noinj
withInj2 _ Noinj _ = Noinj
-- | This function produces quadratically-sized code in the presence of nested
-- dynamic sparsity. しょうがない。
sparsePlusS
:: SBool inj1 -> SBool inj2
-> SMTy t -> Sparse t t1 -> Sparse t t2
-> (forall t3. Sparse t t3
-> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
-> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
-> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
-> r)
-> r
-- nil override
sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext)
-- simplifications
sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b)
sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext))
sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
let ta = applySparse sp1 (fromSMTy t) in
sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
k sp3
(withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
minj2
(\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
let tb = applySparse sp2 (fromSMTy t) in
sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
k sp3
minj1
(withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
(\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
k sp3
(withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
minj2
(\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
k sp3
minj1
(withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
(\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
k sp3
(withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
minj2
(\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
k sp3
minj1
(withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
(\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
-- TODO: sparse of Just is just Maybe
-- dense plus
sparsePlusS _ _ t sp1 sp2 k
| Just Refl <- isDense t sp1
, Just Refl <- isDense t sp2
= k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
-- handle absents
sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b)
sparsePlusS ST _ t SpAbsent sp2 k =
k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b)
sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a)
sparsePlusS _ ST t sp1 SpAbsent k =
k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a)
-- double sparse yields sparse
sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
k (SpSparse sp3)
(Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
(Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
(\a b ->
elet b $
emaybe (weakenExpr WSink a)
(emaybe (evar IZ)
(ENothing ext (applySparse sp3 (fromSMTy t)))
(EJust ext (inj2 (evar IZ))))
(emaybe (evar (IS IZ))
(EJust ext (inj1 (evar IZ)))
(EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-- single sparse can yield non-sparse if the other argument is always present
sparsePlusS SF _ t (SpSparse sp1) sp2 k =
sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
k sp3 Noinj (Inj inj2)
(\a b ->
elet b $
emaybe (weakenExpr WSink a)
(inj2 (evar IZ))
(plus (evar IZ) (evar (IS IZ))))
sparsePlusS ST _ t (SpSparse sp1) sp2 k =
sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
k (SpSparse sp3)
(Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
(Inj $ \b -> EJust ext (inj2 b))
(\a b ->
elet b $
emaybe (weakenExpr WSink a)
(EJust ext (inj2 (evar IZ)))
(EJust ext (plus (evar IZ) (evar (IS IZ)))))
sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
k sp3 inj2 inj1 (flip plus)
-- products
sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
k (SpPair sp3a sp3b)
(withInj2 minj13a minj13b $ \inj13a inj13b ->
\x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
(withInj2 minj23a minj23b $ \inj23a inj23b ->
\x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
(\x1 x2 ->
eunPair x1 $ \w1 x1a x1b ->
eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
-- coproducts
sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
in
k (SpLEither sp3a sp3b)
(Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
(Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
(\x1 x2 ->
elet x2 $
elcase (weakenExpr WSink x1)
(elcase (evar IZ)
nil
(inl (inj23a (evar IZ)))
(inr (inj23b (evar IZ))))
(elcase (evar (IS IZ))
(inl (inj13a (evar IZ)))
(inl (plusa (evar (IS IZ)) (evar IZ)))
(EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
(elcase (evar (IS IZ))
(inr (inj13b (evar IZ)))
(EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
(inr (plusb (evar (IS IZ)) (evar IZ)))))
-- maybe
sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
k (SpMaybe sp3)
(Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
(Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
(\a b ->
elet b $
emaybe (weakenExpr WSink a)
(emaybe (evar IZ)
(ENothing ext (applySparse sp3 (fromSMTy t)))
(EJust ext (inj2 (evar IZ))))
(emaybe (evar (IS IZ))
(EJust ext (inj1 (evar IZ)))
(EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-- dense array cotangents simply recurse
sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
k (SpArr sp3)
(withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
(withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
(ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
(EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
-- scalars
sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
|