1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
|
{-# LANGUAGE DataKinds #-}
-- {-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Data.Bifunctor
-- import qualified Data.Dependent.Map as DMap
-- import Data.Dependent.Map (DMap)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main
import Array
import AST
import AST.Pretty
import CHAD
import CHAD.Types
import Data
import qualified Example
import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
import Simplify
type family MapMerge env where
MapMerge '[] = '[]
MapMerge (t : ts) = "merge" : MapMerge ts
mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl
mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
primalEnv :: SList STy env' -> SList STy (D1E env')
primalEnv SNil = SNil
primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64)
-> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env)))
diffCHAD = \simplIters env term ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of
(Refl, Refl, Dict) ->
let descr = makeMergeDescr env
simpl = case simplIters of
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
let dterm = diffCHAD simplIters env term
input1 = toPrimalE env input
(out, grad) = interpretOpen False input1 dterm
in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad)))
where
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
toPrimalE SNil SNil = SNil
toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
toPrimal :: STy t -> Rep t -> Rep (D1 t)
toPrimal = \case
STNil -> id
STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
STMaybe t -> fmap (toPrimal t)
STArr _ t -> fmap (toPrimal t)
STScal _ -> id
STAccum{} -> error "Accumulators not allowed in input program"
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
Value (toTan t p x) `SCons` toTanE env primal inp
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
STPair t1 t2 -> case der of
Left () -> bimap (zeroTan t1) (zeroTan t2) primal
Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
STEither t1 t2 -> case der of
Left () -> bimap (zeroTan t1) (zeroTan t2) primal
Right d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
STMaybe t -> liftA2 (toTan t) primal der
STArr _ t
| shapeSize (arrayShape der) == 0 ->
arrayMap (zeroTan t) primal
| arrayShape primal == arrayShape der ->
arrayGenerateLin (arrayShape primal) $ \i ->
toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
| otherwise ->
error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
closeIsh :: Double -> Double -> Bool
closeIsh a b =
abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
genShape :: SNat n -> Gen (Shape n)
genShape = \n -> do
sh <- genShapeNaive n
let sz = shapeSize sh
factor = sz `div` 100 + 1
return (shapeDiv sh factor)
where
genShapeNaive :: SNat n -> Gen (Shape n)
genShapeNaive SZ = return ShNil
genShapeNaive (SS n) = ShCons <$> genShapeNaive n <*> Gen.integral (Range.linear 0 10)
shapeDiv :: Shape n -> Int -> Shape n
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t)
genValue :: STy a -> Gen (Value a)
genValue = \case
STNil -> return (Value ())
STPair a b -> liftV2 (,) <$> genValue a <*> genValue b
STEither a b -> Gen.choice [liftV Left <$> genValue a
,liftV Right <$> genValue b]
STMaybe t -> Gen.choice [return (Value Nothing)
,liftV Just <$> genValue t]
STArr n t -> genShape n >>= genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
-- data TemplateVar n = TemplateVar (SNat n) String
-- deriving (Show)
-- data Template t where
-- TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
-- TpAny :: STy t -> Template t
-- TpPair :: Template a -> Template b -> Template (TPair a b)
-- deriving instance Show (Template t)
-- data ShapeConstraint n = ShapeAtLeast (Shape n)
-- deriving (Show)
-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
-- genTemplate = _
-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
-- genEnvTemplateExact shapes env = _
-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
-- genEnvTemplate constrs env = do
-- shapes <- DMap.traverseWithKey _ constrs
-- genEnvTemplateExact shapes env
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
where
showEntries :: SList STy env -> SList Value env -> [String]
showEntries SNil SNil = []
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = flip adTestGen (genEnv (knownEnv @env))
-- adTestTp :: forall env. KnownEnv env
-- => DMap TemplateVar ShapeConstraint -> SList Template env
-- -> Ex env (TScal TF64) -> Property
-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp)
adTestGen :: forall env. KnownEnv env
=> Ex env (TScal TF64) -> Gen (SList Value env) -> Property
adTestGen expr envGenerator = property $ do
let env = knownEnv @env
input <- forAllWith (showEnv env) envGenerator
let outPrimal = interpretOpen False input expr
gradFwd = gradientByForward knownEnv expr input
(ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
(ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
(ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
annotate (ppExpr knownEnv expr)
annotate ppdterm
annotate ppdterm_S
diff ppdterm_S20 (==) ppdterm_S
diff outChad closeIsh outChad_S
diff outPrimal closeIsh outChad_S
diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
tests :: IO Bool
tests = checkSequential $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)
,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))
,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $
idx0 $ sum1i $ replicate1i 10 #x)
,("pairs", adTest term_pairs)
,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0)
,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)
-- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum
,("build1-sum", adTest term_build1_sum)
,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)
-- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
-- idx0 $ sum1i . sum1i $
-- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
-- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)
,("neural", adTestGen Example.neural $ do
let tR = STScal STF64
let genLayer nin nout =
liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
<*> genArray tR (ShNil `ShCons` nout)
nin <- Gen.integral (Range.linear 1 10)
n1 <- Gen.integral (Range.linear 1 10)
n2 <- Gen.integral (Range.linear 1 10)
input <- genArray tR (ShNil `ShCons` nin)
lay1 <- genLayer nin n1
lay2 <- genLayer n1 n2
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
]
main :: IO ()
main = defaultMain [tests]
|