summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs78
1 files changed, 39 insertions, 39 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 6fd19cd..4130f47 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -1,5 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE TypeOperators #-}
module Example where
import AST
@@ -52,66 +54,60 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
-- one "v2" (x1 * x5)
-- one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
- x * y + x
+ex1 = fromNamed $ lambda #x $ lambda #y $ body $
+ #x * #y + #x
-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
- let_ (x + y) $ \z ->
- z * (z + x)
+ex2 = fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #z (#x + #y) $
+ #z * (#z + #x)
-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
- if_ (x .< y) (2 * x) (3 * x)
+ex3 = fromNamed $ lambda #x $ lambda #y $ body $
+ if_ (#x .< #y) (2 * #x) (3 * #x)
-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
-ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
- if_ (x .< y) (2 * x + y * y) (3 + x)
+ex4 = fromNamed $ lambda #x $ lambda #y $ body $
+ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
-senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil
+senv5 = knownEnv
descr5 :: Storage a -> Storage b
-> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
-descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b)
+descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b)
-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
-ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
- case_ x (\a -> a * y)
- (\b -> b * (y + 1))
+ex5 = fromNamed $ lambda #x $ lambda #y $ body $
+ case_ #x (#a :-> #a * #y)
+ (#b :-> #b * (#y + 1))
senv6 :: SList STy [TScal TI64, TScal TF32]
-senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil
+senv6 = knownEnv
descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
-descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)
+descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-- x:R n:I |- let a = unit x
-- b = build1 n (\i. let c = idx0 a in c * c)
-- in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
-ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $
- let_ (unit x) $ \a ->
- let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b ->
- idx0 (b .! 3)
+ex6 = fromNamed $ lambda #x $ lambda #n $ body $
+ let_ #a (unit #x) $
+ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
+ idx0 (#b .! 3)
type R = TScal TF32
senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
-senv7 =
- let tR = STScal STF32
- tpair = STPair tR tR
- in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil
+senv7 = knownEnv
descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
-descr7 =
- let tR = STScal STF32
- tpair = STPair tR tR
- in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge)
+descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
@@ -126,17 +122,21 @@ descr7 =
-- x3 = fst p3 * x + snd p3
-- in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
-ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $
+ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
let tR = STScal STF32
tpair = STPair tR tR
- layer :: STy p -> SExpr p -> SExpr R -> SExpr R
- layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t =
- let_ (snd_ pars) $ \par ->
- let_ (fst_ pars) $ \restpars ->
- let_ (fst_ par * inp + snd_ par) $ \res ->
- layer t restpars res
- layer STNil _ inp = inp
- layer _ _ _ = error "Invalid layer inputs"
-
- in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input
+ layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32)
+ => STy p -> NExpr env R
+ layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t =
+ let_ #par (snd_ #parstup) $
+ let_ #restpars (fst_ #parstup) $
+ let_ #inp (fst_ #par * #inp + snd_ #par) $
+ let_ #parstup #restpars $
+ layer t
+ layer STNil = #inp
+ layer _ = error "Invalid layer inputs"
+
+ in let_ #parstup #pars123 $
+ let_ #inp #input $
+ layer (STPair (STPair (STPair STNil tpair) tpair) tpair)