summaryrefslogtreecommitdiff
path: root/src/AST/Sparse.hs
blob: 34a398febf470ff0aa72cd5f72243c4c8aba8ced (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RankNTypes #-}

{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where

import Data.Type.Equality

import AST
import AST.Sparse.Types
import Data (SBool(..))


sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
sparsePlus _ SpAbsent _ _ = ENil ext
sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2  -- heh
sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
  eunPair e1 $ \w1 e1a e1b ->
  eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
    EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
              (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
  elet e2 $
    elcase (weakenExpr WSink e1)
      (evar IZ)
      (elcase (evar (IS IZ))
        (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
        (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
        (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
      (elcase (evar (IS IZ))
        (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
        (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
        (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
  elet e2 $
    emaybe (weakenExpr WSink e1)
      (evar IZ)
      (emaybe (evar (IS IZ))
        (EJust ext (evar IZ))
        (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2


cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
cheapZero SMTNil = Just (ENil ext)
cheapZero (SMTPair t1 t2)
  | Just e1 <- cheapZero t1
  , Just e2 <- cheapZero t2
  = Just (EPair ext e1 e2)
  | otherwise
  = Nothing
cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
cheapZero SMTArr{} = Nothing
cheapZero (SMTScal t) = case t of
  STI32 -> Just (EConst ext t 0)
  STI64 -> Just (EConst ext t 0)
  STF32 -> Just (EConst ext t 0.0)
  STF64 -> Just (EConst ext t 0.0)


data Injection sp a b where
  -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
  -- 'sparsePlusS' can provide injections even if the caller doesn't require
  -- them. This simplifies the sparsePlusS code.
  Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
  Noinj :: Injection False a b

withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
withInj (Inj f) k = Inj (k f)
withInj Noinj _ = Noinj

withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
         -> ((forall e. Ex e a1 -> Ex e b1)
             -> (forall e. Ex e a2 -> Ex e b2)
             -> (forall e'. Ex e' a' -> Ex e' b'))
         -> Injection sp a' b'
withInj2 (Inj f) (Inj g) k = Inj (k f g)
withInj2 Noinj _ _ = Noinj
withInj2 _ Noinj _ = Noinj

use :: Ex env a -> Ex env b -> Ex env b
use a b = elet a $ weakenExpr WSink b

-- | This function produces quadratically-sized code in the presence of nested
-- dynamic sparsity. TODO can this be improved?
sparsePlusS
  :: SBool inj1 -> SBool inj2
  -> SMTy t -> Sparse t t1 -> Sparse t t2
  -> (forall t3. Sparse t t3
              -> Injection inj1 t1 t3  -- only available if first injection is requested (second argument may be absent)
              -> Injection inj2 t2 t3  -- only available if second injection is requested (first argument may be absent)
              -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
              -> r)
  -> r
-- nil override (but don't destroy effects!)
sparsePlusS _ _ SMTNil _ _ k =
  k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)

-- simplifications
sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
  sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
    k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
  sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
    k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))

sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
  let ta = applySparse sp1 (fromSMTy t) in
  sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
    k sp3
      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
      minj2
      (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
  let tb = applySparse sp2 (fromSMTy t) in
  sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
    k sp3
      minj1
      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
      (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))

sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
  let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
  sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
    k sp3
      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
      minj2
      (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
  let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
  sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
    k sp3
      minj1
      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
      (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))

sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
  let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
  sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
    k sp3
      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
      minj2
      (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
  let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
  sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
    k sp3
      minj1
      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
      (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k

-- TODO: sparse of Just is just Maybe

-- dense plus
sparsePlusS _ _ t sp1 sp2 k
  | Just Refl <- isDense t sp1
  , Just Refl <- isDense t sp2
  = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)

-- handle absents
sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
sparsePlusS ST _ t SpAbsent sp2 k
  | Just zero2 <- cheapZero (applySparse sp2 t) =
      k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
  | otherwise =
      k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)

sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
sparsePlusS _ ST t sp1 SpAbsent k
  | Just zero1 <- cheapZero (applySparse sp1 t) =
      k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
  | otherwise =
      k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)

-- double sparse yields sparse
sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
    k (SpSparse sp3)
      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
      (\a b ->
        elet b $
          emaybe (weakenExpr WSink a)
            (emaybe (evar IZ)
              (ENothing ext (applySparse sp3 (fromSMTy t)))
              (EJust ext (inj2 (evar IZ))))
            (emaybe (evar (IS IZ))
              (EJust ext (inj1 (evar IZ)))
              (EJust ext (plus (evar (IS IZ)) (evar IZ)))))

-- single sparse can yield non-sparse if the other argument is always present
sparsePlusS SF _ t (SpSparse sp1) sp2 k =
  sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
    k sp3 Noinj (Inj inj2)
      (\a b ->
        elet b $
          emaybe (weakenExpr WSink a)
            (inj2 (evar IZ))
            (plus (evar IZ) (evar (IS IZ))))
sparsePlusS ST _ t (SpSparse sp1) sp2 k =
  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
    k (SpSparse sp3)
      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
      (Inj $ \b -> EJust ext (inj2 b))
      (\a b ->
        elet b $
          emaybe (weakenExpr WSink a)
            (EJust ext (inj2 (evar IZ)))
            (EJust ext (plus (evar IZ) (evar (IS IZ)))))
sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
  sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
    k sp3 inj2 inj1 (flip plus)

-- products
sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
  sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
  sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
    k (SpPair sp3a sp3b)
      (withInj2 minj13a minj13b $ \inj13a inj13b ->
        \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
      (withInj2 minj23a minj23b $ \inj23a inj23b ->
        \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
      (\x1 x2 ->
        eunPair x1 $ \w1 x1a x1b ->
        eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
          EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))

-- coproducts
sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
  sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
  sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
    let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
        inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
        inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
    in
    k (SpLEither sp3a sp3b)
      (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
      (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
      (\x1 x2 ->
        elet x2 $
          elcase (weakenExpr WSink x1)
            (elcase (evar IZ)
              nil
              (inl (inj23a (evar IZ)))
              (inr (inj23b (evar IZ))))
            (elcase (evar (IS IZ))
              (inl (inj13a (evar IZ)))
              (inl (plusa (evar (IS IZ)) (evar IZ)))
              (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
            (elcase (evar (IS IZ))
              (inr (inj13b (evar IZ)))
              (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
              (inr (plusb (evar (IS IZ)) (evar IZ)))))

-- maybe
sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
    k (SpMaybe sp3)
      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
      (\a b ->
        elet b $
          emaybe (weakenExpr WSink a)
            (emaybe (evar IZ)
              (ENothing ext (applySparse sp3 (fromSMTy t)))
              (EJust ext (inj2 (evar IZ))))
            (emaybe (evar (IS IZ))
              (EJust ext (inj1 (evar IZ)))
              (EJust ext (plus (evar (IS IZ)) (evar IZ)))))

-- dense array cotangents simply recurse
sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
  sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
    k (SpArr sp3)
      (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
      (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
      (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
                      (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))

-- scalars
sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))