summaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
blob: e867d663a54003e08bae776e253f9bdded6504c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module ForwardAD where

import Data.Bifunctor (bimap)
import System.IO.Unsafe

-- import Debug.Trace
-- import AST.Pretty

import Array
import AST
import Compile
import Data
import ForwardAD.DualNumbers
import Interpreter
import Interpreter.Rep


-- | Tangent along a type (coincides with cotangent for these types)
type family Tan t where
  Tan TNil = TNil
  Tan (TPair a b) = TPair (Tan a) (Tan b)
  Tan (TEither a b) = TEither (Tan a) (Tan b)
  Tan (TMaybe t) = TMaybe (Tan t)
  Tan (TArr n t) = TArr n (Tan t)
  Tan (TScal t) = TanS t

type family TanS t where
  TanS TI32 = TNil
  TanS TI64 = TNil
  TanS TF32 = TScal TF32
  TanS TF64 = TScal TF64
  TanS TBool = TNil

type family TanE env where
  TanE '[] = '[]
  TanE (t : env) = Tan t : TanE env

tanty :: STy t -> STy (Tan t)
tanty STNil = STNil
tanty (STPair a b) = STPair (tanty a) (tanty b)
tanty (STEither a b) = STEither (tanty a) (tanty b)
tanty (STMaybe t) = STMaybe (tanty t)
tanty (STArr n t) = STArr n (tanty t)
tanty (STScal t) = case t of
  STI32 -> STNil
  STI64 -> STNil
  STF32 -> STScal STF32
  STF64 -> STScal STF64
  STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"

zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
zeroTan (STMaybe _) Nothing = Nothing
zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
zeroTan (STArr _ t) x = fmap (zeroTan t) x
zeroTan (STScal STI32) _ = ()
zeroTan (STScal STI64) _ = ()
zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"

tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
tanScalars (STEither a _) (Left x) = tanScalars a x
tanScalars (STEither _ b) (Right y) = tanScalars b y
tanScalars (STMaybe _) Nothing = []
tanScalars (STMaybe t) (Just x) = tanScalars t x
tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
tanScalars (STScal STI32) _ = []
tanScalars (STScal STI64) _ = []
tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"

unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =
  let (x, dx) = unzipDN a d1
      (y, dy) = unzipDN b d2
  in ((x, y), (dx, dy))
unzipDN (STEither a b) d = case d of
  Left d1 -> bimap Left Left (unzipDN a d1)
  Right d2 -> bimap Right Right (unzipDN b d2)
unzipDN (STMaybe t) d = case d of
  Nothing -> (Nothing, Nothing)
  Just d' -> bimap Just Just (unzipDN t d')
unzipDN (STArr _ t) d = 
  let pairs = arrayMap (unzipDN t) d
  in (arrayMap fst pairs, arrayMap snd pairs)
unzipDN (STScal ty) d = case ty of
  STI32 -> (d, ())
  STI64 -> (d, ())
  STF32 -> d
  STF64 -> d
  STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"

dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
dotprodTan (STPair a b) (x, y) (x', y') =
  dotprodTan a x x' + dotprodTan b y y'
dotprodTan (STEither a b) x y = case (x, y) of
  (Left x', Left y') -> dotprodTan a x' y'
  (Right x', Right y') -> dotprodTan b x' y'
  _ -> error "dotprodTan: incompatible Either alternatives"
dotprodTan (STMaybe t) x y = case (x, y) of
  (Nothing, Nothing) -> 0.0
  (Just x', Just y') -> dotprodTan t x' y'
  _ -> error "dotprodTan: incompatible Maybe alternatives"
dotprodTan (STArr _ t) x y = 
  let sh1 = arrayShape x
      sh2 = arrayShape y
  in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
        | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
        | otherwise -> error "dotprodTan: incompatible array shapes"
dotprodTan (STScal ty) x y = case ty of
  STI32 -> 0.0
  STI64 -> 0.0
  STF32 -> realToFrac @Float @Double (x * y)
  STF64 -> x * y
  STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"

-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
-- dnConstE STNil _ = ENil ext
-- dnConstE (STPair t1 t2) e =
--   -- This creates fst/snd stacks of unbounded size, but let's not care here
--   EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
-- dnConstE (STEither t1 t2) e =
--   ECase ext e
--     (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
--     (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
-- dnConstE (STMaybe t) e =
--   EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
-- dnConstE (STArr n t) e =
--   EBuild ext n (EShape ext e)
--     (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
-- dnConstE (STScal t) e = case t of
--   STI32 -> e
--   STI64 -> e
--   STF32 -> EPair ext e (EConst ext STF32 0.0)
--   STF64 -> EPair ext e (EConst ext STF64 0.0)
--   STBool -> e
-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"

dnConst :: STy t -> Rep t -> Rep (DN t)
dnConst STNil = const ()
dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STMaybe t) = fmap (dnConst t)
dnConst (STArr _ t) = arrayMap (dnConst t)
dnConst (STScal t) = case t of
  STI32 -> id
  STI64 -> id
  STF32 -> (,0.0)
  STF64 -> (,0.0)
  STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"

-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
-- @t@ input.
type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)

dnOnehots :: STy t -> Rep t -> RevByFwd t
dnOnehots STNil _ = \_ -> ()
dnOnehots (STPair t1 t2) (x, y) =
  \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
dnOnehots (STEither t1 t2) e =
  case e of
    Left x -> \f -> Left (dnOnehots t1 x (f . Left))
    Right y -> \f -> Right (dnOnehots t2 y (f . Right))
dnOnehots (STMaybe t) m =
  case m of
    Nothing -> \_ -> Nothing
    Just x -> \f -> Just (dnOnehots t x (f . Just))
dnOnehots (STArr _ t) a =
  \f ->
    arrayGenerate (arrayShape a) $ \idx ->
      dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
                                                    if i == idx then oh else dnConst t (arrayIndex a i)))
dnOnehots (STScal t) x = case t of
  STI32 -> \_ -> ()
  STI64 -> \_ -> ()
  STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
  STF64 -> \f -> f (x, 1.0)
  STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"

dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val

type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)

dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
dnOnehotEnvs SNil SNil = \_ -> SNil
dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
  \f ->
    Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
    `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))

data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))

makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
makeFwdADArtifactInterp env expr =
  let dexpr = dfwdDN expr
  in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)

{-# NOINLINE makeFwdADArtifactCompile #-}
makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr)

drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)

drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
drevByFwd (FwdADArtifact env outty fun) input dres =
  dnOnehotEnvs env input $ \dnInput ->
    -- trace (showEnv (dne env) dnInput) $
    let (_, outtan) = unzipDN outty (fun dnInput)
    in dotprodTan outty outtan dres