1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- I want to bring various type variables in scope using type annotations in
-- patterns, but I don't want to have to mention all the other type parameters
-- of the types in question as well then. Partial type signatures (with '_') are
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
module ForwardAD.DualNumbers (
dfwdDN,
DN, DNS, DNE, dn, dne,
) where
import AST
import Data
-- | Dual-numbers transformation
type family DN t where
DN TNil = TNil
DN (TPair a b) = TPair (DN a) (DN b)
DN (TEither a b) = TEither (DN a) (DN b)
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
type family DNS t where
DNS TF32 = TPair (TScal TF32) (TScal TF32)
DNS TF64 = TPair (TScal TF64) (TScal TF64)
DNS TI32 = TScal TI32
DNS TI64 = TScal TI64
DNS TBool = TScal TBool
type family DNE env where
DNE '[] = '[]
DNE (t : ts) = DN t : DNE ts
dn :: STy t -> STy (DN t)
dn STNil = STNil
dn (STPair a b) = STPair (dn a) (dn b)
dn (STEither a b) = STEither (dn a) (dn b)
dn (STMaybe t) = STMaybe (dn t)
dn (STArr n t) = STArr n (dn t)
dn (STScal t) = case t of
STF32 -> STPair (STScal STF32) (STScal STF32)
STF64 -> STPair (STScal STF64) (STScal STF64)
STI32 -> STScal STI32
STI64 -> STScal STI64
STBool -> STScal STBool
dn STAccum{} = error "Accum in source program"
dne :: SList STy env -> SList STy (DNE env)
dne SNil = SNil
dne (t `SCons` env) = dn t `SCons` dne env
dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
dnPreservesTupIx SZ = Refl
dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
convIdx :: Idx env t -> Idx (DNE env) (DN t)
convIdx IZ = IZ
convIdx (IS i) = IS (convIdx i)
scalTyCase :: SScalTy t
-> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
-> (DN (TScal t) ~ TScal t => r)
-> r
scalTyCase STF32 k1 _ = k1
scalTyCase STF64 k1 _ = k1
scalTyCase STI32 _ k2 = k2
scalTyCase STI64 _ k2 = k2
scalTyCase STBool _ k2 = k2
-- | Argument does not need to be duplicable.
dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
dop = \case
OAdd t -> scalTyCase t
(binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
(EOp ext (OAdd t))
OMul t -> scalTyCase t
(binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
(EOp ext (OMul t))
ONeg t -> scalTyCase t
(unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
(EOp ext (ONeg t))
OLt t -> scalTyCase t
(binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
(EOp ext (OLt t))
OLe t -> scalTyCase t
(binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
(EOp ext (OLe t))
OEq t -> scalTyCase t
(binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
(EOp ext (OEq t))
ONot -> EOp ext ONot
OAnd -> EOp ext OAnd
OOr -> EOp ext OOr
OIf -> EOp ext OIf
ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
add t a b = EOp ext (OAdd t) (EPair ext a b)
mul :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
mul t a b = EOp ext (OMul t) (EPair ext a b)
neg :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
neg t = EOp ext (ONeg t)
unFloat :: DN a ~ TPair a a
=> (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
-> Ex env (DN a) -> Ex env (DN b)
unFloat f e =
ELet ext e $
let var = EVar ext (typeOf e) IZ
in f (EFst ext var, ESnd ext var)
binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
=> (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
-> Ex env (DN a) -> Ex env (DN b)
binFloat f e =
ELet ext e $
let var = EVar ext (typeOf e) IZ
in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
(EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
EVar _ t i -> EVar ext (dn t) (convIdx i)
ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
EFst _ e -> EFst ext (dfwdDN e)
ESnd _ e -> ESnd ext (dfwdDN e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext (dn t) (dfwdDN e)
EInr _ t e -> EInr ext (dn t) (dfwdDN e)
ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
ENothing _ t -> ENothing ext (dn t)
EJust _ e -> EJust ext (dfwdDN e)
EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
EConstArr _ n t x -> scalTyCase t
(emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
(EConstArr ext n t x))
(EConstArr ext n t x)
EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
ESum1Inner _ e ->
let STArr n (STScal t) = typeOf e
pairty = (STPair (STScal t) (STScal t))
in scalTyCase t
(ELet ext (dfwdDN e) $
ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
(EVar ext (STArr n pairty) IZ)))
(ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
(EVar ext (STArr n pairty) IZ))))
(ESum1Inner ext (dfwdDN e))
EUnit _ e -> EUnit ext (dfwdDN e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
EIdx0 _ e -> EIdx0 ext (dfwdDN e)
EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
EIdx _ a b
| STArr n _ <- typeOf a
, Refl <- dnPreservesTupIx n
-> EIdx ext (dfwdDN a) (dfwdDN b)
EShape _ e
| Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
EError t s -> EError (dn t) s
EWith{} -> err_accum
EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
|