summaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers.hs
blob: 3587378730ab65757cdd38abf3306bff50e8636c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- I want to bring various type variables in scope using type annotations in
-- patterns, but I don't want to have to mention all the other type parameters
-- of the types in question as well then. Partial type signatures (with '_') are
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
module ForwardAD.DualNumbers (
  dfwdDN,
  DN, DNS, DNE, dn, dne,
) where

import AST
import Data
import ForwardAD.DualNumbers.Types


dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
dnPreservesTupIx SZ = Refl
dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl

convIdx :: Idx env t -> Idx (DNE env) (DN t)
convIdx IZ = IZ
convIdx (IS i) = IS (convIdx i)

scalTyCase :: SScalTy t
           -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
           -> (DN (TScal t) ~ TScal t => r)
           -> r
scalTyCase STF32  k1 _ = k1
scalTyCase STF64  k1 _ = k1
scalTyCase STI32  _ k2 = k2
scalTyCase STI64  _ k2 = k2
scalTyCase STBool _ k2 = k2

-- | Argument does not need to be duplicable.
dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
dop = \case
  OAdd t -> scalTyCase t
    (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
    (EOp ext (OAdd t))
  OMul t -> scalTyCase t
    (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
    (EOp ext (OMul t))
  ONeg t -> scalTyCase t
    (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
    (EOp ext (ONeg t))
  OLt t -> scalTyCase t
    (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
    (EOp ext (OLt t))
  OLe t -> scalTyCase t
    (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
    (EOp ext (OLe t))
  OEq t -> scalTyCase t
    (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
    (EOp ext (OEq t))
  ONot -> EOp ext ONot
  OAnd -> EOp ext OAnd
  OOr -> EOp ext OOr
  OIf -> EOp ext OIf
  ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
  OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
  where
    add :: ScalIsNumeric t ~ True
        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
    add t a b = EOp ext (OAdd t) (EPair ext a b)

    mul :: ScalIsNumeric t ~ True
        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
    mul t a b = EOp ext (OMul t) (EPair ext a b)

    neg :: ScalIsNumeric t ~ True
        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
    neg t = EOp ext (ONeg t)

    unFloat :: DN a ~ TPair a a
            => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
            -> Ex env (DN a) -> Ex env (DN b)
    unFloat f e =
      ELet ext e $
        let var = EVar ext (typeOf e) IZ
        in f (EFst ext var, ESnd ext var)

    binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
             => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
             -> Ex env (DN a) -> Ex env (DN b)
    binFloat f e =
      ELet ext e $
        let var = EVar ext (typeOf e) IZ
        in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
             (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))

dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
  EVar _ t i -> EVar ext (dn t) (convIdx i)
  ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
  EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
  EFst _ e -> EFst ext (dfwdDN e)
  ESnd _ e -> ESnd ext (dfwdDN e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext (dn t) (dfwdDN e)
  EInr _ t e -> EInr ext (dn t) (dfwdDN e)
  ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
  ENothing _ t -> ENothing ext (dn t)
  EJust _ e -> EJust ext (dfwdDN e)
  EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
  EConstArr _ n t x -> scalTyCase t
    (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
          (EConstArr ext n t x))
    (EConstArr ext n t x)
  EBuild _ n a b
    | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
  EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
  ESum1Inner _ e ->
    let STArr n (STScal t) = typeOf e
        pairty = (STPair (STScal t) (STScal t))
    in scalTyCase t
         (ELet ext (dfwdDN e) $
            ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
                                       (EVar ext (STArr n pairty) IZ)))
                 (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
                                       (EVar ext (STArr n pairty) IZ))))
         (ESum1Inner ext (dfwdDN e))
  EUnit _ e -> EUnit ext (dfwdDN e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
  EConst _ t x -> scalTyCase t
    (EPair ext (EConst ext t x) (EConst ext t 0.0))
    (EConst ext t x)
  EIdx0 _ e -> EIdx0 ext (dfwdDN e)
  EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
  EIdx _ a b
    | STArr n _ <- typeOf a
    , Refl <- dnPreservesTupIx n
    -> EIdx ext (dfwdDN a) (dfwdDN b)
  EShape _ e
    | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
    -> EShape ext (dfwdDN e)
  EOp _ op e -> dop op (dfwdDN e)
  ECustom _ s t _ du _ e1 e2 ->
    -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code.
    ELet ext (_ e1) $
    ELet ext (weakenExpr WSink (dfwdDN e2)) $
      weakenExpr (WCopy (WCopy WClosed)) du
  EError t s -> EError (dn t) s

  EWith{} -> err_accum
  EAccum{} -> err_accum
  EZero{} -> err_monoid
  EPlus{} -> err_monoid
  EOneHot{} -> err_monoid
  where
    err_accum = error "Accumulator operations unsupported in the source program"
    err_monoid = error "Monoid operations unsupported in the source program"