diff options
-rw-r--r-- | src/Example.hs | 8 | ||||
-rw-r--r-- | src/Language.hs | 6 | ||||
-rw-r--r-- | src/Language/AST.hs | 33 |
3 files changed, 42 insertions, 5 deletions
diff --git a/src/Example.hs b/src/Example.hs index 0bc18fb..d0405af 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -4,6 +4,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} module Example where import Array @@ -164,8 +165,7 @@ type TMat = TArr (S (S Z)) neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ - let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R) - layer = + let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ -- prod = wei `matmul` x let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> #wei ! #idx * #x ! pair nil (snd_ #idx)) $ @@ -174,8 +174,8 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda # let_ #out (#prod ! #idx + #bias ! #idx) $ if_ (#out .<= const_ 0) (const_ 0) #out - in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $ - let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $ + in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ + let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil diff --git a/src/Language.hs b/src/Language.hs index 3a4a36c..b61a497 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -26,6 +26,12 @@ body = NBody lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t lambda = NLam +inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t +inline = inlineNFun + +(.$) :: SList f list -> f a -> SList f (a : list) +(.$) = flip SCons + let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t let_ = NELet diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 409d24d..3b04bec 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -58,6 +58,9 @@ data NExpr env t where -- partiality NEError :: STy a -> String -> NExpr env a + + -- embedded unnamed expressions + NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) type family Lookup name env where @@ -85,21 +88,41 @@ instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) whe instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where fromLabel = NEVar (fromLabel @name) +-- | Innermost variable variable on the outside, on the right. data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) +-- | First (outermost) parameter on the outside, on the left. +-- * env: environment of this function (grows as you go deeper inside lambdas) +-- * env': environment of the body of the function +-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t - NBody :: NExpr env t -> NFun env env t + NBody :: NExpr env' t -> NFun env' env' t type family UnName env where UnName '[] = '[] UnName ('(name, t) : env) = t : UnName env +envFromNEnv :: NEnv env -> SList STy (UnName env) +envFromNEnv NTop = SNil +envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env + +inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t +inlineNFun fun args = NEUnnamed (fromNamed fun) args + fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop +-- | Some of the parameters have already been put in the environment; some +-- haven't. Transfer all parameters to the left into the environment. +-- +-- [] `fromNamedFun` λx y z. E +-- = []:x `fromNamedFun` λy z. E +-- = []:x:y `fromNamedFun` λz. E +-- = []:x:y:z `fromNamedFun` λ. E +-- = []:x:y:z `fromNamedExpr` E fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun fromNamedFun env (NBody e) = fromNamedExpr env e @@ -136,6 +159,8 @@ fromNamedExpr val = \case NEOp op e -> EOp ext op (go e) NEError t s -> EError t s + + NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args where go :: NExpr env t' -> Ex (UnName env) t' go = fromNamedExpr val @@ -154,3 +179,9 @@ fromNamedExpr val = \case lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e + + injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t + injectWrapLet e SNil = e + injectWrapLet e (arg `SCons` args) = + injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) + args |