diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/Example.hs b/src/Example.hs index 0bc18fb..d0405af 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -4,6 +4,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} module Example where import Array @@ -164,8 +165,7 @@ type TMat = TArr (S (S Z)) neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ - let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R) - layer = + let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ -- prod = wei `matmul` x let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> #wei ! #idx * #x ! pair nil (snd_ #idx)) $ @@ -174,8 +174,8 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda # let_ #out (#prod ! #idx + #bias ! #idx) $ if_ (#out .<= const_ 0) (const_ 0) #out - in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $ - let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $ + in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ + let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil |