1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Main where
import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Framework
import Array
import AST
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import CHAD.Types.ToTan
import Compile
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
import ForwardAD.DualNumbers
import Interpreter
import Interpreter.Rep
import Language
import Simplify
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t
simplifyIters iters env | Dict <- envKnown env =
case iters of
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' simplIters env term input =
second (second (toTanE env input)) $
gradientByCHAD simplIters env term input
gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
extendDN (STEither a _) (Left x) = Left <$> extendDN a x
extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
extendDN (STScal sty) x = case sty of
STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
STI32 -> pure x
STI64 -> pure x
STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"
extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val
closeIsh' :: Double -> Double -> Double -> Bool
closeIsh' h a b =
abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h)
closeIsh :: Double -> Double -> Bool
closeIsh = closeIsh' 1e-5
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- An empty name means "no restrictions".
data TplConstr = C String -- ^ name; @""@ means anonymous
Int -- ^ minimum value to generate
type family DimNames n where
DimNames Z = ()
DimNames (S Z) = TplConstr
DimNames (S n) = DimNames n :$ TplConstr
type family Tpl t where
Tpl (TArr n t) = DimNames n
Tpl (TPair a b) = (Tpl a, Tpl b)
-- If you add equations here, don't forget to update genValue! It currently
-- just emptyTpl's things out.
Tpl _ = ()
data a :& b = a :& b deriving (Show) ; infixl :&
type family TemplateE env where
TemplateE '[] = ()
TemplateE '[t] = Tpl t
TemplateE (t : ts) = TemplateE ts :& Tpl t
emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
emptyDimNames (SS SZ) = C "" 0
emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0
emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
emptyTpl (STScal _) = ()
emptyTpl _ = error "too lazy"
emptyTemplateE :: SList STy env -> TemplateE env
emptyTemplateE SNil = ()
emptyTemplateE (t `SCons` SNil) = emptyTpl t
emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t
genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShape = \n tpl -> do
sh <- genShapeNaive n tpl
let sz = shapeSize sh
factor = sz `div` 100 + 1
return (shapeDiv sh factor)
where
genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShapeNaive SZ () = return ShNil
genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
genNamedDim (C "" lo) = genDim lo
genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
Nothing -> do
dim <- genDim lo
modify (Map.insert name dim)
return dim
Just dim -> return dim
genDim :: Int -> StateT (Map String Int) Gen Int
genDim lo = Gen.integral (Range.linear lo 10)
shapeDiv :: Shape n -> Int -> Shape n
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
Value <$> arrayGenerateLinM sh (\_ ->
unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
STNil -> return (Value ())
STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
,liftV Right <$> genValue b (emptyTpl b)]
STMaybe t -> Gen.choice [return (Value Nothing)
,liftV Just <$> genValue t (emptyTpl t)]
STArr n t -> genShape n tpl >>= lift . genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
adTest name = adTestCon name (const True)
adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree
adTestCon name constr term =
let env = knownEnv
in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
=> TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree
adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
=> TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
let env = knownEnv @env
exprS = simplifyFix expr
in
withCompiled env expr $ \primalfun ->
withCompiled env (simplifyFix expr) $ \primalSfun ->
testGroup name
[testProperty "compile primal" $ property $ do
input <- forAllWith (showEnv env) envGenerator
let outPrimalI = interpretOpen False input expr
outPrimalC <- liftIO $ primalfun input
diff outPrimalI (closeIsh' 1e-8) outPrimalC
let outPrimalSI = interpretOpen False input exprS
outPrimalSC <- liftIO $ primalSfun input
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
,withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
testProperty "compile fwdAD" $ property $ do
input <- forAllWith (showEnv env) envGenerator
dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
(outDNC1, outDNC2) <- liftIO $ dnfun dinput
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
,withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
testProperty "chad" $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
dtermChadS = simplifyFix dtermChad0
let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
dtermSChadS = simplifyFix dtermSChad0
-- pack Text for less GC pressure (these values are retained for some reason)
diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
input <- forAllWith (showEnv env) envGenerator
outPrimal <- liftIO $ primalSfun input
let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
unpackGrad = unTup vUnpair (d2e env) . Value
let scFwd = envScalars env $ gradientByForward fwdartifactC input
let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
(outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
(outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
(outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
scChad = envScalars env $ toTanE env input gradChad0
scChadS = envScalars env $ toTanE env input gradChadS
scSChad = envScalars env $ toTanE env input gradSChad0
scSChadS = envScalars env $ toTanE env input gradSChadS
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
diff outChad0 closeIsh outPrimal
diff outChadS closeIsh outPrimal
diff outSChad0 closeIsh outPrimal
diff outSChadS closeIsh outPrimal
diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
]
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_sparse = fromNamed $ lambda #inp $ body $
let_ #n (snd_ (shape #inp)) $
let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $
let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_regression_simpl1 = fromNamed $ lambda #q $ body $
idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
let_ #j (snd_ #idx) $
if_ (#j .== 0)
(#q ! pair nil 0)
(if_ (#j .== #j) 1.0 2.0)
term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
let_ #hei (snd_ (fst_ (shape #mat))) $
let_ #wid (snd_ (shape #mat)) $
build1 #hei $ #i :->
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
tests :: TestTree
tests = testGroup "AD"
[adTest "id" $ fromNamed $ lambda #x $ body $ #x
,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $
let_ #i (round_ #x) $
let_ #j (round_ #y) $
let_ #a1 (#x + #y) $
let_ #a2 (#x - #y) $
let_ #a3 (#x * #y) $
let_ #a4 (#x / (#y * #y + 1)) $
let_ #b1 (#i + #j) $
let_ #b2 (#i - #j) $
let_ #b3 (#i * #j) $
let_ #b4 (#i `idiv` (#j * #j + 1)) $
#a1 + #a2 + #a3 + #a4 +
toFloat_ (#b1 + #b2 + #b3 + #b4)
,adTest "order-of-operations" $ fromNamed $ body $
toFloat_ (3 * (3 `idiv` 2)) -- Compile had a pretty-printing bug at some point
,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $
idx0 $ sum1i $ replicate1i 10 #x
,adTest "pairs" term_pairs
,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx
,adTest "build1-sum" term_build1_sum
,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ maximum1i #x
,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ minimum1i #x
,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42
,adTestTp "sparse" (C "" 5) term_sparse
-- Regression test for a simplifier bug (89b78d4)
,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1
,adTestGen "neural" Example.neural genNeural
,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural
,adTestTp "logsumexp" (C "" 1) $
fromNamed $ lambda @(TArr N1 _) #vec $ body $
let_ #m (maximum1i #vec) $
log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec
,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM
,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM
,adTestGen "gmm" (Example.gmmObjective False) genGMM
,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
]
where
genGMM = do
-- The input ranges here are completely arbitrary.
let tR = STScal STF64
kN <- Gen.integral (Range.linear 1 8)
kD <- Gen.integral (Range.linear 1 8)
kK <- Gen.integral (Range.linear 1 8)
let i2i64 = fromIntegral @Int @Int64
valpha <- genArray tR (ShNil `ShCons` kK)
vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
vm <- Gen.integral (Range.linear 0 5)
let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
k2 = 0.5 * vgamma * vgamma
k3 = 0.42 -- don't feel like multigammaing today
return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
Value vm `SCons` vX `SCons`
vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil)
genNeural = do
let tR = STScal STF64
let genLayer nin nout =
liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
<*> genArray tR (ShNil `ShCons` nout)
nin <- Gen.integral (Range.linear 1 10)
n1 <- Gen.integral (Range.linear 1 10)
n2 <- Gen.integral (Range.linear 1 10)
input <- genArray tR (ShNil `ShCons` nin)
lay1 <- genLayer nin n1
lay2 <- genLayer n1 n2
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
main :: IO ()
main = defaultMain tests
|