summaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
blob: 4b6b52354396ff1b84c3817549b5fe5b3411c5d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid where

import AST
import CHAD.Types
import Data


unMonoid :: Ex env t -> Ex env t
unMonoid = \case
  EZero _ t -> zero t
  EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
  EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b)

  EVar _ t i -> EVar ext t i
  ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
  EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
  EFst _ e -> EFst ext (unMonoid e)
  ESnd _ e -> ESnd ext (unMonoid e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext t (unMonoid e)
  EInr _ t e -> EInr ext t (unMonoid e)
  ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
  ENothing _ t -> ENothing ext t
  EJust _ e -> EJust ext (unMonoid e)
  EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
  EConstArr _ n t x -> EConstArr ext n t x
  EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
  EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
  ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
  EUnit _ e -> EUnit ext (unMonoid e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
  EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
  EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
  EConst _ t x -> EConst ext t x
  EIdx0 _ e -> EIdx0 ext (unMonoid e)
  EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
  EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
  EShape _ e -> EShape ext (unMonoid e)
  EOp _ op e -> EOp ext op (unMonoid e)
  ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
  EWith _ a b -> EWith ext (unMonoid a) (unMonoid b)
  EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e)
  EError _ t s -> EError ext t s

zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
zero (STArr SZ t) = EUnit ext (zero t)
zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty")
zero (STScal t) = case t of
  STI32 -> ENil ext
  STI64 -> ENil ext
  STF32 -> EConst ext STF32 0.0
  STF64 -> EConst ext STF64 0.0
  STBool -> ENil ext
zero STAccum{} = error "Accumulators not allowed in input program"

plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
plus (STPair t1 t2) a b =
  let t = STPair (d2 t1) (d2 t2)
  in plusSparse t a b $
       EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
                          (EFst ext (EVar ext t IZ)))
                 (plus t2 (ESnd ext (EVar ext t (IS IZ)))
                          (ESnd ext (EVar ext t IZ)))
plus (STEither t1 t2) a b =
  let t = STEither (d2 t1) (d2 t2)
  in plusSparse t a b $
       ECase ext (EVar ext t (IS IZ))
         (ECase ext (EVar ext t (IS IZ))
           (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
           (EError ext t "plus l+r"))
         (ECase ext (EVar ext t (IS IZ))
           (EError ext t "plus r+l")
           (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
plus (STMaybe t) a b =
  plusSparse (d2 t) a b $
    plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
plus (STArr n t) a b =
  ELet ext a $
  ELet ext (weakenExpr WSink b) $
    eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
        (EVar ext (STArr n (d2 t)) IZ)
        (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
             (EVar ext (STArr n (d2 t)) (IS IZ))
             (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
               (EVar ext (STArr n (d2 t)) (IS IZ))
               (EVar ext (STArr n (d2 t)) IZ)))
plus (STScal t) a b = case t of
  STI32 -> ENil ext
  STI64 -> ENil ext
  STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
  STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
  STBool -> ENil ext
plus STAccum{} _ _ = error "Accumulators not allowed in input program"

plusSparse :: STy a
           -> Ex env (TMaybe a) -> Ex env (TMaybe a)
           -> Ex (a : a : env) a
           -> Ex env (TMaybe a)
plusSparse t a b adder =
  ELet ext b $
    EMaybe ext
      (EVar ext (STMaybe t) IZ)
      (EJust ext
        (EMaybe ext
          (EVar ext t IZ)
          (weakenExpr (WCopy (WCopy WSink)) adder)
          (EVar ext (STMaybe t) (IS IZ))))
      (weakenExpr WSink a)

onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t)
onehot _ SZ _ val = val
onehot t (SS dep) idx val = case t of
  STPair t1 t2 ->
    case dep of
      SZ -> EJust ext val
      SS dep' ->
        let STEither tidx1 tidx2 = typeOf idx
            STEither tval1 tval2 = typeOf val
        in EJust ext $
            ECase ext idx
              (ECase ext (weakenExpr WSink val)
                (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
                           (zero t2))
                (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
              (ECase ext (weakenExpr WSink val)
                (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
                (EPair ext (zero t1)
                           (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))

  STEither t1 t2 ->
    case dep of
      SZ -> EJust ext val
      SS dep' ->
        let STEither tidx1 tidx2 = typeOf idx
            STEither tval1 tval2 = typeOf val
        in EJust ext $
            ECase ext idx
              (ECase ext (weakenExpr WSink val)
                (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
                (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
              (ECase ext (weakenExpr WSink val)
                (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l")
                (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))

  STMaybe t1 -> EJust ext (onehot t1 dep idx val)

  STArr n t1 ->
    ELet ext val $
      EBuild ext n (EFst ext (EVar ext (typeOf val) IZ))
                   (onehotArrayElem t1 n (SS dep)
                                    (EVar ext (tTup (sreplicate n tIx)) IZ)
                                    (weakenExpr (WSink .> WSink) idx)
                                    (ESnd ext (EVar ext (typeOf val) (IS IZ))))

  STNil -> error "Cannot index into nil"
  STScal{} -> error "Cannot index into scalar"
  STAccum{} -> error "Accumulators not allowed in input program"

onehotArrayElem
  :: STy t -> SNat n -> SNat i
  -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
  -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot
  -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole
  -> Ex env (D2 t)
onehotArrayElem t n dep eltidx idx val =
  ELet ext eltidx $
  ELet ext (weakenExpr WSink idx) $
    let (cond, elt) = onehotArrayElemRec t n dep
                                         (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
                                         (EVar ext (typeOf idx) IZ)
                                         (weakenExpr (WSink .> WSink) val)
    in eif cond elt (zero t)

-- AcIdx must be duplicable
onehotArrayElemRec
  :: STy t -> SNat n -> SNat i
  -> [Ex env TIx]
  -> Ex env (AcIdx (TArr n (D2 t)) i)
  -> Ex env (AcValArr n (D2 t) i)
  -> (Ex env (TScal TBool), Ex env (D2 t))
onehotArrayElemRec _ n SZ eltidx _ val =
  (EConst ext STBool True
  ,EIdx ext val (reconstructFromOutsideIn n eltidx))
onehotArrayElemRec t SZ (SS dep) eltidx idx val =
  case eltidx of
    [] -> (EConst ext STBool True, onehot t dep idx val)
    _ -> error "onehotArrayElemRec: mismatched list length"
onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
  case eltidx of
    i : eltidx' ->
      let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
      in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
         ,elt)
    [] -> error "onehotArrayElemRec: mismatched list length"

-- | Outermost index at the head. The input expression must be duplicable.
outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
outsideInIndex = \n idx -> go n idx []
  where
    go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx]
    go SZ _ acc = acc
    go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc)

-- Takes a list with the outermost index at the head. Returns a tuple with the
-- innermost index on the right.
reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
reconstructFromOutsideIn = \n list -> go n (reverse list)
  where
    -- Takes list with the _innermost_ index at the head.
    go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
    go SZ [] = ENil ext
    go (SS n) (i:is) = EPair ext (go n is) i
    go _ _ = error "reconstructFromOutsideIn: mismatched list length"