summaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 045ac1cfa357feb50bfeec3a20b7535df1e12808 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Main where

import Data.Bifunctor
import Hedgehog
import Hedgehog.Main

import Array
import AST
import CHAD
import CHAD.Types
import Data
import ForwardAD
import Interpreter
import Interpreter.Rep


type family MapMerge env where
  MapMerge '[] = '[]
  MapMerge (t : ts) = "merge" : MapMerge ts

mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl

mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl

gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
gradientByCHAD = \env term input ->
  case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
    (Refl, Refl) ->
      let descr = makeMergeDescr env
          dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
          input1 = toPrimalE env input
          (_out, grad) = interpretOpen input1 dterm
      in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
  where
    makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
    makeMergeDescr SNil = DTop
    makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)

    toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
    toPrimalE SNil SNil = SNil
    toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp

    toPrimal :: STy t -> Rep t -> Rep (D1 t)
    toPrimal = \case
      STNil -> id
      STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STMaybe t -> fmap (toPrimal t)
      STArr _ t -> fmap (toPrimal t)
      STScal _ -> id
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
  where
    toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
    toTanE SNil SNil SNil = SNil
    toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
      Value (toTan t p x) `SCons` toTanE env primal inp

    toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
    toTan typ primal der = case typ of
      STNil -> der
      STPair t1 t2 -> case der of
                        Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                        Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
      STEither t1 t2 -> case der of
                          Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                          Right d -> case (primal, d) of
                            (Left p, Left d') -> Left (toTan t1 p d')
                            (Right p, Right d') -> Right (toTan t2 p d')
                            _ -> error "Primal and cotangent disagree on Either alternative"
      STMaybe t -> liftA2 (toTan t) primal der
      STArr _ t
        | shapeSize (arrayShape der) == 0 ->
            arrayMap (zeroTan t) primal
        | arrayShape primal == arrayShape der ->
            arrayGenerateLin (arrayShape primal) $ \i ->
              toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
        | otherwise ->
            error "Primal and cotangent disagree on array shape"
      STScal sty -> case sty of
        STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0

closeIsh :: Double -> Double -> Bool
closeIsh a b =
  abs (a - b) < 1e-5 || (let scale = min (abs a) (abs b) in scale > 1e-4 && abs (a - b) / scale < 1e-5)

adTest :: forall env. KnownEnv env => SList Value env -> Ex env (TScal TF64) -> Property
adTest input expr = property $
  let env = knownEnv @env
      gradFwd = gradientByForward knownEnv expr input
      gradCHAD = gradientByCHAD' knownEnv expr input
      scFwd = envScalars env gradFwd
      scCHAD = envScalars env gradCHAD
  in diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
  where
    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
    envScalars SNil SNil = []
    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs

tests :: IO Bool
tests = checkParallel $ Group "AD"
  [("id", adTest (Value 42.0))]

main :: IO ()
main = defaultMain [tests]