summaryrefslogtreecommitdiff
path: root/src/Example.hs
blob: ee91981577f278fb7be7c0e2b8261856c5e4c1e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
{-# LANGUAGE DataKinds #-}
module Example where

import Data.Some

import AST
import AST.Pretty
import CHAD
import Simplify


bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)

senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil

descr1 :: Storage a -> Storage b
       -> Descr [TScal TF32, TScal TF32] [b, a]
descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
  where t = STScal STF32

-- x y |- x * y + x
--
-- let x3 = (x1, x2)
--     x4 = ((*) x3, x1)
--  in ( (+) x4
--     , let x5 = 1.0
--           x6 = Inr (x5, x5)
--        in case x6 of
--             Inl x7 -> return ()
--             Inr x8 ->
--               let x9 = fst x8
--                   x10 = Inr (snd x3 * x9, fst x3 * x9)
--                in case x10 of
--                     Inl x11 -> return ()
--                     Inr x12 ->
--                       let x13 = fst x12
--                        in one "v1" x13 >>= \x14 ->
--                             let x15 = snd x12
--                              in one "v2" x15 >>= \x16 ->
--                     let x17 = snd x8
--                      in one "v1" x17)
--
-- ( (x1 * x2) + x1
-- , let x5 = 1.0
--   in do one "v1" (x2 * x5)
--         one "v2" (x1 * x5)
--         one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex1 =
  bin (OAdd STF32)
    (bin (OMul STF32)
      (EVar ext (STScal STF32) (IS IZ))
      (EVar ext (STScal STF32) IZ))
    (EVar ext (STScal STF32) (IS IZ))

-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex2 =
  ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
                             (EVar ext (STScal STF32) IZ)) $
    bin (OMul STF32)
      (EVar ext (STScal STF32) IZ)
      (bin (OAdd STF32)
        (EVar ext (STScal STF32) IZ)
        (EVar ext (STScal STF32) (IS (IS IZ))))

-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex3 =
  ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
                                          (EVar ext (STScal STF32) IZ)))
    (bin (OMul STF32) (EConst ext STF32 2.0)
                      (EVar ext (STScal STF32) (IS (IS IZ))))
    (bin (OAdd STF32) (EConst ext STF32 3.0)
                      (EVar ext (STScal STF32) (IS (IS IZ))))

-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 =
  ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
                                          (EVar ext (STScal STF32) IZ)))
    (bin (OAdd STF32)
       (bin (OMul STF32) (EConst ext STF32 2.0)
                         (EVar ext (STScal STF32) (IS (IS IZ))))
       (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ))
                         (EVar ext (STScal STF32) (IS IZ))))
    (bin (OAdd STF32) (EConst ext STF32 3.0)
                      (EVar ext (STScal STF32) (IS (IS IZ))))

senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil

descr5 :: Storage a -> Storage b
       -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b)

-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 =
  ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ))
    (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
                      (EVar ext (STScal STF32) (IS IZ)))
    (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
                      (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
                                        (EConst ext STF32 1.0)))