summaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 34ab5af36492259092c41a93afb9a670250b367f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Main where

import Data.Bifunctor
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main

import Array
import AST
import CHAD
import CHAD.Types
import Data
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language


type family MapMerge env where
  MapMerge '[] = '[]
  MapMerge (t : ts) = "merge" : MapMerge ts

mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl

mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl

gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
gradientByCHAD = \env term input ->
  case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
    (Refl, Refl) ->
      let descr = makeMergeDescr env
          dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
          input1 = toPrimalE env input
          (_out, grad) = interpretOpen input1 dterm
      in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
  where
    makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
    makeMergeDescr SNil = DTop
    makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)

    toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
    toPrimalE SNil SNil = SNil
    toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp

    toPrimal :: STy t -> Rep t -> Rep (D1 t)
    toPrimal = \case
      STNil -> id
      STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STMaybe t -> fmap (toPrimal t)
      STArr _ t -> fmap (toPrimal t)
      STScal _ -> id
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
  where
    toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
    toTanE SNil SNil SNil = SNil
    toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
      Value (toTan t p x) `SCons` toTanE env primal inp

    toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
    toTan typ primal der = case typ of
      STNil -> der
      STPair t1 t2 -> case der of
                        Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                        Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
      STEither t1 t2 -> case der of
                          Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                          Right d -> case (primal, d) of
                            (Left p, Left d') -> Left (toTan t1 p d')
                            (Right p, Right d') -> Right (toTan t2 p d')
                            _ -> error "Primal and cotangent disagree on Either alternative"
      STMaybe t -> liftA2 (toTan t) primal der
      STArr _ t
        | shapeSize (arrayShape der) == 0 ->
            arrayMap (zeroTan t) primal
        | arrayShape primal == arrayShape der ->
            arrayGenerateLin (arrayShape primal) $ \i ->
              toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
        | otherwise ->
            error "Primal and cotangent disagree on array shape"
      STScal sty -> case sty of
        STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0

closeIsh :: Double -> Double -> Bool
closeIsh a b =
  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)

genShape :: SNat n -> Gen (Shape n)
genShape = \n -> do
  sh <- genShapeNaive n
  let sz = shapeSize sh
      factor = sz `div` 100 + 1
  return (shapeDiv sh factor)
  where
    genShapeNaive :: SNat n -> Gen (Shape n)
    genShapeNaive SZ = return ShNil
    genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)

    shapeDiv :: Shape n -> Int -> Shape n
    shapeDiv ShNil _ = ShNil
    shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)

genValue :: STy a -> Gen (Value a)
genValue = \case
  STNil -> return (Value ())
  STPair a b -> lv2 (,) <$> genValue a <*> genValue b
  STEither a b -> Gen.choice [lv1 Left <$> genValue a
                             ,lv1 Right <$> genValue b]
  STMaybe t -> Gen.choice [return (Value Nothing)
                          ,lv1 Just <$> genValue t]
  STArr n t -> do
    sh <- genShape n
    Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t)
  STScal sty -> case sty of
    STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STBool -> Gen.choice [return (Value False), return (Value True)]
  STAccum{} -> error "Cannot generate inputs for accumulators"
  where
    lv1 :: (Rep a -> Rep b) -> Value a -> Value b
    lv1 f (Value x) = Value (f x)

    lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
    lv2 f (Value x) (Value y) = Value (f x y)

genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env

showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr)  -- TODO: improve
showValue _ (STScal sty) x = case sty of
  STF32 -> shows x
  STF64 -> shows x
  STI32 -> shows x
  STI64 -> shows x
  STBool -> shows x
showValue _ STAccum{} _ = error "Cannot show accumulators"

showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
  where
    showEntries :: SList STy env -> SList Value env -> [String]
    showEntries SNil SNil = []
    showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs

adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest expr = property $ do
  let env = knownEnv @env
  input <- forAllWith (showEnv env) $ genEnv env
  let gradFwd = gradientByForward knownEnv expr input
      gradCHAD = gradientByCHAD' knownEnv expr input
      scFwd = envScalars env gradFwd
      scCHAD = envScalars env gradCHAD
  diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
  where
    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
    envScalars SNil SNil = []
    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs

tests :: IO Bool
tests = checkParallel $ Group "AD"
  [("id", adTest $ fromNamed $ lambda #x $ body $ #x)]

main :: IO ()
main = defaultMain [tests]