summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs20
1 files changed, 20 insertions, 0 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 4130f47..d1d04e3 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -140,3 +140,23 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
in let_ #parstup #pars123 $
let_ #inp #input $
layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
+
+type TVec = TArr (S Z)
+type TMat = TArr (S (S Z))
+
+neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
+neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
+ let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R)
+ layer =
+ -- prod = wei `matmul` x
+ let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
+ #wei ! #idx * #x ! pair nil (snd_ #idx)) $
+ -- relu (prod + bias)
+ build (SS SZ) (shape #prod) $ #idx :->
+ let_ #out (#prod ! #idx + #bias ! #idx) $
+ if_ (#out .<= const_ 0) (const_ 0) #out
+
+ in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $
+ let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
+ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
+ #x3 ! nil