summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 0bc18fb..d0405af 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
module Example where
import Array
@@ -164,8 +165,7 @@ type TMat = TArr (S (S Z))
neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
- let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R)
- layer =
+ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
-- prod = wei `matmul` x
let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
#wei ! #idx * #x ! pair nil (snd_ #idx)) $
@@ -174,8 +174,8 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #out (#prod ! #idx + #bias ! #idx) $
if_ (#out .<= const_ 0) (const_ 0) #out
- in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $
- let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
+ in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
+ let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil