1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Main where
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main
import Array
import AST
import AST.Pretty
import CHAD.Top
import CHAD.Types
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
import Simplify
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
dterm | Dict <- envKnown env
= case simplIters of
SimplIters n -> simplifyN n dtermNonSimpl
SimplFix -> simplifyFix dtermNonSimpl
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
Value (toTan t p x) `SCons` toTanE env primal inp
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
STPair t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
STEither t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
STMaybe t -> liftA2 (toTan t) primal der
STArr _ t
| shapeSize (arrayShape der) == 0 ->
arrayMap (zeroTan t) primal
| arrayShape primal == arrayShape der ->
arrayGenerateLin (arrayShape primal) $ \i ->
toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
| otherwise ->
error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
closeIsh :: Double -> Double -> Bool
closeIsh a b =
abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- An empty name means "no restrictions".
data TplConstr = C String -- ^ name; @""@ means anonymous
Int -- ^ minimum value to generate
type family DimNames n where
DimNames Z = ()
DimNames (S Z) = TplConstr
DimNames (S n) = DimNames n :$ TplConstr
type family Tpl t where
Tpl (TArr n t) = DimNames n
Tpl (TPair a b) = (Tpl a, Tpl b)
-- If you add equations here, don't forget to update genValue! It currently
-- just emptyTpl's things out.
Tpl _ = ()
data a :& b = a :& b deriving (Show) ; infixl :&
type family TemplateE env where
TemplateE '[] = ()
TemplateE '[t] = Tpl t
TemplateE (t : ts) = TemplateE ts :& Tpl t
emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
emptyDimNames (SS SZ) = C "" 0
emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0
emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
emptyTpl (STScal _) = ()
emptyTpl _ = error "too lazy"
emptyTemplateE :: SList STy env -> TemplateE env
emptyTemplateE SNil = ()
emptyTemplateE (t `SCons` SNil) = emptyTpl t
emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t
genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShape = \n tpl -> do
sh <- genShapeNaive n tpl
let sz = shapeSize sh
factor = sz `div` 100 + 1
return (shapeDiv sh factor)
where
genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShapeNaive SZ () = return ShNil
genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
genNamedDim (C "" lo) = genDim lo
genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
Nothing -> do
dim <- genDim lo
modify (Map.insert name dim)
return dim
Just dim -> return dim
genDim :: Int -> StateT (Map String Int) Gen Int
genDim lo = Gen.integral (Range.linear lo 10)
shapeDiv :: Shape n -> Int -> Shape n
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
Value <$> arrayGenerateLinM sh (\_ ->
unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
STNil -> return (Value ())
STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
,liftV Right <$> genValue b (emptyTpl b)]
STMaybe t -> Gen.choice [return (Value Nothing)
,liftV Just <$> genValue t (emptyTpl t)]
STArr n t -> genShape n tpl >>= lift . genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = adTestCon (const True)
adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property
adTestCon constr term =
let env = knownEnv
in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
=> TemplateE env -> Ex env (TScal TF64) -> Property
adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
=> Ex env (TScal TF64) -> Gen (SList Value env) -> Property
adTestGen expr envGenerator = property $ do
let env = knownEnv @env
input <- forAllWith (showEnv env) envGenerator
let outPrimal = interpretOpen False input expr
gradFwd = gradientByForward knownEnv expr input
(_ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
(ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
(ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
-- annotate (ppExpr knownEnv expr)
-- annotate ppdterm
-- annotate ppdterm_S
diff ppdterm_S (==) ppdterm_S20
diff outChad_S closeIsh outChad
diff outChad_S closeIsh outPrimal
diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD
diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)
,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))
,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $
idx0 $ sum1i $ replicate1i 10 #x)
,("pairs", adTest term_pairs)
,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0)
,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)
,("build1-sum", adTest term_build1_sum)
,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)
,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ maximum1i #x)
,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ minimum1i #x)
,("unused", adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42)
,("neural", adTestGen Example.neural $ do
let tR = STScal STF64
let genLayer nin nout =
liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
<*> genArray tR (ShNil `ShCons` nout)
nin <- Gen.integral (Range.linear 1 10)
n1 <- Gen.integral (Range.linear 1 10)
n2 <- Gen.integral (Range.linear 1 10)
input <- genArray tR (ShNil `ShCons` nin)
lay1 <- genLayer nin n1
lay2 <- genLayer n1 n2
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
,("logsumexp", adTestTp (C "" 1) $
fromNamed $ lambda @(TArr N1 _) #vec $ body $
let_ #m (maximum1i #vec) $
log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m)
,("mulmatvec", adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $
fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
let_ #hei (snd_ (fst_ (shape #mat))) $
let_ #wid (snd_ (shape #mat)) $
build1 #hei $ #i :->
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j)))
,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM)
,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM)
]
where
genGMM = do
-- The input ranges here are completely arbitrary.
let tR = STScal STF64
kN <- Gen.integral (Range.linear 1 10)
kD <- Gen.integral (Range.linear 1 10)
kK <- Gen.integral (Range.linear 1 10)
let i2i64 = fromIntegral @Int @Int64
valpha <- genArray tR (ShNil `ShCons` kK)
vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
vm <- Gen.integral (Range.linear 0 5)
let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
k2 = 0.5 * vgamma * vgamma
k3 = 0.42 -- don't feel like multigammaing today
return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
Value vm `SCons` vX `SCons`
vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil)
main :: IO ()
main = defaultMain [tests]
|