1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module CHAD.Language (
-- * Named expressions
fromNamed,
NExpr, NFun,
-- * Functions
lambda,
body,
inline,
(.$),
-- * Basic language constructs
let_,
pair, fst_, snd_, nil,
inl, inr, case_,
nothing, just, maybe_,
-- * Array operations
constArr_,
build1, build2, build,
map_,
fold1i, fold1i',
sum1i,
unit,
replicate1i,
maximum1i, minimum1i,
reshape,
fold1iD1, fold1iD1',
fold1iD2,
-- * Scalar operations
-- | Note that 'NExpr' is also an instance of some numeric classes like 'Num' and 'Floating'.
const_,
idx0,
(!),
shape,
length_,
error_,
(.==), (.<), (CHAD.Language..>), (.<=), (.>=),
not_, and_, or_,
mod_, round_, toFloat_, idiv,
-- * Control flow
if_,
-- * Special operations
custom,
recompute,
with, accum, accumS,
oper, oper2,
-- * Helper types
(:->)(..),
-- * Reexports
TIx,
Lookup,
Ex,
Ty(..),
SNat(..), Nat(..), N0, N1, N2, N3,
) where
import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
import CHAD.Array
import CHAD.AST
import CHAD.AST.Sparse.Types
import CHAD.Data
import CHAD.Drev.Types
import CHAD.Language.AST
-- | Helper type, used for e.g. 'case_' and 'build'.
data a :-> b = a :-> b
deriving (Show)
infixr 0 :->
-- | See 'fromNamed' for a usage example.
body :: NExpr env t -> NFun env env t
body = NBody
-- | See 'fromNamed' for a usage example.
lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam
-- | Inline a function here, with the given list of expressions as arguments.
-- While this is a normal 'SList', the @params@ list is reversed from the
-- natural argument order of the function; the '(.$)' helper operator serves to
-- "fix" the order.
--
-- @
-- let fun = 'lambda' \@(TScal TF64) #x $ 'lambda' \@(TScal TBool) #b $ 'body' $ if_ #b #x (#x + 1)
-- in 'inline' fun ('SNil' .$ 16 .$ 'const_' True)
-- @
--
-- Note that no 'const_' is needed for the @16@, because 'NExpr' implements
-- 'Num'.
inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
inline = inlineNFun
-- | Helper for constructing the argument list for 'inline';
-- @(.$) = flip 'SCons'@. See 'inline'.
(.$) :: SList f list -> f a -> SList f (a : list)
(.$) = flip SCons
-- | The first 'Var' argument is the left-hand side of this let-binding. For example:
--
-- @
-- 'fromNamed' $ 'lambda' \@(TScal TI64) #a $ 'body' $
-- 'let_' #x (#a + 1) $
-- #x * #a
-- @
--
-- This produces an expression of type @'Ex' '[TScal TI64] (TScal TI64)@ that
-- corresponds to the Haskell code @\\a -> let x = a + 1 in x * a@.
let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet
pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
pair = NEPair
fst_ :: NExpr env (TPair a b) -> NExpr env a
fst_ = NEFst
snd_ :: NExpr env (TPair a b) -> NExpr env b
snd_ = NESnd
nil :: NExpr env TNil
nil = NENil
inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b)
inl = NEInl knownTy
inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b)
inr = NEInr knownTy
-- | A @case@ expression on @Either@s. For example, the following expression
-- will evaluate to 10 + 1 = 11:
--
-- @
-- 'case_' ('inl' 10)
-- (#x :-> #x + 1)
-- (#y :-> #y * 2)
-- @
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
nothing :: KnownTy a => NExpr env (TMaybe a)
nothing = NENothing knownTy
just :: NExpr env a -> NExpr env (TMaybe a)
just = NEJust
-- | Analogue of the 'Prelude.maybe' function in the Haskell Prelude:
--
-- @
-- 'maybe_' 2 (#x :-> #x * 3) (...)
-- @
--
-- will return 2 if @(...)@ is @Nothing@ and @x + 3@ if it is @Just x@.
maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b
maybe_ a (v :-> b) c = NEMaybe a v b c
-- | To construct 'Array' values, see "CHAD.Array".
constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConstArr knownNat ty x
-- | Special case of 'build' for 1-dimensional arrays. This produces the array
-- [0.0, 1.0, 2.0]:
--
-- @
-- 'build1' 3 (#i :-> 'toFloat_' #i)
-- @
build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
-- | Special case of 'build' for 2-dimensional arrays.
build2 :: NExpr env TIx -> NExpr env TIx
-> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
-> NExpr env (TArr (S (S Z)) t)
build2 a1 a2 (v1 :-> v2 :-> b) =
NEBuild (SS (SS SZ))
(pair (pair nil a1) a2)
#idx
(let_ v1 (snd_ (fst_ #idx)) $
let_ v2 (NEDrop SZ (snd_ #idx)) $
NEDrop (SS (SS SZ)) b)
-- | General n-dimensional elementwise array constructor. A 3-dimensional index
-- looks like @((((), i1), i2), i3)@; other dimensionalities are analogous. The
-- innermost dimension (i.e. whose index variable varies the fastest in the
-- standard memory layout) is the right-most index, i.e. @i3@ in 3D example. To
-- create a 10-by-10 table of (row, column) pairs:
--
-- @
-- 'build' ('SS' ('SS' 'SZ')) ('pair' ('pair' 'nil' 10) 10) (#i :-> #j :-> 'pair' #i #j)
-- @
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
map_ :: forall n a b env name. (KnownNat n, KnownTy a)
=> (Var name a :-> NExpr ('(name, a) : env) b)
-> NExpr env (TArr n a) -> NExpr env (TArr n b)
map_ (v :-> a) b = NEMap v a b
-- | Fold over the innermost dimension of an array, thus reducing its dimensionality by one.
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
assertSymbolNotUnderscore s3 $
equalityReflexive s3 $
assertSymbolDistinct s3 s1 $
let v3 = Var s3 (STPair t t)
in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
let_ v2 (snd_ (NEVar v3)) $
NEDrop (SS (SS SZ)) e1)
e2 e3
-- | The underlying AST constructor for a fold takes a function with /one/
-- argument: a pair of inputs. 'fold1i'' directly returns this AST constructor
-- in case it is helpful for testing. The 'fold1i' function is a convenience
-- wrapper around 'fold1i''.
fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3
sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e
unit :: NExpr env t -> NExpr env (TArr Z t)
unit = NEUnit
replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
replicate1i n a = NEReplicate1Inner n a
maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
maximum1i e = NEMaximum1Inner e
minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
minimum1i e = NEMinimum1Inner e
reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
reshape = NEReshape
-- | 'fold1iD1'' with a curried combination function.
fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
-> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
assertSymbolNotUnderscore s3 $
equalityReflexive s3 $
assertSymbolDistinct s3 s1 $
let v3 = Var s3 (STPair t1 t1)
in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
let_ v2 (snd_ (NEVar v3)) $
NEDrop (SS (SS SZ)) e1)
e2 e3
-- | Primal of a fold. Not supported in the input program for reverse differentiation.
fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b))
-> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3
-- | Reverse pass of a fold. Not supported in the input program for reverse differentiation.
fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
-> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConst ty x
idx0 :: NExpr env (TArr Z t) -> NExpr env t
idx0 = NEIdx0
-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
-- (.!) = NEIdx1
-- infixl 9 .!
-- | Index an array. Note that the index is a tuple, just like the argument to
-- the function in 'build'. To index a 2-dimensional array @a@ at row @i@ and
-- column @j@, write @a '!' 'pair' ('pair' 'nil' i) j@.
(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx
infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
-- | Convenience special case of 'shape' for single-dimensional arrays.
length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
length_ e = snd_ (shape e)
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t
oper2 op a b = NEOp op (pair a b)
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
-- | Specify a custom reverse derivative for a subexpression. Morally, the type
-- of this combinator should be read as follows:
--
-- @
-- custom :: (a -> b -> t) -- normal semantics
-- -> (D1 a -> D1 b -> (D1 t, tape)) -- forward pass
-- -> (tape -> D2 t -> D2 b) -- reverse pass
-- -> a -> b -- arguments
-- -> t -- result
-- @
--
-- In normal evaluation, or when forward-differentiating, the first argument is
-- taken and the second and third are ignored. When reverse-differentiating
-- using CHAD, however, the /first/ argument is ignored and the second and
-- third arguments are respectively put in the forward and the reverse passes
-- of the derivative program. The @tape@ value may be used to remember primals
-- for the reverse pass.
--
-- This combinator allows for "inactive" and "active" inputs to the operation;
-- derivatives to the "inactive" input are not propagated. The active input
-- (whose derivatives /are/ propagated) has type @b@; the inactive input has
-- type @a@.
--
-- No accumulators are allowed inside @a@, @b@ and @tape@.
custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
-> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
-> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
-> NExpr env a -> NExpr env b
-> NExpr env t
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
-- | Semantically the identity, but when reverse differentiating using CHAD,
-- the contained expression is recomputed in the reverse pass. This is a
-- light-weight form of checkpointing, with the goal of reducing the number
-- primal values being stored and thus reducing memory use and memory traffic.
--
-- Note that free variables of the contained expression do still need to be
-- stored, as we do need to be able to recompute the expression in the reverse
-- pass.
recompute :: NExpr env a -> NExpr env a
recompute = NERecompute
-- | Introduce an accumulator. The initial value is not allowed to be sparse!
-- See 'CHAD.AST.EWith'. Not supported in the input program for reverse
-- differentiation.
with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
with a (n :-> b) = NEWith (knownMTy @t) a n b
-- | Accumulate to an accumulator. Not supported in the input program for
-- reverse differentiation.
accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
-- | Accumulate to an accumulator with additional sparsity. Not supported in
-- the input program for reverse differentiation.
accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
accumS p a sp b c = NEAccum knownMTy p a sp b c
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==
(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .< b = oper (OLt knownScalTy) (pair a b)
infix 4 .<
(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>) = flip (.<)
infix 4 .>
(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .<= b = oper (OLe knownScalTy) (pair a b)
infix 4 .<=
(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>=) = flip (.<=)
infix 4 .>=
not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot
and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
and_ = oper2 OAnd
infixr 3 `and_`
or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool)
or_ = oper2 OOr
infixr 2 `or_`
mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a)
mod_ = oper2 (OMod knownScalTy)
infixl 7 `mod_`
-- | The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64)
round_ = oper ORound64
toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64)
toFloat_ = oper OToFl64
idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t)
idiv = oper2 (OIDiv knownScalTy)
infixl 7 `idiv`
|