aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/APIv1.hs
blob: 4e82130feaca5aa12979e14e557c4f7506067e1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.APIv1 (
  -- * Expressions and types
  Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..),

  -- * Reverse derivatives (Fast CHAD)
  vjp, vjp',
  D2, D2E, Tup,
  CHADConfig(..),

  -- ** Primal type transform
  -- | The primal type transform only important when working with special
  -- operations like 'CHAD.Language.custom'.
  D1,

  -- * Forward derivatives (dual numbers)
  jvp, jvpDN,
  Tan, DN, DNE,

  -- * Working with expressions
  interpret, interpret1,
  compile, compile1,
  fullSimplify,
  SList(..), Value(..), Rep,
  KnownEnv(..), KnownTy(..),
) where

import CHAD.AST
import CHAD.AST.Count
import CHAD.AST.UnMonoid
import CHAD.Compile qualified as Compile
import CHAD.Data
import CHAD.Drev.Top
import CHAD.Drev.Types
import CHAD.ForwardAD
import CHAD.ForwardAD.DualNumbers
import CHAD.Interpreter qualified as Interpreter
import CHAD.Simplify
import CHAD.Interpreter.Rep


-- | Compute a reverse derivative: a vector-Jacobian product. The type has been
-- simplified with the assumption that 'D1' is the identity.
vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
vjp = vjp' (chcSetAccum defaultConfig)

-- | Same as 'vjp'', but supply CHAD configuration.
vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
vjp' config term
  | Dict <- styKnown (d2 (typeOf term)) =
  fullSimplify $
    unMonoid . simplifyFix $  -- need to merge onehots and accums for unMonoid to do its work
      chad' config knownEnv (simplifyFix term)

jvpDN :: Ex env t -> Ex (DNE env) (DN t)
jvpDN = dfwdDN

jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t))
jvp term
  | Dict <- styKnown (tanty (knownTy @s))
  = fullSimplify $
      elet (ezipDN knownTy) $
      elet (weakenExpr (WCopy WClosed) (jvpDN term)) $
        eunzipDN (typeOf term)
  where
    ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s')
    ezipDN STNil = ENil ext
    ezipDN (STPair a b) =
      EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ)
                                       IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ))
                                       IS (IS i) -> EVar ext t' (IS (IS i)))
                       (ezipDN @env a))
                (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ)
                                       IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ))
                                       IS (IS i) -> EVar ext t' (IS (IS i)))
                       (ezipDN @env b))
    ezipDN (STEither a b) =
      ecase (EVar ext (STEither a b) (IS IZ))
        (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ))
          (EInl ext (dn b) (ezipDN a))
          (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr"))
        (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ))
          (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl")
          (EInr ext (dn a) (ezipDN b)))
    ezipDN (STLEither a b) =
      elcase (EVar ext (STLEither a b) (IS IZ))
        (ELNil ext (dn a) (dn b))
        (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ))
          (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN")
          (ELInl ext (dn b) (ezipDN a))
          (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr"))
        (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ))
          (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN")
          (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl")
          (ELInr ext (dn a) (ezipDN b)))
    ezipDN (STMaybe t) =
      emaybe (EVar ext (STMaybe t) (IS IZ))
        (ENothing ext (dn t))
        (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ))
          (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN")
          (EJust ext (ezipDN t)))
    ezipDN (STArr n t) =
      ezipWith (ezipDN t)
               (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ)
    ezipDN (STScal st) = case st of
      STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ)
      STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ)
      STI32 -> EVar ext (STScal STI32) (IS IZ)
      STI64 -> EVar ext (STScal STI64) (IS IZ)
      STBool -> EVar ext (STScal STBool) (IS IZ)
    ezipDN STAccum{} = error "jvp: Accumulators not supported in source program"

    eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t'))
    eunzipDN STNil = EPair ext (ENil ext) (ENil ext)
    eunzipDN (STPair a b) =
      eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 ->
      eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 ->
        EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2)
    eunzipDN (STEither a b) =
      ecase (EVar ext (STEither (dn a) (dn b)) IZ)
        (eunPair (eunzipDN a) $ \_ a1 a2 ->
          EPair ext (EInl ext b a1) (EInl ext (tanty b) a2))
        (eunPair (eunzipDN b) $ \_ b1 b2 ->
          EPair ext (EInr ext a b1) (EInr ext (tanty a) b2))
    eunzipDN (STLEither a b) =
      elcase (EVar ext (STLEither (dn a) (dn b)) IZ)
        (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b)))
        (eunPair (eunzipDN a) $ \_ a1 a2 ->
          EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2))
        (eunPair (eunzipDN b) $ \_ b1 b2 ->
          EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2))
    eunzipDN (STMaybe t) =
      emaybe (EVar ext (STMaybe (dn t)) IZ)
        (EPair ext (ENothing ext t) (ENothing ext (tanty t)))
        (eunPair (eunzipDN t) $ \_ e1 e2 ->
          EPair ext (EJust ext e1) (EJust ext e2))
    eunzipDN (STArr n t) =
      elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $
        EPair ext (emap (EFst ext (evar IZ)) (evar IZ))
                  (emap (ESnd ext (evar IZ)) (evar IZ))
    eunzipDN (STScal st) = case st of
      STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ
      STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ
      STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext)
      STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext)
      STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext)
    eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program"

-- | Interpret an expression in a given environment.
interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t
interpret = Interpreter.interpretOpen False knownEnv

-- | Special case of 'interpret' for an expression with a single free variable.
interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t
interpret1 x = interpret (Value x `SCons` SNil)

-- | Compile an expression to C, load the resulting shared object into the
-- program and wrap it in a Haskell function.
compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t))
compile = Compile.compileStderr knownEnv

-- | Special case of 'compile' for an expression with a single free variable.
compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t))
compile1 term = do
  f <- Compile.compileStderr knownEnv term
  return (\x -> f (Value x `SCons` SNil))

-- | Simpify an expression. The 'vjp'/'jvp' functions already do this automatically.
fullSimplify :: KnownEnv env => Ex env t -> Ex env t
fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix