| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
 | {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Main where
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Framework
import Array
import AST hiding ((.>))
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import CHAD.Types.ToTan
import Compile
import qualified Example
import qualified Example.GMM as Example
import Example.Types
import ForwardAD
import ForwardAD.DualNumbers
import Interpreter
import Interpreter.Rep
import Language
import Simplify
data TypedValue t = TypedValue (STy t) (Rep t)
instance Show (TypedValue t) where
  showsPrec d (TypedValue t x) = showValue d t x
data TypedEnv env = TypedEnv (SList STy env) (SList Value env)
instance Show (TypedEnv env) where
  show (TypedEnv env xs) = showEnv env xs
unTypedEnv :: TypedEnv env -> SList Value env
unTypedEnv (TypedEnv _ xs) = xs
data SimplIters = SimplIters Int | SimplFix
  deriving (Show)
simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t
simplifyIters iters env | Dict <- envKnown env =
  case iters of
    SimplIters n -> simplifyN n
    SimplFix -> simplifyFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
  let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
      (out, grad) = interpretOpen False env input dterm
  in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' simplIters env term input =
  second (second (toTanE env input)) $
    gradientByCHAD simplIters env term input
gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
-- | Generate input tangents for this primal
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
extendDN (STEither a _) (Left x) = Left <$> extendDN a x
extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
extendDN (STLEither _ _) Nothing = pure Nothing
extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
extendDN (STScal sty) x = case sty of
  STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
  STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
  STI32 -> pure x
  STI64 -> pure x
  STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"
extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val
closeIsh' :: Double -> Double -> Double -> Bool
closeIsh' h a b =
  abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h)
closeIsh :: Double -> Double -> Bool
closeIsh = closeIsh' 1e-5
closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool
closeIshT' _ STNil () () = True
closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y'
closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
closeIshT' _ STEither{} _ _ = False
closeIshT' _ (STLEither _ _) Nothing Nothing = True
closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
closeIshT' _ STLEither{} _ _ = False
closeIshT' _ (STMaybe _) Nothing Nothing = True
closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
closeIshT' _ STMaybe{} _ _ = False
closeIshT' h (STArr _ a) arr1 arr2 =
  arrayShape arr1 == arrayShape arr2 &&
    and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2))
closeIshT' _ (STScal STI32) i j = i == j
closeIshT' _ (STScal STI64) i j = i == j
closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
closeIshT' h (STScal STF64) x y = closeIsh' h x y
closeIshT' _ (STScal STBool) x y = x == y
closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
closeIshT :: STy t -> Rep t -> Rep t -> Bool
closeIshT = closeIshT' 1e-5
closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool
closeIshE SNil SNil SNil = True
closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) =
  closeIshT t x y && closeIshE env xs ys
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- | The type index is just a marker that helps typed holes show what (type of)
-- argument this template constraint belongs to.
data TplConstr a = C String  -- ^ name; @""@ means anonymous
                     Int     -- ^ minimum value to generate
                 | NC  -- ^ no constraints
type family DimNames n where
  DimNames Z = ()
  DimNames (S Z) = TplConstr (S Z)
  DimNames (S n) = DimNames n :$ TplConstr (S n)
type family Tpl t where
  Tpl (TArr n t) = DimNames n
  Tpl (TPair a b) = (Tpl a, Tpl b)
  Tpl (TScal TI32) = TplConstr TI32
  Tpl (TScal TI64) = TplConstr TI64
  -- If you add equations here, don't forget to update genValue! It currently
  -- just emptyTpl's things out.
  Tpl _ = ()
data a :& b = a :& b deriving (Show) ; infixl :&
type family TemplateE env where
  TemplateE '[] = ()
  TemplateE '[t] = Tpl t
  TemplateE (t : ts) = TemplateE ts :& Tpl t
emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
emptyDimNames (SS SZ) = NC
emptyDimNames (SS n@SS{}) = emptyDimNames n :$ NC
emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
emptyTpl (STScal STI32) = NC
emptyTpl (STScal STI64) = NC
emptyTpl (STScal STF32) = ()
emptyTpl (STScal STF64) = ()
emptyTpl (STScal STBool) = ()
emptyTpl _ = error "too lazy"
emptyTemplateE :: SList STy env -> TemplateE env
emptyTemplateE SNil = ()
emptyTemplateE (t `SCons` SNil) = emptyTpl t
emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t
genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShape = \n tpl -> do
  sh <- genShapeNaive n tpl
  let sz = shapeSize sh
      factor = sz `div` 100 + 1
  return (shapeDiv sh tpl factor)
  where
    genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
    genShapeNaive SZ () = return ShNil
    genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
    genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
    genNamedDim :: TplConstr a -> StateT (Map String Int) Gen Int
    genNamedDim NC = genDim 0
    genNamedDim (C "" lo) = genDim lo
    genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
      Nothing -> do
        dim <- genDim lo
        modify (Map.insert name dim)
        return dim
      Just dim -> return dim
    genDim :: Int -> StateT (Map String Int) Gen Int
    genDim lo = Gen.integral (Range.linear lo (lo+10))
    shapeDiv :: Shape n -> DimNames n -> Int -> Shape n
    shapeDiv ShNil _ _ = ShNil
    shapeDiv (ShNil       `ShCons` n) (       C _ lo) f = ShNil             `ShCons` (max lo (n `div` f))
    shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f))
    shapeDiv (ShNil       `ShCons` n)         NC  f = ShNil             `ShCons` (n `div` f)
    shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
  Value <$> arrayGenerateLinM sh (\_ ->
              unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
genValue :: forall t. STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
  STNil -> return (Value ())
  STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
  STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
                             ,liftV Right <$> genValue b (emptyTpl b)]
  STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
                                 ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
                                 ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
  STMaybe t -> Gen.choice [return (Value Nothing)
                          ,liftV Just <$> genValue t (emptyTpl t)]
  STArr n t -> genShape n tpl >>= lift . genArray t
  STScal sty -> case sty of
    STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STI32 -> genInt
    STI64 -> genInt
    STBool -> Gen.choice [return (Value False), return (Value True)]
  STAccum{} -> error "Cannot generate inputs for accumulators"
  where
    genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
    genInt = do
      let gen lo = Gen.integral (Range.linearFrom 0 lo (max 10 (lo + 10)))
      val <- case tpl of
               NC -> gen (-10)
               C name lo -> gets (Map.lookup name) >>= \case
                              Nothing -> do
                                val <- fromIntegral @Int @(Rep t) <$> gen lo
                                modify (Map.insert name (fromIntegral @(Rep t) @Int val))
                                return val
                              Just val -> return (fromIntegral @Int @(Rep t) val)
      return (Value val)
genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr
compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree
compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty)
compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree
compileTestGen name expr envGenerator =
  let env = knownEnv
      t = typeOf expr
  in withCompiled env expr $ \fun ->
     testProperty name $ property $ do
       input <- forAllWith (showEnv env) envGenerator
       let resI = interpretOpen False env input expr
       resC <- evalIO $ fun input
       let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
       diff (TypedValue t resI) cmp (TypedValue t resC)
adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree
adTest name = adTestCon name (const True)
adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree
adTestCon name constr term =
  let env = knownEnv
  in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
         => TestName -> TemplateE env -> Ex env R -> TestTree
adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
          => TestName -> Ex env R -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
  let env = knownEnv @env
      exprS = simplifyFix expr
  in withCompiled env expr $ \primalfun ->
     withCompiled env (simplifyFix expr) $ \primalSfun ->
     testGroupCollapse name
       [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
       ,adTestGenFwd env envGenerator exprS
       ,testGroup "chad"
          [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun
          ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]]
adTestGenPrimal :: SList STy env -> Gen (SList Value env)
                -> Ex env R -> Ex env R
                -> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
                -> TestTree
adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
  testProperty "compile primal" $ property $ do
    input <- forAllWith (showEnv env) envGenerator
    let outPrimalI = interpretOpen False env input expr
    outPrimalC <- evalIO $ primalfun input
    diff outPrimalI (closeIsh' 1e-8) outPrimalC
    let outPrimalSI = interpretOpen False env input exprS
    outPrimalSC <- evalIO $ primalSfun input
    diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
adTestGenFwd :: SList STy env -> Gen (SList Value env)
             -> Ex env R
             -> TestTree
adTestGenFwd env envGenerator exprS =
  withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
    testProperty "compile fwdAD" $ property $ do
      input <- forAllWith (showEnv env) envGenerator
      dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
      let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS)
      (outDNC1, outDNC2) <- evalIO $ dnfun dinput
      diff outDNI1 (closeIsh' 1e-8) outDNC1
      diff outDNI2 (closeIsh' 1e-8) outDNC2
adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env)
              -> Ex env R -> Ex env R
              -> (SList Value env -> IO Double)
              -> TestTree
adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env =
  let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr
      dtermChadS = simplifyFix dtermChad0
      dtermChadSUS = simplifyFix $ unMonoid dtermChadS
      dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
      dtermSChadS = simplifyFix dtermSChad0
      dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
  in
  withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
  withCompiled env dtermSChadSUS $ \dcompSChadSUS ->
    testProperty testname $ property $ do
      annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
      -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason)
      diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
      diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0)))
      diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
      diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0)))
      input <- forAllWith (showEnv env) envGenerator
      outPrimal <- evalIO $ primalSfun input
      let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
          unpackGrad = unTup vUnpair (d2e env) . Value
      let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
      let (outChad0   , gradChad0)    = second unpackGrad $ interpretOpen False env input dtermChad0
          (outChadS   , gradChadS)    = second unpackGrad $ interpretOpen False env input dtermChadS
          (outChadSUS , gradChadSUS)  = second unpackGrad $ interpretOpen False env input dtermChadSUS
          (outSChad0  , gradSChad0)   = second unpackGrad $ interpretOpen False env input dtermSChad0
          (outSChadS  , gradSChadS)   = second unpackGrad $ interpretOpen False env input dtermSChadS
          (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
          tansChad     = TypedEnv (tanenv env) $ toTanE env input gradChad0
          tansChadS    = TypedEnv (tanenv env) $ toTanE env input gradChadS
          tansChadSUS  = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
          tansSChad    = TypedEnv (tanenv env) $ toTanE env input gradSChad0
          tansSChadS   = TypedEnv (tanenv env) $ toTanE env input gradSChadS
          tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
      (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input)
      let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS
      -- annotate (showEnv (d2e env) gradChad0)
      -- annotate (showEnv (d2e env) gradChadS)
      -- annotate (ppExpr knownEnv expr)
      -- annotate (ppExpr env dtermChad0)
      -- annotate (ppExpr env dtermChadS)
      annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
      diff outChad0        closeIsh outPrimal
      diff outChadS        closeIsh outPrimal
      diff outChadSUS      closeIsh outPrimal
      diff outSChad0       closeIsh outPrimal
      diff outSChadS       closeIsh outPrimal
      diff outSChadSUS     closeIsh outPrimal
      diff outCompSChadSUS closeIsh outPrimal
      let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
      diff tansChad         closeIshE' tansFwd
      diff tansChadS        closeIshE' tansFwd
      diff tansChadSUS      closeIshE' tansFwd
      diff tansSChad        closeIshE' tansFwd
      diff tansSChadS       closeIshE' tansFwd
      diff tansSChadSUS     closeIshE' tansFwd
      diff tansCompSChadSUS closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64])
gen_gmm = do
  -- The input ranges here are completely arbitrary.
  let tR = STScal STF64
  kN <- Gen.integral (Range.linear 1 8)
  kD <- Gen.integral (Range.linear 1 8)
  kK <- Gen.integral (Range.linear 1 8)
  let i2i64 = fromIntegral @Int @Int64
  valpha <- genArray tR (ShNil `ShCons` kK)
  vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
  vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
  vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
  vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
  vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
  vm <- Gen.integral (Range.linear 0 5)
  let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
      k2 = 0.5 * vgamma * vgamma
      k3 = 0.42  -- don't feel like multigammaing today
  return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
          Value vm `SCons` vX `SCons`
          vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
          Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
          SNil)
gen_neural :: Gen (SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)])
gen_neural = do
  let tR = STScal STF64
  let genLayer nin nout =
        liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
                   <*> genArray tR (ShNil `ShCons` nout)
  nin <- Gen.integral (Range.linear 1 10)
  n1 <- Gen.integral (Range.linear 1 10)
  n2 <- Gen.integral (Range.linear 1 10)
  input <- genArray tR (ShNil `ShCons` nin)
  lay1 <- genLayer nin n1
  lay2 <- genLayer n1 n2
  lay3 <- genArray tR (ShNil `ShCons` n2)
  return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
term_build1_sum :: Ex '[TVec R] R
term_build1_sum = fromNamed $ lambda #x $ body $
  idx0 $ sum1i $
    build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #p (pair #x #y) $
  let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
    fst_ #q * #x + snd_ #q * fst_ #p
term_sparse :: Ex '[TVec R] R
term_sparse = fromNamed $ lambda #inp $ body $
  let_ #n (snd_ (shape #inp)) $
  let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
  let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $
  let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $
  let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
    idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
term_regression_simpl1 :: Ex '[TVec R] R
term_regression_simpl1 = fromNamed $ lambda #q $ body $
  idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
    let_ #j (snd_ #idx) $
      if_ (#j .== 0)
        (#q ! pair nil 0)
        (if_ (#j .== #j) 1.0 2.0)
term_mulmatvec :: Ex [TVec R, TMat R] R
term_mulmatvec = fromNamed $ lambda #mat $ lambda #vec $ body $
  idx0 $ sum1i $
    let_ #hei (snd_ (fst_ (shape #mat))) $
    let_ #wid (snd_ (shape #mat)) $
      build1 #hei $ #i :->
        idx0 (sum1i (build1 #wid $ #j :->
                       #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
term_arr_rebind :: Ex '[I64, TVec R] R
term_arr_rebind = fromNamed $ lambda #a $ lambda #k $ body $
  let_ #n (if_ (#k .< length_ #a) #k (length_ #a)) $
  let_ #b (build1 #n (#i :-> #a ! pair nil #i)) $
  let_ #p (if_ (#n `mod_` 2 .== 1)
             (pair #a #b)
             (pair (map_ (#x :-> #x + 1) #a) #b)) $
    if_ (#n `mod_` 3 .== 1)
      (idx0 (sum1i (snd_ #p)))
      (let_ #b' (snd_ #p) $
         idx0 (sum1i #b') * idx0 (sum1i (map_ (#x :-> 2 * #x) #b')))
-- This simplifies away to a pointless test, but is helpful for debugging what
-- term_arr_rebind is supposed to test in a REPL
term_arr_rebind_simple :: Ex '[TVec R] R
term_arr_rebind_simple = fromNamed $ lambda #a $ body $
  let_ #b (build1 (length_ #a) (#i :-> 5 * (#a ! pair nil #i))) $
  let_ #c #b $
  let_ #d #c $
   idx0 (sum1i #d)
tests_Compile :: TestTree
tests_Compile = testGroup "Compile"
  [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $
      with @R 0.0 $ #ac :->
        if_ #b (accum SAPHere nil #x #ac)
               nil
  ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
      with @(TPair R R) (pair 0.0 0.0) $ #ac :->
        let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $
        let_ #_ (accum SAPHere nil #x #ac) $
        let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $
          nil
  ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
      with @(TMaybe (TPair R R)) nothing $ #ac :->
        let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $
        let_ #_ (accum SAPHere nil #x #ac) $
        let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $
          nil
  ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
      let_ #len (snd_ (shape #x)) $
      with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
        let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac)
                        nil) $
        let_ #_ (accum SAPHere nil #x #ac) $
          nil
  ]
tests_AD :: TestTree
tests_AD = testGroup "AD"
  [adTest "id" $ fromNamed $ lambda #x $ body $ #x
  ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
  ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $
      let_ #i (round_ #x) $
      let_ #j (round_ #y) $
      let_ #a1 (#x + #y) $
      let_ #a2 (#x - #y) $
      let_ #a3 (#x * #y) $
      let_ #a4 (#x / (#y * #y + 1)) $
      let_ #b1 (#i + #j) $
      let_ #b2 (#i - #j) $
      let_ #b3 (#i * #j) $
      let_ #b4 (#i `idiv` (#j * #j + 1)) $
        #a1 + #a2 + #a3 + #a4 +
        toFloat_ (#b1 + #b2 + #b3 + #b4)
  ,adTest "order-of-operations" $ fromNamed $ body $
      toFloat_ (3 * (3 `idiv` 2))  -- Compile had a pretty-printing bug at some point
  ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
  ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $
      idx0 $ sum1i $ replicate1i 10 #x
  ,adTest "pairs" term_pairs
  ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
      idx0 $ build SZ nil $ #idx :-> const_ 0.0
  ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
      idx0 $
        build SZ (shape #x) $ #idx :-> #x ! #idx
  ,adTest "build1-sum" term_build1_sum
  ,adTest "build2-sum" $ fromNamed $ lambda @(TMat _) #x $ body $
      idx0 $ sum1i . sum1i $
        build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
  ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
      fromNamed $ lambda @(TMat R) #x $ body $
        idx0 $ sum1i $ maximum1i #x
  ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
      fromNamed $ lambda @(TMat R) #x $ body $
        idx0 $ sum1i $ minimum1i #x
  ,adTest "unused" $ fromNamed $ lambda @(TVec R) #x $ body $
      let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
        42
  ,adTestTp "sparse" (C "" 5) term_sparse
  -- Regression test for a simplifier bug (89b78d4)
  ,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1
  -- Regression test for refcounts when indexing in nested arrays
  ,adTestTp "regression-idx1" (C "" 1 :$ C "" 1) $
    fromNamed $ lambda @(TMat R) #L $ body $
      if_ (const_ @TI64 1 .> 0)
        (idx0 $ sum1i (build1 1 $ #_ :->
           idx0 (sum1i (build1 1 $ #_ :->
             #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0))))
        42
  ,adTest "arr-rebind-simple" term_arr_rebind_simple
  ,adTestTp "arr-rebind" (NC :& C "" 0) term_arr_rebind
  ,adTestGen "neural" Example.neural gen_neural
  ,adTestTp "logsumexp" (C "" 1) $
      fromNamed $ lambda @(TVec _) #vec $ body $
      let_ #m (maximum1i #vec) $
        log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
  ,adTestTp "mulmatvec" ((NC :$ C "n" 0) :& C "n" 0) term_mulmatvec
  ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
  ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
  ]
main :: IO ()
main = defaultMain $ testGroup "All"
  [tests_Compile
  ,tests_AD]
 |