diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/src/Example.hs b/src/Example.hs index 4130f47..d1d04e3 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -140,3 +140,23 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ in let_ #parstup #pars123 $ let_ #inp #input $ layer (STPair (STPair (STPair STNil tpair) tpair) tpair) + +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) + +neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R +neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ + let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R) + layer = + -- prod = wei `matmul` x + let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> + #wei ! #idx * #x ! pair nil (snd_ #idx)) $ + -- relu (prod + bias) + build (SS SZ) (shape #prod) $ #idx :-> + let_ #out (#prod ! #idx + #bias ! #idx) $ + if_ (#out .<= const_ 0) (const_ 0) #out + + in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $ + let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $ + let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ + #x3 ! nil |