From 991dacbec1a84400ab3412811f246b1fc58b0938 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 27 Oct 2024 15:04:47 +0100
Subject: Better inline syntax for Language

---
 src/Example.hs | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

(limited to 'src/Example.hs')

diff --git a/src/Example.hs b/src/Example.hs
index 0bc18fb..d0405af 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -4,6 +4,7 @@
 {-# LANGUAGE PolyKinds #-}
 {-# LANGUAGE TypeFamilies #-}
 {-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
 module Example where
 
 import Array
@@ -164,8 +165,7 @@ type TMat = TArr (S (S Z))
 
 neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
 neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
-  let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R)
-      layer =
+  let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
         -- prod = wei `matmul` x
         let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
                               #wei ! #idx * #x ! pair nil (snd_ #idx)) $
@@ -174,8 +174,8 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
           let_ #out (#prod ! #idx + #bias ! #idx) $
           if_ (#out .<= const_ 0) (const_ 0) #out
 
-  in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $
-     let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
+  in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
+     let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
      let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
      #x3 ! nil
 
-- 
cgit v1.2.3-70-g09d2