summaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: d3e55b38669066b036549934a85aae51a7a33c31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where

import Data.Bifunctor
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main

import Array
import AST
import AST.Pretty
import CHAD.Top
import CHAD.Types
import Data
import qualified Example
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
import Simplify


data SimplIters = SimplIters Int | SimplFix
  deriving (Show)

-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
  let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term
      dterm | Dict <- envKnown env
            = case simplIters of
                SimplIters n -> simplifyN n dtermNonSimpl
                SimplFix -> simplifyFix dtermNonSimpl
      (out, grad) = interpretOpen False input dterm
  in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))

-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
  where
    toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
    toTanE SNil SNil SNil = SNil
    toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
      Value (toTan t p x) `SCons` toTanE env primal inp

    toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
    toTan typ primal der = case typ of
      STNil -> der
      STPair t1 t2 -> case der of
                        Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                        Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
      STEither t1 t2 -> case der of
                          Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                          Right d -> case (primal, d) of
                            (Left p, Left d') -> Left (toTan t1 p d')
                            (Right p, Right d') -> Right (toTan t2 p d')
                            _ -> error "Primal and cotangent disagree on Either alternative"
      STMaybe t -> liftA2 (toTan t) primal der
      STArr _ t
        | shapeSize (arrayShape der) == 0 ->
            arrayMap (zeroTan t) primal
        | arrayShape primal == arrayShape der ->
            arrayGenerateLin (arrayShape primal) $ \i ->
              toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
        | otherwise ->
            error "Primal and cotangent disagree on array shape"
      STScal sty -> case sty of
        STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0

closeIsh :: Double -> Double -> Bool
closeIsh a b =
  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)

genShape :: SNat n -> Gen (Shape n)
genShape = \n -> do
  sh <- genShapeNaive n
  let sz = shapeSize sh
      factor = sz `div` 100 + 1
  return (shapeDiv sh factor)
  where
    genShapeNaive :: SNat n -> Gen (Shape n)
    genShapeNaive SZ = return ShNil
    genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)

    shapeDiv :: Shape n -> Int -> Shape n
    shapeDiv ShNil _ = ShNil
    shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)

genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t)

genValue :: STy a -> Gen (Value a)
genValue = \case
  STNil -> return (Value ())
  STPair a b -> liftV2 (,) <$> genValue a <*> genValue b
  STEither a b -> Gen.choice [liftV Left <$> genValue a
                             ,liftV Right <$> genValue b]
  STMaybe t -> Gen.choice [return (Value Nothing)
                          ,liftV Just <$> genValue t]
  STArr n t -> genShape n >>= genArray t
  STScal sty -> case sty of
    STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STBool -> Gen.choice [return (Value False), return (Value True)]
  STAccum{} -> error "Cannot generate inputs for accumulators"

genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env

showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
  where
    showEntries :: SList STy env -> SList Value env -> [String]
    showEntries SNil SNil = []
    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs

adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = flip adTestGen (genEnv (knownEnv @env))

adTestGen :: forall env. KnownEnv env
          => Ex env (TScal TF64) -> Gen (SList Value env) -> Property
adTestGen expr envGenerator = property $ do
  let env = knownEnv @env
  input <- forAllWith (showEnv env) envGenerator
  let outPrimal = interpretOpen False input expr
      gradFwd = gradientByForward knownEnv expr input
      (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
      (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
      (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
      scFwd = envScalars env gradFwd
      scCHAD = envScalars env gradCHAD
      scCHAD_S = envScalars env gradCHAD_S
  annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
  annotate (ppExpr knownEnv expr)
  annotate ppdterm
  annotate ppdterm_S
  diff ppdterm_S20 (==) ppdterm_S
  diff outChad closeIsh outChad_S
  diff outPrimal closeIsh outChad_S
  diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
  diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
  where
    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
    envScalars SNil SNil = []
    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs

term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
  idx0 $ sum1i $
    build (SS SZ) (shape #x) $ #idx :-> #x ! #idx

term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #p (pair #x #y) $
  let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
    fst_ #q * #x + snd_ #q * fst_ #p

tests :: IO Bool
tests = checkSequential $ Group "AD"
  [("id", adTest $ fromNamed $ lambda #x $ body $ #x)

  ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)

  ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))

  ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $
      idx0 $ sum1i $ replicate1i 10 #x)

  ,("pairs", adTest term_pairs)

  ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
      idx0 $ build SZ nil $ #idx :-> const_ 0.0)

  ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
      idx0 $
        build SZ (shape #x) $ #idx :-> #x ! #idx)

  ,("build1-sum", adTest term_build1_sum)

  ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
      idx0 $ sum1i . sum1i $
        build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)

  -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
  --     idx0 $ sum1i . sum1i $
  --       build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
  --         oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)

  ,("neural", adTestGen Example.neural $ do
      let tR = STScal STF64
      let genLayer nin nout =
            liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
                       <*> genArray tR (ShNil `ShCons` nout)
      nin <- Gen.integral (Range.linear 1 10)
      n1 <- Gen.integral (Range.linear 1 10)
      n2 <- Gen.integral (Range.linear 1 10)
      input <- genArray tR (ShNil `ShCons` nin)
      lay1 <- genLayer nin n1
      lay2 <- genLayer n1 n2
      lay3 <- genArray tR (ShNil `ShCons` n2)
      return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
  ]

main :: IO ()
main = defaultMain [tests]