aboutsummaryrefslogtreecommitdiff
path: root/AD.hs
blob: 76fefe440746ce84d8385d6bf0fe44003dc7e6e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module AD (
    Dual,
    dual,
    ad,
) where

import Data.Bifunctor
import Data.Type.Equality

import AST
import qualified Language as L
import Sink


-- | Dual-number version of a type. Pairs up every Double with a tangent copy.
type family Dual a where
    Dual () = ()
    Dual Int = Int
    Dual Bool = Bool
    Dual Double = (Double, Double)
    Dual (a, b) = (Dual a, Dual b)
    Dual (Array sh a) = Array sh (Dual a)
    Dual (a -> b) = Dual a -> Dual b

data DEnv env env' where
    ETop :: DEnv env env
    ECons :: Type a -> DEnv env env' -> DEnv (a ': env) (Dual a ': env')

-- | Convert a type to its 'Dual' variant.
dual :: Type a -> Type (Dual a)
dual (TFun a b) = TFun (dual a) (dual b)
dual TInt = TInt
dual TBool = TBool
dual TDouble = TPair TDouble TDouble
dual (TArray sht t) = TArray sht (dual t)
dual TNil = TNil
dual (TPair a b) = TPair (dual a) (dual b)

-- | Forward AD.
--
-- Since @'Dual' (a -> b) = 'Dual' a -> 'Dual' b@, calling 'ad' on a
-- function-typed expression gives the most useful results.
ad :: Exp env a -> Exp env (Dual a)
ad = ad' ETop

ad' :: DEnv env env' -> Exp env a -> Exp env' (Dual a)
ad' env = \case
    App e1 e2 -> App (ad' env e1) (ad' env e2)
    Lam t e -> Lam (dual t) (ad' (ECons t env) e)
    Var t i ->
        case convIdx env i of
          Left freei -> App (duale t) (Var t freei)
          Right duali -> Var (dual t) duali
    Let e1 e2 -> Let (ad' env e1) (ad' (ECons (typeof e1) env) e2)
    Lit l -> App (duale (literalType l)) (Lit l)
    Cond e1 e2 e3 -> Cond (ad' env e1) (ad' env e2) (ad' env e3)
    Const CAddI -> Const CAddI
    Const CSubI -> Const CSubI
    Const CMulI -> Const CMulI
    Const CDivI -> Const CDivI
    Const CAddF ->
        let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero
        in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble))
               (Pair (App (Const CAddF) (Pair (Fst (Fst v)) (Fst (Snd v))))
                     (App (Const CAddF) (Pair (Snd (Fst v)) (Snd (Snd v)))))
    Const CSubF ->
        let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero
        in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble))
               (Pair (App (Const CSubF) (Pair (Fst (Fst v)) (Fst (Snd v))))
                     (App (Const CSubF) (Pair (Snd (Fst v)) (Snd (Snd v)))))
    Const CMulF ->
        let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero
        in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble))
               (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Fst (Snd v))))
                     (App (Const CAddF) (Pair
                          (App (Const CMulF) (Pair (Fst (Fst v)) (Snd (Snd v))))
                          (App (Const CMulF) (Pair (Fst (Snd v)) (Snd (Fst v)))))))
    Const _ -> undefined
    Pair e1 e2 -> Pair (ad' env e1) (ad' env e2)
    Fst e -> Fst (ad' env e)
    Snd e -> Snd (ad' env e)
    Build sht e1 e2
      | Refl <- prfDualSht sht
      -> Build sht (ad' env e1) (ad' env e2)
    Ifold sht e1 e2 e3
      | Refl <- prfDualSht sht
      -> Ifold sht (ad' env e1) (ad' env e2) (ad' env e3)
    Index e1 e2
      | TArray sht _ <- typeof e1
      , Refl <- prfDualSht sht
      -> Index (ad' env e1) (ad' env e2)
    Shape e
      | TArray sht _ <- typeof e
      , Refl <- prfDualSht sht
      -> Shape (ad' env e)

convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a))
convIdx ETop i = Left i
convIdx (ECons _ _) Zero = Right Zero
convIdx (ECons _ env) (Succ i) = bimap Succ Succ (convIdx env i)

duale :: Type a -> Exp env (a -> Dual a)
duale topty = Lam topty (go topty)
  where
    go :: forall a env. Type a -> Exp (a ': env) (Dual a)
    go ty = case ty of
        TInt -> ref
        TBool -> ref
        TDouble -> Pair ref (Lit (LDouble 0))
        TArray _ t -> L.map (Lam t (go t)) ref
        TNil -> Lit LNil
        TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2))
        TFun t1 t2 -> Lam (dual t1)
                          (App (duale t2)
                               (App (sinkExp1 ref)
                                    (App (unduale t1)
                                         (Var (dual t1) Zero))))
      where
        ref :: Exp (a ': env) a
        ref = Var ty Zero

unduale :: Type a -> Exp env (Dual a -> a)
unduale topty = Lam (dual topty) (go topty)
  where
    go :: forall a env. Type a -> Exp (Dual a ': env) a
    go ty = case ty of
        TInt -> ref
        TBool -> ref
        TDouble -> Fst ref
        TArray _ t -> L.map (Lam (dual t) (go t)) ref
        TNil -> Lit LNil
        TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2))
        TFun t1 t2 -> Lam t1
                          (App (unduale t2)
                               (App (sinkExp1 ref)
                                    (App (duale t1)
                                         (Var t1 Zero))))
      where
        ref :: Exp (Dual a ': env) (Dual a)
        ref = Var (dual ty) Zero

prfDualSht :: ShapeType sh -> sh :~: Dual sh
prfDualSht STZ = Refl
prfDualSht (STC sht)
  | Refl <- prfDualSht sht
  = Refl