1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Data.Bifunctor
import qualified Data.Dependent.Map as DMap
import Data.Dependent.Map (DMap)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main
import Array
import AST
import CHAD
import CHAD.Types
import Data
import qualified Example
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
type family MapMerge env where
MapMerge '[] = '[]
MapMerge (t : ts) = "merge" : MapMerge ts
mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl
mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
gradientByCHAD = \env term input ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
let descr = makeMergeDescr env
dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
input1 = toPrimalE env input
(_out, grad) = interpretOpen input1 dterm
in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
toPrimalE SNil SNil = SNil
toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
toPrimal :: STy t -> Rep t -> Rep (D1 t)
toPrimal = \case
STNil -> id
STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
STMaybe t -> fmap (toPrimal t)
STArr _ t -> fmap (toPrimal t)
STScal _ -> id
STAccum{} -> error "Accumulators not allowed in input program"
gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
Value (toTan t p x) `SCons` toTanE env primal inp
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
STPair t1 t2 -> case der of
Left () -> bimap (zeroTan t1) (zeroTan t2) primal
Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
STEither t1 t2 -> case der of
Left () -> bimap (zeroTan t1) (zeroTan t2) primal
Right d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
STMaybe t -> liftA2 (toTan t) primal der
STArr _ t
| shapeSize (arrayShape der) == 0 ->
arrayMap (zeroTan t) primal
| arrayShape primal == arrayShape der ->
arrayGenerateLin (arrayShape primal) $ \i ->
toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
| otherwise ->
error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
closeIsh :: Double -> Double -> Bool
closeIsh a b =
abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
genShape :: SNat n -> Gen (Shape n)
genShape = \n -> do
sh <- genShapeNaive n
let sz = shapeSize sh
factor = sz `div` 100 + 1
return (shapeDiv sh factor)
where
genShapeNaive :: SNat n -> Gen (Shape n)
genShapeNaive SZ = return ShNil
genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)
shapeDiv :: Shape n -> Int -> Shape n
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
genValue :: STy a -> Gen (Value a)
genValue = \case
STNil -> return (Value ())
STPair a b -> lv2 (,) <$> genValue a <*> genValue b
STEither a b -> Gen.choice [lv1 Left <$> genValue a
,lv1 Right <$> genValue b]
STMaybe t -> Gen.choice [return (Value Nothing)
,lv1 Just <$> genValue t]
STArr n t -> do
sh <- genShape n
Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t)
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
where
lv1 :: (Rep a -> Rep b) -> Value a -> Value b
lv1 f (Value x) = Value (f x)
lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
lv2 f (Value x) (Value y) = Value (f x y)
genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
data TemplateVar n = TemplateVar (SNat n) String
deriving (Show)
data Template t where
TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
TpAny :: STy t -> Template t
TpPair :: Template a -> Template b -> Template (TPair a b)
deriving instance Show (Template t)
data ShapeConstraint n = ShapeAtLeast (Shape n)
deriving (Show)
genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
genTemplate = _
genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
genEnvTemplateExact shapes env = _
genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
genEnvTemplate constrs env = do
shapes <- DMap.traverseWithKey _ constrs
genEnvTemplateExact shapes env
showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve
showValue _ (STScal sty) x = case sty of
STF32 -> shows x
STF64 -> shows x
STI32 -> shows x
STI64 -> shows x
STBool -> shows x
showValue _ STAccum{} _ = error "Cannot show accumulators"
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
where
showEntries :: SList STy env -> SList Value env -> [String]
showEntries SNil SNil = []
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest expr = property $ do
let env = knownEnv @env
input <- forAllWith (showEnv env) $ genEnv env
let gradFwd = gradientByForward knownEnv expr input
gradCHAD = gradientByCHAD' knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
,("neural", adTest Example.neural)]
main :: IO ()
main = defaultMain [tests]
|