summaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 39415bb8fa55871d8f9cdf48385d6c1fa35ec0ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE LambdaCase #-}
module Main where

import Data.Bifunctor

import Array
import AST
import CHAD
import CHAD.Types
import Data
import ForwardAD
import Interpreter
import Interpreter.Rep


type family MapMerge env where
  MapMerge '[] = '[]
  MapMerge (t : ts) = "merge" : MapMerge ts

mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
mapMergeNoAccum SNil = Refl
mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl

mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl

gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
gradientByCHAD = \env term input ->
  case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
    (Refl, Refl) ->
      let descr = makeMergeDescr env
          dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
          input1 = toPrimalE env input
          (_out, grad) = interpretOpen input1 dterm
      in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
  where
    makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
    makeMergeDescr SNil = DTop
    makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)

    toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
    toPrimalE SNil SNil = SNil
    toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp

    toPrimal :: STy t -> Rep t -> Rep (D1 t)
    toPrimal = \case
      STNil -> id
      STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
      STMaybe t -> fmap (toPrimal t)
      STArr _ t -> fmap (toPrimal t)
      STScal _ -> id
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
  where
    toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
    toTanE SNil SNil SNil = SNil
    toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
      Value (toTan t p x) `SCons` toTanE env primal inp

    toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
    toTan typ primal der = case typ of
      STNil -> der
      STPair t1 t2 -> case der of
                        Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                        Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
      STEither t1 t2 -> case der of
                          Left () -> bimap (zeroTan t1) (zeroTan t2) primal
                          Right d -> case (primal, d) of
                            (Left p, Left d') -> Left (toTan t1 p d')
                            (Right p, Right d') -> Right (toTan t2 p d')
                            _ -> error "Primal and cotangent disagree on Either alternative"
      STMaybe t -> liftA2 (toTan t) primal der
      STArr _ t
        | shapeSize (arrayShape der) == 0 ->
            arrayMap (zeroTan t) primal
        | arrayShape primal == arrayShape der ->
            arrayGenerateLin (arrayShape primal) $ \i ->
              toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
        | otherwise ->
            error "Primal and cotangent disagree on array shape"
      STScal sty -> case sty of
        STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
      STAccum{} -> error "Accumulators not allowed in input program"

gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0

main :: IO ()
main = return ()