summaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: e7dda69cd358fcf7cd5d06fbf937e47751608ef7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
{-# LANGUAGE DataKinds #-}
-- {-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where

import Data.Bifunctor
-- import qualified Data.Dependent.Map as DMap
-- import Data.Dependent.Map (DMap)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main

import Array
import AST
import AST.Pretty
import CHAD
import CHAD.Types
import Data
import qualified Example
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
import Simplify


type family MapMerge env where
  MapMerge '[] = '[]
  MapMerge (t : ts) = "merge" : MapMerge ts

mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl

mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl

primalEnv :: SList STy env' -> SList STy (D1E env')
primalEnv SNil = SNil
primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env

data SimplIters = SimplIters Int | SimplFix
  deriving (Show)

diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64)
         -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env)))
diffCHAD = \simplIters env term ->
  case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of
    (Refl, Refl, Dict) ->
      let descr = makeMergeDescr env
          simpl = case simplIters of
                    SimplIters n -> simplifyN n
                    SimplFix -> simplifyFix
      in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
  where
    makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
    makeMergeDescr SNil = DTop
    makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)

-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
  case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
    (Refl, Refl) ->
      let dterm = diffCHAD simplIters env term
          input1 = toPrimalE env input
          (out, grad) = interpretOpen False input1 dterm
      in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad)))
  where
    toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
    toPrimalE SNil SNil = SNil
    toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp

    toPrimal :: STy t -> Rep t -> Rep (D1 t)
    toPrimal = \case
      STNil -> id
      STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STMaybe t -> fmap (toPrimal t)
      STArr _ t -> fmap (toPrimal t)
      STScal _ -> id
      STAccum{} -> error "Accumulators not allowed in input program"

-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
  where
    toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
    toTanE SNil SNil SNil = SNil
    toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
      Value (toTan t p x) `SCons` toTanE env primal inp

    toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
    toTan typ primal der = case typ of
      STNil -> der
      STPair t1 t2 -> case der of
                        Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                        Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
      STEither t1 t2 -> case der of
                          Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                          Right d -> case (primal, d) of
                            (Left p, Left d') -> Left (toTan t1 p d')
                            (Right p, Right d') -> Right (toTan t2 p d')
                            _ -> error "Primal and cotangent disagree on Either alternative"
      STMaybe t -> liftA2 (toTan t) primal der
      STArr _ t
        | shapeSize (arrayShape der) == 0 ->
            arrayMap (zeroTan t) primal
        | arrayShape primal == arrayShape der ->
            arrayGenerateLin (arrayShape primal) $ \i ->
              toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
        | otherwise ->
            error "Primal and cotangent disagree on array shape"
      STScal sty -> case sty of
        STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0

closeIsh :: Double -> Double -> Bool
closeIsh a b =
  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)

genShape :: SNat n -> Gen (Shape n)
genShape = \n -> do
  sh <- genShapeNaive n
  let sz = shapeSize sh
      factor = sz `div` 100 + 1
  return (shapeDiv sh factor)
  where
    genShapeNaive :: SNat n -> Gen (Shape n)
    genShapeNaive SZ = return ShNil
    genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)

    shapeDiv :: Shape n -> Int -> Shape n
    shapeDiv ShNil _ = ShNil
    shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)

genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t)

genValue :: STy a -> Gen (Value a)
genValue = \case
  STNil -> return (Value ())
  STPair a b -> liftV2 (,) <$> genValue a <*> genValue b
  STEither a b -> Gen.choice [liftV Left <$> genValue a
                             ,liftV Right <$> genValue b]
  STMaybe t -> Gen.choice [return (Value Nothing)
                          ,liftV Just <$> genValue t]
  STArr n t -> genShape n >>= genArray t
  STScal sty -> case sty of
    STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STBool -> Gen.choice [return (Value False), return (Value True)]
  STAccum{} -> error "Cannot generate inputs for accumulators"

genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env

-- data TemplateVar n = TemplateVar (SNat n) String
--   deriving (Show)

-- data Template t where
--   TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
--   TpAny :: STy t -> Template t
--   TpPair :: Template a -> Template b -> Template (TPair a b)
-- deriving instance Show (Template t)

-- data ShapeConstraint n = ShapeAtLeast (Shape n)
--   deriving (Show)

-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
-- genTemplate = _

-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
-- genEnvTemplateExact shapes env = _

-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
-- genEnvTemplate constrs env = do
--   shapes <- DMap.traverseWithKey _ constrs
--   genEnvTemplateExact shapes env

showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
  where
    showEntries :: SList STy env -> SList Value env -> [String]
    showEntries SNil SNil = []
    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs

adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = flip adTestGen (genEnv (knownEnv @env))

-- adTestTp :: forall env. KnownEnv env
--          => DMap TemplateVar ShapeConstraint -> SList Template env
--          -> Ex env (TScal TF64) -> Property
-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp)

adTestGen :: forall env. KnownEnv env
          => Ex env (TScal TF64) -> Gen (SList Value env) -> Property
adTestGen expr envGenerator = property $ do
  let env = knownEnv @env
  input <- forAllWith (showEnv env) envGenerator
  let outPrimal = interpretOpen False input expr
      gradFwd = gradientByForward knownEnv expr input
      (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
      (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
      (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
      scFwd = envScalars env gradFwd
      scCHAD = envScalars env gradCHAD
      scCHAD_S = envScalars env gradCHAD_S
  annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
  annotate (ppExpr knownEnv expr)
  annotate ppdterm
  annotate ppdterm_S
  diff ppdterm_S20 (==) ppdterm_S
  diff outChad closeIsh outChad_S
  diff outPrimal closeIsh outChad_S
  diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
  diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
  where
    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
    envScalars SNil SNil = []
    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs

term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
  idx0 $ sum1i $
    build (SS SZ) (shape #x) $ #idx :-> #x ! #idx

term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #p (pair #x #y) $
  let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
    fst_ #q * #x + snd_ #q * fst_ #p

tests :: IO Bool
tests = checkSequential $ Group "AD"
  [("id", adTest $ fromNamed $ lambda #x $ body $ #x)

  ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)

  ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))

  ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $
      idx0 $ sum1i $ replicate1i 10 #x)

  ,("pairs", adTest term_pairs)

  ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
      idx0 $ build SZ nil $ #idx :-> const_ 0.0)

  ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
      idx0 $
        build SZ (shape #x) $ #idx :-> #x ! #idx)

  -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum
  ,("build1-sum", adTest term_build1_sum)

  ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
      idx0 $ sum1i . sum1i $
        build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)

  -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
  --     idx0 $ sum1i . sum1i $
  --       build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
  --         oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)

  ,("neural", adTestGen Example.neural $ do
      let tR = STScal STF64
      let genLayer nin nout =
            liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
                       <*> genArray tR (ShNil `ShCons` nout)
      nin <- Gen.integral (Range.linear 1 10)
      n1 <- Gen.integral (Range.linear 1 10)
      n2 <- Gen.integral (Range.linear 1 10)
      input <- genArray tR (ShNil `ShCons` nin)
      lay1 <- genLayer nin n1
      lay2 <- genLayer n1 n2
      lay3 <- genArray tR (ShNil `ShCons` n2)
      return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
  ]

main :: IO ()
main = defaultMain [tests]