diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 82 |
1 files changed, 26 insertions, 56 deletions
diff --git a/src/Example.hs b/src/Example.hs index 424351c..6fd19cd 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -6,6 +6,7 @@ import AST import AST.Pretty import CHAD import Data +import Language import Simplify @@ -51,46 +52,24 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex1 = - bin (OAdd STF32) - (bin (OMul STF32) - (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ)) - (EVar ext (STScal STF32) (IS IZ)) +ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + x * y + x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex2 = - ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ)) $ - bin (OMul STF32) - (EVar ext (STScal STF32) IZ) - (bin (OAdd STF32) - (EVar ext (STScal STF32) IZ) - (EVar ext (STScal STF32) (IS (IS IZ)))) +ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + let_ (x + y) $ \z -> + z * (z + x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex3 = - ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ))) - (bin (OMul STF32) (EConst ext STF32 2.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) - (bin (OAdd STF32) (EConst ext STF32 3.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) +ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + if_ (x .< y) (2 * x) (3 * x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex4 = - ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) IZ))) - (bin (OAdd STF32) - (bin (OMul STF32) (EConst ext STF32 2.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) - (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ)) - (EVar ext (STScal STF32) (IS IZ)))) - (bin (OAdd STF32) (EConst ext STF32 3.0) - (EVar ext (STScal STF32) (IS (IS IZ)))) +ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + if_ (x .< y) (2 * x + y * y) (3 + x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil @@ -101,13 +80,9 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) -ex5 = - ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ)) - (bin (OMul STF32) (EVar ext (STScal STF32) IZ) - (EVar ext (STScal STF32) (IS IZ))) - (bin (OMul STF32) (EVar ext (STScal STF32) IZ) - (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) - (EConst ext STF32 1.0))) +ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ + case_ x (\a -> a * y) + (\b -> b * (y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil @@ -119,13 +94,10 @@ descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) -ex6 = - ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ - ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $ - ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $ - bin (OMul STF32) (EVar ext (STScal STF32) IZ) - (EVar ext (STScal STF32) IZ)) $ - (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) +ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $ + let_ (unit x) $ \a -> + let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b -> + idx0 (b .! 3) type R = TScal TF32 @@ -154,19 +126,17 @@ descr7 = -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R -ex7 = +ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $ let tR = STScal STF32 tpair = STPair tR tR - layer :: STy p -> Idx env p -> Idx env R -> Ex env R - layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp = - ELet ext (ESnd ext (EVar ext parst pars)) $ - ELet ext (EFst ext (EVar ext parst (IS pars))) $ - ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ))) - (EVar ext tR (IS (IS inp)))) - (ESnd ext (EVar ext tpair (IS IZ)))) $ - layer t (IS IZ) IZ - layer STNil _ inp = EVar ext tR inp + layer :: STy p -> SExpr p -> SExpr R -> SExpr R + layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t = + let_ (snd_ pars) $ \par -> + let_ (fst_ pars) $ \restpars -> + let_ (fst_ par * inp + snd_ par) $ \res -> + layer t restpars res + layer STNil _ inp = inp layer _ _ _ = error "Invalid layer inputs" - in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ + in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input |