blob: 045ac1cfa357feb50bfeec3a20b7535df1e12808 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Main where
import Data.Bifunctor
import Hedgehog
import Hedgehog.Main
import Array
import AST
import CHAD
import CHAD.Types
import Data
import ForwardAD
import Interpreter
import Interpreter.Rep
type family MapMerge env where
MapMerge '[] = '[]
MapMerge (t : ts) = "merge" : MapMerge ts
mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl
mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
gradientByCHAD = \env term input ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
let descr = makeMergeDescr env
dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
input1 = toPrimalE env input
(_out, grad) = interpretOpen input1 dterm
in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
toPrimalE SNil SNil = SNil
toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
toPrimal :: STy t -> Rep t -> Rep (D1 t)
toPrimal = \case
STNil -> id
STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
STMaybe t -> fmap (toPrimal t)
STArr _ t -> fmap (toPrimal t)
STScal _ -> id
STAccum{} -> error "Accumulators not allowed in input program"
gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
Value (toTan t p x) `SCons` toTanE env primal inp
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
STPair t1 t2 -> case der of
Left () -> bimap (zeroTan t1) (zeroTan t2) primal
Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
STEither t1 t2 -> case der of
Left () -> bimap (zeroTan t1) (zeroTan t2) primal
Right d -> case (primal, d) of
(Left p, Left d') -> Left (toTan t1 p d')
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
STMaybe t -> liftA2 (toTan t) primal der
STArr _ t
| shapeSize (arrayShape der) == 0 ->
arrayMap (zeroTan t) primal
| arrayShape primal == arrayShape der ->
arrayGenerateLin (arrayShape primal) $ \i ->
toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
| otherwise ->
error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
closeIsh :: Double -> Double -> Bool
closeIsh a b =
abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)
adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property
adTest input expr = property $
let env = knownEnv @env
gradFwd = gradientByForward knownEnv expr input
gradCHAD = gradientByCHAD' knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest (Value 42.0))]
main :: IO ()
main = defaultMain [tests]
|