diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 78 |
1 files changed, 39 insertions, 39 deletions
diff --git a/src/Example.hs b/src/Example.hs index 6fd19cd..4130f47 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE TypeOperators #-} module Example where import AST @@ -52,66 +54,60 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - x * y + x +ex1 = fromNamed $ lambda #x $ lambda #y $ body $ + #x * #y + #x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - let_ (x + y) $ \z -> - z * (z + x) +ex2 = fromNamed $ lambda #x $ lambda #y $ body $ + let_ #z (#x + #y) $ + #z * (#z + #x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - if_ (x .< y) (2 * x) (3 * x) +ex3 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x) (3 * #x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - if_ (x .< y) (2 * x + y * y) (3 + x) +ex4 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] -senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil +senv5 = knownEnv descr5 :: Storage a -> Storage b -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] -descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b) +descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) -ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - case_ x (\a -> a * y) - (\b -> b * (y + 1)) +ex5 = fromNamed $ lambda #x $ lambda #y $ body $ + case_ #x (#a :-> #a * #y) + (#b :-> #b * (#y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] -senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil +senv6 = knownEnv descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] -descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) +descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) -ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $ - let_ (unit x) $ \a -> - let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b -> - idx0 (b .! 3) +ex6 = fromNamed $ lambda #x $ lambda #n $ body $ + let_ #a (unit #x) $ + let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ + idx0 (#b .! 3) type R = TScal TF32 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] -senv7 = - let tR = STScal STF32 - tpair = STPair tR tR - in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil +senv7 = knownEnv descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] -descr7 = - let tR = STScal STF32 - tpair = STPair tR tR - in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) +descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R @@ -126,17 +122,21 @@ descr7 = -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R -ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $ +ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ let tR = STScal STF32 tpair = STPair tR tR - layer :: STy p -> SExpr p -> SExpr R -> SExpr R - layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t = - let_ (snd_ pars) $ \par -> - let_ (fst_ pars) $ \restpars -> - let_ (fst_ par * inp + snd_ par) $ \res -> - layer t restpars res - layer STNil _ inp = inp - layer _ _ _ = error "Invalid layer inputs" - - in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input + layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32) + => STy p -> NExpr env R + layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t = + let_ #par (snd_ #parstup) $ + let_ #restpars (fst_ #parstup) $ + let_ #inp (fst_ #par * #inp + snd_ #par) $ + let_ #parstup #restpars $ + layer t + layer STNil = #inp + layer _ = error "Invalid layer inputs" + + in let_ #parstup #pars123 $ + let_ #inp #input $ + layer (STPair (STPair (STPair STNil tpair) tpair) tpair) |