diff options
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/Example.hs | 78 | ||||
-rw-r--r-- | src/Language.hs | 121 | ||||
-rw-r--r-- | src/Language/AST.hs | 196 | ||||
-rw-r--r-- | src/Language/Tag.hs | 22 |
5 files changed, 199 insertions, 219 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 0c9170c..290329b 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -22,7 +22,6 @@ library Example Language Language.AST - Language.Tag Lemmas -- PreludeCu Simplify diff --git a/src/Example.hs b/src/Example.hs index 6fd19cd..4130f47 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE TypeOperators #-} module Example where import AST @@ -52,66 +54,60 @@ descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - x * y + x +ex1 = fromNamed $ lambda #x $ lambda #y $ body $ + #x * #y + #x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - let_ (x + y) $ \z -> - z * (z + x) +ex2 = fromNamed $ lambda #x $ lambda #y $ body $ + let_ #z (#x + #y) $ + #z * (#z + #x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - if_ (x .< y) (2 * x) (3 * x) +ex3 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x) (3 * #x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - if_ (x .< y) (2 * x + y * y) (3 + x) +ex4 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] -senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil +senv5 = knownEnv descr5 :: Storage a -> Storage b -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] -descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b) +descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) -ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ - case_ x (\a -> a * y) - (\b -> b * (y + 1)) +ex5 = fromNamed $ lambda #x $ lambda #y $ body $ + case_ #x (#a :-> #a * #y) + (#b :-> #b * (#y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] -senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil +senv6 = knownEnv descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] -descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) +descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) -ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $ - let_ (unit x) $ \a -> - let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b -> - idx0 (b .! 3) +ex6 = fromNamed $ lambda #x $ lambda #n $ body $ + let_ #a (unit #x) $ + let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ + idx0 (#b .! 3) type R = TScal TF32 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] -senv7 = - let tR = STScal STF32 - tpair = STPair tR tR - in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil +senv7 = knownEnv descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] -descr7 = - let tR = STScal STF32 - tpair = STPair tR tR - in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) +descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R @@ -126,17 +122,21 @@ descr7 = -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R -ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $ +ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ let tR = STScal STF32 tpair = STPair tR tR - layer :: STy p -> SExpr p -> SExpr R -> SExpr R - layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t = - let_ (snd_ pars) $ \par -> - let_ (fst_ pars) $ \restpars -> - let_ (fst_ par * inp + snd_ par) $ \res -> - layer t restpars res - layer STNil _ inp = inp - layer _ _ _ = error "Invalid layer inputs" - - in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input + layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32) + => STy p -> NExpr env R + layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t = + let_ #par (snd_ #parstup) $ + let_ #restpars (fst_ #parstup) $ + let_ #inp (fst_ #par * #inp + snd_ #par) $ + let_ #parstup #restpars $ + layer t + layer STNil = #inp + layer _ = error "Invalid layer inputs" + + in let_ #parstup #pars123 $ + let_ #inp #input $ + layer (STPair (STPair (STPair STNil tpair) tpair) tpair) diff --git a/src/Language.hs b/src/Language.hs index f4719bf..58a7070 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeOperators #-} module Language ( - scopeCheck, - SExpr, + fromNamed, + NExpr, module Language, + Lookup, ) where import AST @@ -12,101 +13,97 @@ import Data import Language.AST -lambda :: forall a args t. KnownTy a => Var a -> SFun args t -> SFun (Append args '[a]) t -lambda var (SFun args e) = SFun (sappend args (var `SCons` SNil)) e - -body :: SExpr t -> SFun '[] t -body e = SFun SNil e - - data a :-> b = a :-> b deriving (Show) -infix 0 :-> +infixr 0 :-> + + +body :: NExpr env t -> NFun env env t +body = NBody +lambda :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda = NLam -TODO --- TODO: should give SExpr an environment index of kind '[(Symbol, Ty)]. Then --- the IsLabel instance for SExpr (but not the one for Var!) can check that the --- type in the named environment matches the locally expected type. -let_ :: KnownTy a => Var a -> SExpr a -> SExpr t -> SExpr t -let_ var rhs e = SELet rhs (Lambda var e) +let_ :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t +let_ = NELet -pair :: SExpr a -> SExpr b -> SExpr (TPair a b) -pair = SEPair +pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) +pair = NEPair -fst_ :: SExpr (TPair a b) -> SExpr a -fst_ = SEFst +fst_ :: NExpr env (TPair a b) -> NExpr env a +fst_ = NEFst -snd_ :: SExpr (TPair a b) -> SExpr b -snd_ = SESnd +snd_ :: NExpr env (TPair a b) -> NExpr env b +snd_ = NESnd -nil :: SExpr TNil -nil = SENil +nil :: NExpr env TNil +nil = NENil -inl :: STy b -> SExpr a -> SExpr (TEither a b) -inl = SEInl +inl :: STy b -> NExpr env a -> NExpr env (TEither a b) +inl = NEInl -inr :: STy a -> SExpr b -> SExpr (TEither a b) -inr = SEInr +inr :: STy a -> NExpr env b -> NExpr env (TEither a b) +inr = NEInr -case_ :: (KnownTy a, KnownTy b) - => SExpr (TEither a b) -> (Var a :-> SExpr c) -> (Var b :-> SExpr c) -> SExpr c -case_ e (v1 :-> e1) (v2 :-> e2) = SECase e (Lambda v1 e1) (Lambda v2 e2) +case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c +case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 -build1 :: SExpr TIx -> (Var TIx :-> SExpr t) -> SExpr (TArr (S Z) t) -build1 a (v :-> b) = SEBuild1 a (Lambda v b) +build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) +build1 a (v :-> b) = NEBuild1 a v b -build :: SNat n -> SExpr (Tup (Replicate n TIx)) -> (Var (Tup (Replicate n TIx)) :-> SExpr t) -> SExpr (TArr n t) -build n a (v :-> b) = SEBuild n a (Lambda v b) +build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) +build n a (v :-> b) = NEBuild n a v b -fold1 :: KnownTy t => (SExpr t -> SExpr t -> SExpr t) -> SExpr (TArr (S n) t) -> SExpr (TArr n t) -fold1 f e = SEFold1 (mkLambda2 (f, e) f) e +fold1 :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1 (v1 :-> v2 :-> e1) e2 = NEFold1 v1 v2 e1 e2 -unit :: SExpr t -> SExpr (TArr Z t) -unit = SEUnit +unit :: NExpr env t -> NExpr env (TArr Z t) +unit = NEUnit -const_ :: KnownScalTy t => ScalRep t -> SExpr (TScal t) +const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy in case scalRepIsShow ty of - Dict -> SEConst ty x + Dict -> NEConst ty x -idx0 :: SExpr (TArr Z t) -> SExpr t -idx0 = SEIdx0 +idx0 :: NExpr env (TArr Z t) -> NExpr env t +idx0 = NEIdx0 -(.!) :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t) -(.!) = SEIdx1 +(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +(.!) = NEIdx1 -(!) :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t -(!) = SEIdx +(!) :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx -shape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -shape = SEShape +shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) +shape = NEShape -oper :: SOp a t -> SExpr a -> SExpr t -oper = SEOp +oper :: SOp a t -> NExpr env a -> NExpr env t +oper = NEOp -error_ :: KnownTy t => String -> SExpr t -error_ s = SEError knownTy s +error_ :: KnownTy t => String -> NExpr env t +error_ s = NEError knownTy s -(.==) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.==) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .== b = oper (OEq knownScalTy) (pair a b) -(.<) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.<) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .< b = oper (OLt knownScalTy) (pair a b) -(.>) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.>) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) (.>) = flip (.<) -(.<=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.<=) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .<= b = oper (OLe knownScalTy) (pair a b) -(.>=) :: KnownScalTy st => SExpr (TScal st) -> SExpr (TScal st) -> SExpr (TScal TBool) +(.>=) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) (.>=) = flip (.<=) -not_ :: SExpr (TScal TBool) -> SExpr (TScal TBool) +not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) not_ = oper ONot -if_ :: SExpr (TScal TBool) -> SExpr t -> SExpr t -> SExpr t -if_ e a b = case_ (oper OIf e) (\_ -> a) (\_ -> b) +-- | The "_" variables in scope are unusable and should be ignored. With a +-- weakening function on NExprs they could be hidden. +if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index f31f249..511723a 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -1,143 +1,149 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE TypeApplications #-} module Language.AST where -import Data.Proxy +import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (symbolVal, KnownSymbol) +import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text)) import AST import Data -data SExpr t where +type NExpr :: [(Symbol, Ty)] -> Ty -> Type +data NExpr env t where -- lambda calculus - SEVar :: Var t -> SExpr t - SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t + NEVar :: Lookup name env ~ t => Var name t -> NExpr env t + NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -- base types - SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b) - SEFst :: SExpr (TPair a b) -> SExpr a - SESnd :: SExpr (TPair a b) -> SExpr b - SENil :: SExpr TNil - SEInl :: STy b -> SExpr a -> SExpr (TEither a b) - SEInr :: STy a -> SExpr b -> SExpr (TEither a b) - SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c + NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) + NEFst :: NExpr env (TPair a b) -> NExpr env a + NESnd :: NExpr env (TPair a b) -> NExpr env b + NENil :: NExpr env TNil + NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) + NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) + NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c -- array operations - SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t) - SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t) - SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t) - SEUnit :: SExpr t -> SExpr (TArr Z t) + NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t) + NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) + NEFold1 :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEUnit :: NExpr env t -> NExpr env (TArr Z t) -- expression operations - SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t) - SEIdx0 :: SExpr (TArr Z t) -> SExpr t - SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t) - SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t - SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) - SEOp :: SOp a t -> SExpr a -> SExpr t + NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) + NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t + NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) + NEIdx :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t + NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) + NEOp :: SOp a t -> NExpr env a -> NExpr env t -- partiality - SEError :: STy a -> String -> SExpr a -deriving instance Show (SExpr t) + NEError :: STy a -> String -> NExpr env a +deriving instance Show (NExpr env t) -data Var a = Var (STy a) String - deriving (Show) +type family Lookup name env where + Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup name ('(name, t) : env) = t + Lookup name (_ : env) = Lookup name env -data Lambda a b = Lambda (Var a) b +data Var name t = Var (SSymbol name) (STy t) deriving (Show) -mkLambda :: KnownTy a => String -> (SExpr a -> f t) -> Lambda a (f t) -mkLambda name f = mkLambda' (Var knownTy name) f - -mkLambda' :: Var a -> (SExpr a -> f t) -> Lambda a (f t) -mkLambda' var f = Lambda var (f (SEVar var)) - -mkLambda2 :: (KnownTy a, KnownTy b) - => String -> String -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t)) -mkLambda2 name1 name2 f = mkLambda2' (Var knownTy name1) (Var knownTy name2) f - -mkLambda2' :: Var a -> Var b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t)) -mkLambda2' var1 var2 f = Lambda var1 (Lambda var2 (f (SEVar var1) (SEVar var2))) - -instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where - a + b = SEOp (OAdd knownScalTy) (SEPair a b) - a * b = SEOp (OMul knownScalTy) (SEPair a b) - negate e = SEOp (ONeg knownScalTy) e - abs = error "abs undefined for SExpr" - signum = error "signum undefined for SExpr" +instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where + a + b = NEOp (OAdd knownScalTy) (NEPair a b) + a * b = NEOp (OMul knownScalTy) (NEPair a b) + negate e = NEOp (ONeg knownScalTy) e + abs = error "abs undefined for NExpr" + signum = error "signum undefined for NExpr" fromInteger = let ty = knownScalTy in case scalRepIsShow ty of - Dict -> SEConst ty . fromInteger + Dict -> NEConst ty . fromInteger + +instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where + fromLabel = Var symbolSing knownTy + +instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where + fromLabel = NEVar (fromLabel @name) + +data NEnv env where + NTop :: NEnv '[] + NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) -instance (KnownTy t, KnownSymbol name) => IsLabel name (Var t) where - fromLabel = Var knownTy (symbolVal (Proxy @name)) +data NFun env env' t where + NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t + NBody :: NExpr env t -> NFun env env t -instance (KnownTy t, KnownSymbol name) => IsLabel name (SExpr t) where - fromLabel = SEVar (fromLabel @name) +type family UnName env where + UnName '[] = '[] + UnName ('(name, t) : env) = t : UnName env -data SFun args t = SFun (SList Var args) (SExpr t) +fromNamed :: NFun '[] env t -> Ex (UnName env) t +fromNamed = fromNamedFun NTop -scopeCheck :: SFun env t -> Ex env t -scopeCheck (SFun args e) = scopeCheckExpr args e +fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t +fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun +fromNamedFun env (NBody e) = fromNamedExpr env e -scopeCheckExpr :: forall env t. SList Var env -> SExpr t -> Ex env t -scopeCheckExpr val = \case - SEVar tag@(Var ty _) - | Just idx <- find tag val -> EVar ext ty idx +fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t +fromNamedExpr val = \case + NEVar var@(Var _ ty) + | Just idx <- find var val -> EVar ext ty idx | otherwise -> error "Variable out of scope in conversion from surface \ \expression to De Bruijn expression" - SELet a b -> ELet ext (go a) (lambda val b) - - SEPair a b -> EPair ext (go a) (go b) - SEFst e -> EFst ext (go e) - SESnd e -> ESnd ext (go e) - SENil -> ENil ext - SEInl t e -> EInl ext t (go e) - SEInr t e -> EInr ext t (go e) - SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b) - - SEBuild1 a b -> EBuild1 ext (go a) (lambda val b) - SEBuild n a b -> EBuild ext n (go a) (lambda val b) - SEFold1 a b -> EFold1 ext (lambda2 val a) (go b) - SEUnit e -> EUnit ext (go e) - - SEConst t x -> EConst ext t x - SEIdx0 e -> EIdx0 ext (go e) - SEIdx1 a b -> EIdx1 ext (go a) (go b) - SEIdx n a b -> EIdx ext n (go a) (go b) - SEShape e -> EShape ext (go e) - SEOp op e -> EOp ext op (go e) - - SEError t s -> EError t s + NELet n a b -> ELet ext (go a) (lambda val n b) + + NEPair a b -> EPair ext (go a) (go b) + NEFst e -> EFst ext (go e) + NESnd e -> ESnd ext (go e) + NENil -> ENil ext + NEInl t e -> EInl ext t (go e) + NEInr t e -> EInr ext t (go e) + NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) + + NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b) + NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) + NEFold1 n1 n2 a b -> EFold1 ext (lambda2 val n1 n2 a) (go b) + NEUnit e -> EUnit ext (go e) + + NEConst t x -> EConst ext t x + NEIdx0 e -> EIdx0 ext (go e) + NEIdx1 a b -> EIdx1 ext (go a) (go b) + NEIdx n a b -> EIdx ext n (go a) (go b) + NEShape e -> EShape ext (go e) + NEOp op e -> EOp ext op (go e) + + NEError t s -> EError t s where - go :: SExpr t' -> Ex env t' - go = scopeCheckExpr val + go :: NExpr env t' -> Ex (UnName env) t' + go = fromNamedExpr val - find :: Var t' -> SList Var env' -> Maybe (Idx env' t') - find _ SNil = Nothing - find tag@(Var ty s) (Var ty' s' `SCons` val') - | s == s' + find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') + find _ NTop = Nothing + find var@(Var s ty) (val' `NPush` Var s' ty') + | Just Refl <- testEquality s s' , Just Refl <- testEquality ty ty' = Just IZ | otherwise - = IS <$> find tag val' + = IS <$> find var val' - lambda :: SList Var env' -> Lambda a (SExpr b) -> Ex (a : env') b - lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e + lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b + lambda val' var e = fromNamedExpr (val' `NPush` var) e - lambda2 :: SList Var env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c - lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e + lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c + lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e diff --git a/src/Language/Tag.hs b/src/Language/Tag.hs deleted file mode 100644 index 9356073..0000000 --- a/src/Language/Tag.hs +++ /dev/null @@ -1,22 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -module Language.Tag ( - Tag(..), genTag, -) where - -import Data.IORef -import System.IO.Unsafe - -import AST - - -data Tag t = Tag (STy t) Int - deriving (Show) - -{-# NOINLINE tagCounter #-} -tagCounter :: IORef Int -tagCounter = unsafePerformIO $ newIORef 1 - -{-# NOINLINE genTag #-} -genTag :: handle -> STy t -> Tag t -genTag !_ ty = - unsafePerformIO $ Tag ty <$> atomicModifyIORef' tagCounter (\i -> (succ i, i)) |