summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Example.hs8
-rw-r--r--src/Language.hs6
-rw-r--r--src/Language/AST.hs33
3 files changed, 42 insertions, 5 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 0bc18fb..d0405af 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
module Example where
import Array
@@ -164,8 +165,7 @@ type TMat = TArr (S (S Z))
neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
- let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R)
- layer =
+ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
-- prod = wei `matmul` x
let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
#wei ! #idx * #x ! pair nil (snd_ #idx)) $
@@ -174,8 +174,8 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #out (#prod ! #idx + #bias ! #idx) $
if_ (#out .<= const_ 0) (const_ 0) #out
- in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $
- let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
+ in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
+ let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil
diff --git a/src/Language.hs b/src/Language.hs
index 3a4a36c..b61a497 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -26,6 +26,12 @@ body = NBody
lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam
+inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
+inline = inlineNFun
+
+(.$) :: SList f list -> f a -> SList f (a : list)
+(.$) = flip SCons
+
let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 409d24d..3b04bec 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -58,6 +58,9 @@ data NExpr env t where
-- partiality
NEError :: STy a -> String -> NExpr env a
+
+ -- embedded unnamed expressions
+ NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)
type family Lookup name env where
@@ -85,21 +88,41 @@ instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) whe
instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
fromLabel = NEVar (fromLabel @name)
+-- | Innermost variable variable on the outside, on the right.
data NEnv env where
NTop :: NEnv '[]
NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)
+-- | First (outermost) parameter on the outside, on the left.
+-- * env: environment of this function (grows as you go deeper inside lambdas)
+-- * env': environment of the body of the function
+-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
data NFun env env' t where
NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
- NBody :: NExpr env t -> NFun env env t
+ NBody :: NExpr env' t -> NFun env' env' t
type family UnName env where
UnName '[] = '[]
UnName ('(name, t) : env) = t : UnName env
+envFromNEnv :: NEnv env -> SList STy (UnName env)
+envFromNEnv NTop = SNil
+envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env
+
+inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
+inlineNFun fun args = NEUnnamed (fromNamed fun) args
+
fromNamed :: NFun '[] env t -> Ex (UnName env) t
fromNamed = fromNamedFun NTop
+-- | Some of the parameters have already been put in the environment; some
+-- haven't. Transfer all parameters to the left into the environment.
+--
+-- [] `fromNamedFun` λx y z. E
+-- = []:x `fromNamedFun` λy z. E
+-- = []:x:y `fromNamedFun` λz. E
+-- = []:x:y:z `fromNamedFun` λ. E
+-- = []:x:y:z `fromNamedExpr` E
fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
fromNamedFun env (NBody e) = fromNamedExpr env e
@@ -136,6 +159,8 @@ fromNamedExpr val = \case
NEOp op e -> EOp ext op (go e)
NEError t s -> EError t s
+
+ NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
where
go :: NExpr env t' -> Ex (UnName env) t'
go = fromNamedExpr val
@@ -154,3 +179,9 @@ fromNamedExpr val = \case
lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e
+
+ injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
+ injectWrapLet e SNil = e
+ injectWrapLet e (arg `SCons` args) =
+ injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
+ args