summaryrefslogtreecommitdiff
path: root/src/Example.hs
blob: 6fd19cdcdeae5ef6d9f714ac06ec22ac1a34794a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
module Example where

import AST
import AST.Pretty
import CHAD
import Data
import Language
import Simplify


-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)


bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)

senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil

descr1 :: Storage a -> Storage b
       -> Descr [TScal TF32, TScal TF32] [b, a]
descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
  where t = STScal STF32

-- x y |- x * y + x
--
-- let x3 = (x1, x2)
--     x4 = ((*) x3, x1)
--  in ( (+) x4
--     , let x5 = 1.0
--           x6 = Inr (x5, x5)
--        in case x6 of
--             Inl x7 -> return ()
--             Inr x8 ->
--               let x9 = fst x8
--                   x10 = Inr (snd x3 * x9, fst x3 * x9)
--                in case x10 of
--                     Inl x11 -> return ()
--                     Inr x12 ->
--                       let x13 = fst x12
--                        in one "v1" x13 >>= \x14 ->
--                             let x15 = snd x12
--                              in one "v2" x15 >>= \x16 ->
--                     let x17 = snd x8
--                      in one "v1" x17)
--
-- ( (x1 * x2) + x1
-- , let x5 = 1.0
--   in do one "v1" (x2 * x5)
--         one "v2" (x1 * x5)
--         one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  x * y + x

-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  let_ (x + y) $ \z ->
  z * (z + x)

-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  if_ (x .< y) (2 * x) (3 * x)

-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  if_ (x .< y) (2 * x + y * y) (3 + x)

senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil

descr5 :: Storage a -> Storage b
       -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b)

-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  case_ x (\a -> a * y)
          (\b -> b * (y + 1))

senv6 :: SList STy [TScal TI64, TScal TF32]
senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil

descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)

-- x:R n:I |- let a = unit x
--                b = build1 n (\i. let c = idx0 a in c * c)
--            in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $
  let_ (unit x) $ \a ->
  let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b ->
  idx0 (b .! 3)

type R = TScal TF32

senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 =
  let tR = STScal STF32
      tpair = STPair tR tR
  in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil

descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
descr7 =
  let tR = STScal STF32
      tpair = STPair tR tR
  in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge)

-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
-- |- let p1 = snd ps
--        p1' = fst ps
--        x1 = fst p1 * x + snd p1
--        p2 = snd p1'
--        p2' = fst p1'
--        x2 = fst p2 * x + snd p2
--        p3 = snd p2'
--        p3' = fst p2'
--        x3 = fst p3 * x + snd p3
--    in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $
  let tR = STScal STF32
      tpair = STPair tR tR

      layer :: STy p -> SExpr p -> SExpr R -> SExpr R
      layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t =
        let_ (snd_ pars) $ \par ->
        let_ (fst_ pars) $ \restpars ->
        let_ (fst_ par * inp + snd_ par) $ \res ->
        layer t restpars res
      layer STNil _ inp = inp
      layer _ _ _ = error "Invalid layer inputs"

  in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input