summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs82
1 files changed, 26 insertions, 56 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 424351c..6fd19cd 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -6,6 +6,7 @@ import AST
import AST.Pretty
import CHAD
import Data
+import Language
import Simplify
@@ -51,46 +52,24 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
-- one "v2" (x1 * x5)
-- one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex1 =
- bin (OAdd STF32)
- (bin (OMul STF32)
- (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ))
- (EVar ext (STScal STF32) (IS IZ))
+ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ x * y + x
-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex2 =
- ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ)) $
- bin (OMul STF32)
- (EVar ext (STScal STF32) IZ)
- (bin (OAdd STF32)
- (EVar ext (STScal STF32) IZ)
- (EVar ext (STScal STF32) (IS (IS IZ))))
+ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ let_ (x + y) $ \z ->
+ z * (z + x)
-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex3 =
- ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ)))
- (bin (OMul STF32) (EConst ext STF32 2.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
- (bin (OAdd STF32) (EConst ext STF32 3.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
+ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ if_ (x .< y) (2 * x) (3 * x)
-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex4 =
- ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) IZ)))
- (bin (OAdd STF32)
- (bin (OMul STF32) (EConst ext STF32 2.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
- (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ))
- (EVar ext (STScal STF32) (IS IZ))))
- (bin (OAdd STF32) (EConst ext STF32 3.0)
- (EVar ext (STScal STF32) (IS (IS IZ))))
+ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ if_ (x .< y) (2 * x + y * y) (3 + x)
senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil
@@ -101,13 +80,9 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S
-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
-ex5 =
- ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ))
- (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
- (EVar ext (STScal STF32) (IS IZ)))
- (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
- (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
- (EConst ext STF32 1.0)))
+ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
+ case_ x (\a -> a * y)
+ (\b -> b * (y + 1))
senv6 :: SList STy [TScal TI64, TScal TF32]
senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil
@@ -119,13 +94,10 @@ descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)
-- b = build1 n (\i. let c = idx0 a in c * c)
-- in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
-ex6 =
- ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $
- ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $
- ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $
- bin (OMul STF32) (EVar ext (STScal STF32) IZ)
- (EVar ext (STScal STF32) IZ)) $
- (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3)))
+ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $
+ let_ (unit x) $ \a ->
+ let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b ->
+ idx0 (b .! 3)
type R = TScal TF32
@@ -154,19 +126,17 @@ descr7 =
-- x3 = fst p3 * x + snd p3
-- in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
-ex7 =
+ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $
let tR = STScal STF32
tpair = STPair tR tR
- layer :: STy p -> Idx env p -> Idx env R -> Ex env R
- layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp =
- ELet ext (ESnd ext (EVar ext parst pars)) $
- ELet ext (EFst ext (EVar ext parst (IS pars))) $
- ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ)))
- (EVar ext tR (IS (IS inp))))
- (ESnd ext (EVar ext tpair (IS IZ)))) $
- layer t (IS IZ) IZ
- layer STNil _ inp = EVar ext tR inp
+ layer :: STy p -> SExpr p -> SExpr R -> SExpr R
+ layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t =
+ let_ (snd_ pars) $ \par ->
+ let_ (fst_ pars) $ \restpars ->
+ let_ (fst_ par * inp + snd_ par) $ \res ->
+ layer t restpars res
+ layer STNil _ inp = inp
layer _ _ _ = error "Invalid layer inputs"
- in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ
+ in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input