1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- I want to bring various type variables in scope using type annotations in
-- patterns, but I don't want to have to mention all the other type parameters
-- of the types in question as well then. Partial type signatures (with '_') are
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
module ForwardAD.DualNumbers (
dfwdDN,
DN, DNS, DNE, dn, dne,
) where
import AST
import Data
import ForwardAD.DualNumbers.Types
dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
dnPreservesTupIx SZ = Refl
dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
convIdx :: Idx env t -> Idx (DNE env) (DN t)
convIdx IZ = IZ
convIdx (IS i) = IS (convIdx i)
scalTyCase :: SScalTy t
-> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
-> (DN (TScal t) ~ TScal t => r)
-> r
scalTyCase STF32 k1 _ = k1
scalTyCase STF64 k1 _ = k1
scalTyCase STI32 _ k2 = k2
scalTyCase STI64 _ k2 = k2
scalTyCase STBool _ k2 = k2
floatingDual :: ScalIsFloating t ~ True
=> SScalTy t
-> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r
floatingDual STF32 k = k
floatingDual STF64 k = k
-- | Argument does not need to be duplicable.
dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
dop = \case
OAdd t -> scalTyCase t
(binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
(EOp ext (OAdd t))
OMul t -> scalTyCase t
(binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
(EOp ext (OMul t))
ONeg t -> scalTyCase t
(unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
(EOp ext (ONeg t))
OLt t -> scalTyCase t
(binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
(EOp ext (OLt t))
OLe t -> scalTyCase t
(binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
(EOp ext (OLe t))
OEq t -> scalTyCase t
(binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
(EOp ext (OEq t))
ONot -> EOp ext ONot
OAnd -> EOp ext OAnd
OOr -> EOp ext OOr
OIf -> EOp ext OIf
ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
ORecip t -> floatingDual t $ unFloat (\(x, dx) ->
EPair ext (recip' t x)
(mul t (neg t (recip' t (mul t x x))) dx))
OExp t -> floatingDual t $ unFloat (\(x, dx) ->
EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx))
OLog t -> floatingDual t $ unFloat (\(x, dx) ->
EPair ext (EOp ext (OLog t) x)
(mul t (recip' t x) dx))
OIDiv t -> scalTyCase t
(case t of {})
(EOp ext (OIDiv t))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
add t a b = EOp ext (OAdd t) (EPair ext a b)
mul :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
mul t a b = EOp ext (OMul t) (EPair ext a b)
neg :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
neg t = EOp ext (ONeg t)
recip' :: ScalIsFloating t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
recip' t = EOp ext (ORecip t)
unFloat :: DN a ~ TPair a a
=> (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
-> Ex env (DN a) -> Ex env (DN b)
unFloat f e =
ELet ext e $
let var = EVar ext (typeOf e) IZ
in f (EFst ext var, ESnd ext var)
binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
=> (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
-> Ex env (DN a) -> Ex env (DN b)
binFloat f e =
ELet ext e $
let var = EVar ext (typeOf e) IZ
in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
(EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
zeroScalarConst STI32 = EConst ext STI32 0
zeroScalarConst STI64 = EConst ext STI64 0
zeroScalarConst STF32 = EConst ext STF32 0.0
zeroScalarConst STF64 = EConst ext STF64 0.0
dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
EVar _ t i -> EVar ext (dn t) (convIdx i)
ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
EFst _ e -> EFst ext (dfwdDN e)
ESnd _ e -> ESnd ext (dfwdDN e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext (dn t) (dfwdDN e)
EInr _ t e -> EInr ext (dn t) (dfwdDN e)
ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
ENothing _ t -> ENothing ext (dn t)
EJust _ e -> EJust ext (dfwdDN e)
EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
EConstArr _ n t x -> scalTyCase t
(emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
(EConstArr ext n t x))
(EConstArr ext n t x)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
ESum1Inner _ e ->
let STArr n (STScal t) = typeOf e
pairty = (STPair (STScal t) (STScal t))
in scalTyCase t
(ELet ext (dfwdDN e) $
ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
(EVar ext (STArr n pairty) IZ)))
(ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
(EVar ext (STArr n pairty) IZ))))
(ESum1Inner ext (dfwdDN e))
EUnit _ e -> EUnit ext (dfwdDN e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
EIdx0 _ e -> EIdx0 ext (dfwdDN e)
EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
EIdx _ a b
| STArr n _ <- typeOf a
, Refl <- dnPreservesTupIx n
-> EIdx ext (dfwdDN a) (dfwdDN b)
EShape _ e
| Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
-> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
ECustom _ _ _ _ pr _ _ e1 e2 ->
ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
EError t s -> EError (dn t) s
EWith{} -> err_accum
EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
deriv_extremum :: ScalIsNumeric t ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
-> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t)))
deriv_extremum extremum e =
let STArr (SS n) (STScal t) = typeOf e
t2 = STPair (STScal t) (STScal t)
ta2 = STArr (SS n) t2
tIxN = tTup (sreplicate (SS n) tIx)
in scalTyCase t
(ELet ext (dfwdDN e) $
ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $
ezip (EVar ext (STArr n (STScal t)) IZ)
(ESum1Inner ext
{- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -}
(EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $
ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $
ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext
(EFst ext (EVar ext t2 IZ))
(EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ)))
(EFst ext (EVar ext tIxN (IS IZ)))))))
(ESnd ext (EVar ext t2 (IS IZ)))
(zeroScalarConst t))))
(EMaximum1Inner ext (dfwdDN e))
|