| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
 | {-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Main where
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Tasty
import Test.Tasty.Hedgehog
import Array
import AST
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
import Simplify
data SimplIters = SimplIters Int | SimplFix
  deriving (Show)
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
  let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
      dterm | Dict <- envKnown env
            = case simplIters of
                SimplIters n -> simplifyN n dtermNonSimpl
                SimplFix -> simplifyFix dtermNonSimpl
      (out, grad) = interpretOpen False input dterm
  in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
  where
    toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
    toTanE SNil SNil SNil = SNil
    toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
      Value (toTan t p x) `SCons` toTanE env primal inp
    toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
    toTan typ primal der = case typ of
      STNil -> der
      STPair t1 t2 -> case der of
                        Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
                        Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
      STEither t1 t2 -> case der of
                          Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
                          Just d -> case (primal, d) of
                            (Left p, Left d') -> Left (toTan t1 p d')
                            (Right p, Right d') -> Right (toTan t2 p d')
                            _ -> error "Primal and cotangent disagree on Either alternative"
      STMaybe t -> liftA2 (toTan t) primal der
      STArr _ t
        | shapeSize (arrayShape der) == 0 ->
            arrayMap (zeroTan t) primal
        | arrayShape primal == arrayShape der ->
            arrayGenerateLin (arrayShape primal) $ \i ->
              toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
        | otherwise ->
            error "Primal and cotangent disagree on array shape"
      STScal sty -> case sty of
        STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
      STAccum{} -> error "Accumulators not allowed in input program"
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
closeIsh :: Double -> Double -> Bool
closeIsh a b =
  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- An empty name means "no restrictions".
data TplConstr = C String  -- ^ name; @""@ means anonymous
                   Int     -- ^ minimum value to generate
type family DimNames n where
  DimNames Z = ()
  DimNames (S Z) = TplConstr
  DimNames (S n) = DimNames n :$ TplConstr
type family Tpl t where
  Tpl (TArr n t) = DimNames n
  Tpl (TPair a b) = (Tpl a, Tpl b)
  -- If you add equations here, don't forget to update genValue! It currently
  -- just emptyTpl's things out.
  Tpl _ = ()
data a :& b = a :& b deriving (Show) ; infixl :&
type family TemplateE env where
  TemplateE '[] = ()
  TemplateE '[t] = Tpl t
  TemplateE (t : ts) = TemplateE ts :& Tpl t
emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
emptyDimNames (SS SZ) = C "" 0
emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0
emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
emptyTpl (STScal _) = ()
emptyTpl _ = error "too lazy"
emptyTemplateE :: SList STy env -> TemplateE env
emptyTemplateE SNil = ()
emptyTemplateE (t `SCons` SNil) = emptyTpl t
emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t
genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShape = \n tpl -> do
  sh <- genShapeNaive n tpl
  let sz = shapeSize sh
      factor = sz `div` 100 + 1
  return (shapeDiv sh factor)
  where
    genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
    genShapeNaive SZ () = return ShNil
    genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
    genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
    genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
    genNamedDim (C "" lo) = genDim lo
    genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
      Nothing -> do
        dim <- genDim lo
        modify (Map.insert name dim)
        return dim
      Just dim -> return dim
    genDim :: Int -> StateT (Map String Int) Gen Int
    genDim lo = Gen.integral (Range.linear lo 10)
    shapeDiv :: Shape n -> Int -> Shape n
    shapeDiv ShNil _ = ShNil
    shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
  Value <$> arrayGenerateLinM sh (\_ ->
              unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
  STNil -> return (Value ())
  STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
  STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
                             ,liftV Right <$> genValue b (emptyTpl b)]
  STMaybe t -> Gen.choice [return (Value Nothing)
                          ,liftV Just <$> genValue t (emptyTpl t)]
  STArr n t -> genShape n tpl >>= lift . genArray t
  STScal sty -> case sty of
    STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STBool -> Gen.choice [return (Value False), return (Value True)]
  STAccum{} -> error "Cannot generate inputs for accumulators"
genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = adTestCon (const True)
adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property
adTestCon constr term =
  let env = knownEnv
  in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
         => TemplateE env -> Ex env (TScal TF64) -> Property
adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
          => Ex env (TScal TF64) -> Gen (SList Value env) -> Property
adTestGen expr envGenerator = property $ do
  let env = knownEnv @env
  input <- forAllWith (showEnv env) envGenerator
  let outPrimal = interpretOpen False input expr
      gradFwd = gradientByForward knownEnv expr input
      (_ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
      (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
      (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
      scFwd = envScalars env gradFwd
      scCHAD = envScalars env gradCHAD
      scCHAD_S = envScalars env gradCHAD_S
  annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
  -- annotate (ppExpr knownEnv expr)
  -- annotate ppdterm
  -- annotate ppdterm_S
  diff ppdterm_S (==) ppdterm_S20
  diff outChad_S closeIsh outChad
  diff outChad_S closeIsh outPrimal
  diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD
  diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd
  where
    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
    envScalars SNil SNil = []
    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
  idx0 $ sum1i $
    build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #p (pair #x #y) $
  let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
    fst_ #q * #x + snd_ #q * fst_ #p
term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_sparse = fromNamed $ lambda #inp $ body $
  let_ #n (snd_ (shape #inp)) $
  let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
  let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $
  let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $
  let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
    idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
tests :: TestTree
tests = testGroup "AD"
  [testProperty "id" $ adTest $ fromNamed $ lambda #x $ body $ #x
  ,testProperty "idx0" $ adTest $ fromNamed $ lambda #x $ body $ idx0 #x
  ,testProperty "sum-vec" $ adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
  ,testProperty "sum-replicate" $ adTest $ fromNamed $ lambda #x $ body $
      idx0 $ sum1i $ replicate1i 10 #x
  ,testProperty "pairs" $ adTest term_pairs
  ,testProperty "build0 const" $ adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
      idx0 $ build SZ nil $ #idx :-> const_ 0.0
  ,testProperty "build0" $ adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
      idx0 $
        build SZ (shape #x) $ #idx :-> #x ! #idx
  ,testProperty "build1-sum" $ adTest term_build1_sum
  ,testProperty "build2-sum" $ adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
      idx0 $ sum1i . sum1i $
        build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
  ,testProperty "maximum" $ adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
      fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
        idx0 $ sum1i $ maximum1i #x
  ,testProperty "minimum" $ adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
      fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
        idx0 $ sum1i $ minimum1i #x
  ,testProperty "unused" $ adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
    let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
      42
  ,testProperty "sparse" $ adTestTp (C "" 5) term_sparse
  ,testProperty "neural" $ adTestGen Example.neural genNeural
  ,testProperty "neural-unMonoid" $ adTestGen (unMonoid (simplifyFix Example.neural)) genNeural
  ,testProperty "logsumexp" $ adTestTp (C "" 1) $
      fromNamed $ lambda @(TArr N1 _) #vec $ body $
      let_ #m (maximum1i #vec) $
        log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
  ,testProperty "mulmatvec" $ adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $
      fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
      idx0 $ sum1i $
        let_ #hei (snd_ (fst_ (shape #mat))) $
        let_ #wid (snd_ (shape #mat)) $
          build1 #hei $ #i :->
            idx0 (sum1i (build1 #wid $ #j :->
                           #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
  ,testProperty "gmm-wrong" $ withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM
  ,testProperty "gmm-wrong-unMonoid" $ withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM
  ,testProperty "gmm" $ withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM
  ,testProperty "gmm-unMonoid" $ withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
  ]
  where
    genGMM = do
      -- The input ranges here are completely arbitrary.
      let tR = STScal STF64
      kN <- Gen.integral (Range.linear 1 10)
      kD <- Gen.integral (Range.linear 1 10)
      kK <- Gen.integral (Range.linear 1 10)
      let i2i64 = fromIntegral @Int @Int64
      valpha <- genArray tR (ShNil `ShCons` kK)
      vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
      vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
      vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
      vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
      vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
      vm <- Gen.integral (Range.linear 0 5)
      let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
          k2 = 0.5 * vgamma * vgamma
          k3 = 0.42  -- don't feel like multigammaing today
      return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
              Value vm `SCons` vX `SCons`
              vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
              Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
              SNil)
    genNeural = do
      let tR = STScal STF64
      let genLayer nin nout =
            liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
                       <*> genArray tR (ShNil `ShCons` nout)
      nin <- Gen.integral (Range.linear 1 10)
      n1 <- Gen.integral (Range.linear 1 10)
      n2 <- Gen.integral (Range.linear 1 10)
      input <- genArray tR (ShNil `ShCons` nin)
      lay1 <- genLayer nin n1
      lay2 <- genLayer n1 n2
      lay3 <- genArray tR (ShNil `ShCons` n2)
      return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
main :: IO ()
main = defaultMain tests
 |