| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
 | {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Example where
import Array
import AST
import AST.Pretty
import CHAD
import CHAD.Top
import ForwardAD
import Interpreter
import Language
import Simplify
import Debug.Trace
import Example.Types
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)
senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
descr1 :: Storage a -> Storage b
       -> Descr [TScal TF32, TScal TF32] [b, a]
descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
  where t = STScal STF32
-- x y |- x * y + x
--
-- let x3 = (x1, x2)
--     x4 = ((*) x3, x1)
--  in ( (+) x4
--     , let x5 = 1.0
--           x6 = Inr (x5, x5)
--        in case x6 of
--             Inl x7 -> return ()
--             Inr x8 ->
--               let x9 = fst x8
--                   x10 = Inr (snd x3 * x9, fst x3 * x9)
--                in case x10 of
--                     Inl x11 -> return ()
--                     Inr x12 ->
--                       let x13 = fst x12
--                        in one "v1" x13 >>= \x14 ->
--                             let x15 = snd x12
--                              in one "v2" x15 >>= \x16 ->
--                     let x17 = snd x8
--                      in one "v1" x17)
--
-- ( (x1 * x2) + x1
-- , let x5 = 1.0
--   in do one "v1" (x2 * x5)
--         one "v2" (x1 * x5)
--         one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex1 = fromNamed $ lambda #x $ lambda #y $ body $
  #x * #y + #x
-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex2 = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #z (#x + #y) $
  #z * (#z + #x)
-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex3 = fromNamed $ lambda #x $ lambda #y $ body $
  if_ (#x .< #y) (2 * #x) (3 * #x)
-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 = fromNamed $ lambda #x $ lambda #y $ body $
  if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = knownEnv
descr5 :: Storage a -> Storage b
       -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b)
-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 = fromNamed $ lambda #x $ lambda #y $ body $
  case_ #x (#a :-> #a * #y)
           (#b :-> #b * (#y + 1))
senv6 :: SList STy [TScal TI64, TScal TF32]
senv6 = knownEnv
descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-- x:R n:I |- let a = unit x
--                b = build1 n (\i. let c = idx0 a in c * c)
--            in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 = fromNamed $ lambda #x $ lambda #n $ body $
  let_ #a (unit #x) $
  let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
  #b ! pair nil 3
senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 = knownEnv
descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
-- |- let p1 = snd ps
--        p1' = fst ps
--        x1 = fst p1 * x + snd p1
--        p2 = snd p1'
--        p2' = fst p1'
--        x2 = fst p2 * x + snd p2
--        p3 = snd p2'
--        p3' = fst p2'
--        x3 = fst p3 * x + snd p3
--    in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
  let tR = STScal STF64
      tpair = STPair tR tR
      layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
            => STy p -> NExpr env R
      layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
        let_ #par (snd_ #parstup) $
        let_ #restpars (fst_ #parstup) $
        let_ #inp (fst_ #par * #inp + snd_ #par) $
        let_ #parstup #restpars $
        layer t
      layer STNil = #inp
      layer _ = error "Invalid layer inputs"
  in let_ #parstup #pars123 $
     let_ #inp #input $
     layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
  let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
        -- prod = wei `matmul` x
        let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
                              #wei ! #idx * #x ! pair nil (snd_ #idx)) $
        -- relu (prod + bias)
        build (SS SZ) (shape #prod) $ #idx :->
          let_ #out (#prod ! #idx + #bias ! #idx) $
          if_ (#out .<= const_ 0) (const_ 0) #out
  in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
     let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
     let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
     #x3 ! nil
type NeuralGrad = ((Array N2 Double, Array N1 Double)
                  ,(Array N2 Double, Array N1 Double)
                  ,Array N1 Double
                  ,Array N1 Double)
neuralGo :: (Double  -- primal
            ,NeuralGrad  -- gradient using CHAD
            ,NeuralGrad)  -- gradient using dual-numbers forward AD
neuralGo =
  let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
      lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
      lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
      input = arrayFromList (ShNil `ShCons` 2) [1,1]
      argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
      revderiv =
        simplifyN 20 $
          ELet ext (EConst ext STF64 1.0) $
            chad defaultConfig knownEnv neural
      (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
        (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
        _ -> undefined
      (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
  in trace (ppExpr knownEnv revderiv) $
     (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
 |