summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2023-12-21 22:45:10 +0100
committerTom Smeding <tom@tomsmeding.com>2023-12-21 22:45:10 +0100
commit2872a9a18519d41e110c5cea36172935b64edfde (patch)
tree382d3dd9b0aaf546cd6c12075cca98a807976007 /src
Initial
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs34
-rw-r--r--src/TypeCheck.hs253
2 files changed, 287 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
new file mode 100644
index 0000000..42967bc
--- /dev/null
+++ b/src/AST.hs
@@ -0,0 +1,34 @@
+module AST where
+
+import Data.Map.Strict (Map)
+
+
+data Nat = Z | S Nat
+ deriving (Show)
+
+type Name = String
+
+data Term
+ = TSet Term -- ^ The n-th universe (n : Level)
+ | TVar Name -- ^ variable
+ | TPi Name Term Term -- ^ Pi: (x : A) -> B
+ | TLam Name Term Term -- ^ λ(x : A). B
+ | TApp Term Term -- ^ application
+
+ | TLift Term -- ^ Γ |- t : Set i ~> Γ |- lift t : Set (Succ i)
+ | TLevel -- ^ The Level type
+ | TLevelUniv -- ^ The LevelUniv type: the type of Level
+ | TIZero -- ^ Level zero
+ | TIMax Term Term -- ^ Maximum of two levels
+ | TISucc Term -- ^ Level + 1
+
+ -- Temporary stuff until we have proper inductive types:
+ | TOne -- ^ The unit type
+ | TUnit -- ^ The element of the unit type
+ | TSigma Name Term Term -- ^ Sigma: (x : A) × B
+ | TPair Term Term -- ^ Dependent pair
+ | TProj1 Term -- ^ First projection of a dependent pair
+ | TProj2 Term -- ^ Second projection of a dependent pair
+ deriving (Show)
+
+type Env = Map Name Term
diff --git a/src/TypeCheck.hs b/src/TypeCheck.hs
new file mode 100644
index 0000000..4c05c4b
--- /dev/null
+++ b/src/TypeCheck.hs
@@ -0,0 +1,253 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE DeriveFoldable #-}
+module TypeCheck where
+
+import Control.Monad
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.State.Strict
+import Control.Monad.Trans.Writer.CPS
+import Data.Foldable (toList)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+import Numeric.Natural
+
+import AST
+
+
+-- | type checking monad
+newtype TM a = TM (WriterT (Bag Constr) (StateT Natural (Either String)) a)
+ deriving stock (Functor)
+ deriving newtype (Applicative, Monad)
+
+data Bag a = BTwo (Bag a) (Bag a) | BOne a | BZero
+ deriving stock (Show, Functor, Foldable)
+instance Semigroup (Bag a) where (<>) = BTwo
+instance Monoid (Bag a) where mempty = BZero
+
+data Constr = VarEq Name Term
+ | LevelLeq Term Term
+ deriving (Show)
+
+type Subst = Map Name Term
+
+runTM :: TM a -> Either String ([Constr], a)
+runTM (TM m) = do
+ (res, cs) <- evalStateT (runWriterT m) 0
+ return (toList cs, res)
+
+genId :: TM Natural
+genId = TM (lift (state (\i -> (i, i + 1))))
+
+genName :: TM Name
+genName = ("." ++) . show <$> genId
+
+throw :: String -> TM a
+throw err = TM (lift (lift (Left err)))
+
+emit :: Constr -> TM ()
+emit c = TM (tell (BOne c))
+
+data OfType a b = a :| b
+ deriving stock (Show)
+infix 1 :|
+
+check :: Env -> Term -> Term -> TM Term
+check env topTerm typ = case topTerm of
+ TPair a b -> do
+ case whnf env typ of
+ TSigma name t1 t2 -> do
+ a' <- check env a t1
+ b' <- check (Map.insert name t1 env) b t2
+ return (TPair a' b')
+ t -> throw $ "Pair expression cannot have type " ++ show t
+
+ _ -> do
+ e' :| typ2 <- infer env topTerm
+ unify typ typ2
+ return e'
+
+-- | Evaluate the type part of the return value to WHNF before returning.
+inferW :: Env -> Term -> TM (OfType Term Term)
+inferW env term = do
+ e :| ty <- infer env term
+ return (e :| whnf env ty)
+
+infer :: Env -> Term -> TM (OfType Term Term)
+infer env = \case
+ TSet i -> do
+ return (TSet i :| TSet (TISucc i))
+
+ TVar n -> do
+ case Map.lookup n env of
+ Just ty -> return (TVar n :| ty)
+ Nothing -> throw $ "Variable out of scope: " ++ n
+
+ TPi x a b -> do
+ inferW env a >>= \case
+ a' :| TSet lvlA -> do
+ inferW (Map.insert x a' env) b >>= \case
+ b' :| TSet lvlB -> do
+ when (x `freeIn` lvlB) $
+ throw $ "Variable " ++ show x ++ " escapes Pi"
+ return (TPi x a' b' :| TSet (TIMax lvlA lvlB))
+ _ :| tb -> throw $ "RHS of a Pi not of type Set i, but: " ++ show tb
+ _ :| ta -> throw $ "LHS type of a Pi not of type Set i, but: " ++ show ta
+
+ TLam x t e -> do
+ inferW env t >>= \case
+ t' :| TSet{} -> do
+ e' :| te <- inferW (Map.insert x t' env) e
+ when (x `freeIn` te) $
+ throw $ "Variable " ++ show x ++ " escape lambda"
+ return (TLam x t' e' :| TPi x t' te)
+ _ :| tt -> throw $ "Lambda variable type not of type Set i, but: " ++ show tt
+
+ TApp a b -> do
+ inferW env a >>= \case
+ a' :| TPi name t1 t2 -> do
+ b' <- check env b t1
+ return (TApp a' b' :| subst name b' t2)
+ _ :| ta -> throw $ "LHS of application not of Pi type, but: " ++ show ta
+
+ TLift e -> do
+ inferW env e >>= \case
+ e' :| TSet lvl -> do
+ return (TLift e' :| TSet (TISucc lvl))
+ _ :| te -> throw $ "Argument to lift not of type Set i, but: " ++ show te
+
+ TLevel -> do
+ return (TLevel :| TLevelUniv)
+
+ TLevelUniv -> do
+ return (TLevelUniv :| TSet (TISucc TIZero))
+
+ TIZero -> do
+ return (TIZero :| TLevel)
+
+ TIMax a b -> do
+ infer env a >>= \case
+ a' :| TLevel -> do
+ inferW env b >>= \case
+ b' :| TLevel -> do
+ return (TIMax a' b' :| TLevel)
+ _ :| tb -> throw $ "RHS of imax not of type Level, but: " ++ show tb
+ _ :| ta -> throw $ "LHS of imax not of type Level, but: " ++ show ta
+
+ TISucc a -> do
+ inferW env a >>= \case
+ a' :| TLevel -> do
+ return (TISucc a' :| TLevel)
+ _ :| ta -> throw $ "Argument of isucc not of type Level, but: " ++ show ta
+
+ TOne -> do
+ return (TOne :| TSet TIZero)
+
+ TUnit -> do
+ return (TUnit :| TOne)
+
+ TSigma x a b -> do
+ inferW env a >>= \case
+ a' :| TSet lvlA -> do
+ inferW env b >>= \case
+ b' :| TSet lvlB -> do
+ return (TSigma x a' b' :| TSet (TIMax lvlA lvlB))
+ _ :| tb -> throw $ "RHS of a Sigma not of type Set i, but: " ++ show tb
+ _ :| ta -> throw $ "LHS type of a Sigma not of type Set i, but: " ++ show ta
+
+ TPair{} -> do
+ throw "Dependent pair occurring in non-checking position"
+
+ TProj1 e -> do
+ inferW env e >>= \case
+ e' :| TSigma _name t1 _t2 -> do
+ return (TProj1 e' :| t1)
+ _ :| t -> throw $ "Argument of proj1 not of Sigma type, but: " ++ show t
+
+ TProj2 e -> do
+ inferW env e >>= \case
+ e' :| TSigma name _t1 t2 -> do
+ return (TProj2 e' :| subst name (TProj1 e') t2)
+ _ :| t -> throw $ "Argument of proj2 not of Sigma type, but: " ++ show t
+
+freeIn :: Name -> Term -> Bool
+freeIn target = \case
+ TSet a -> rec a
+ TVar n -> n == target
+ TPi n a b -> rec a && (if n == target then True else rec b)
+ TLam n a b -> rec a && (if n == target then True else rec b)
+ TApp a b -> rec a && rec b
+ TLift a -> rec a
+ TLevel -> False
+ TLevelUniv -> False
+ TIZero -> False
+ TIMax a b -> rec a && rec b
+ TISucc a -> rec a
+ TOne -> False
+ TUnit -> False
+ TSigma n a b -> rec a && (if n == target then True else rec b)
+ TPair a b -> rec a && rec b
+ TProj1 a -> rec a
+ TProj2 a -> rec a
+ where
+ rec = freeIn target
+
+subst :: Name -> Term -> Term -> Term
+subst target repl = \case
+ TVar n | n == target -> repl
+
+ TSet a -> TSet (rec a)
+ TVar n -> TVar n
+ TPi n a b -> TPi n (rec a) (if n == target then b else rec b)
+ TLam n a b -> TLam n (rec a) (if n == target then b else rec b)
+ TApp a b -> TApp (rec a) (rec b)
+ TLift a -> TLift (rec a)
+ TLevel -> TLevel
+ TLevelUniv -> TLevelUniv
+ TIZero -> TIZero
+ TIMax a b -> TIMax (rec a) (rec b)
+ TISucc a -> TISucc (rec a)
+ TOne -> TOne
+ TUnit -> TUnit
+ TSigma n a b -> TSigma n (rec a) (if n == target then b else rec b)
+ TPair a b -> TPair (rec a) (rec b)
+ TProj1 a -> TProj1 (rec a)
+ TProj2 a -> TProj2 (rec a)
+ where
+ rec = subst target repl
+
+unify :: Term -> Term -> TM ()
+unify (TSet a) (TSet b) = unify a b
+unify (TVar n) (TVar m) | n == m = return ()
+unify (TPi n a b) (TPi m c d) = unify a c >> unify b (subst m (TVar n) d)
+unify (TLam n a b) (TLam m c d) = unify a c >> unify b (subst m (TVar n) d)
+unify (TLift a) (TLift b) = unify a b
+unify TLevel TLevel = return ()
+unify TLevelUniv TLevelUniv = return ()
+unify TIZero TIZero = return ()
+unify (TIMax a b) (TIMax c d) = unify a c >> unify b d
+unify (TISucc a) (TISucc b) = unify a b
+unify TOne TOne = return ()
+unify TUnit TUnit = return ()
+unify (TSigma n a b) (TSigma m c d) = unify a c >> unify b (subst m (TVar n) d)
+unify (TPair a b) (TPair c d) = unify a c >> unify b d
+unify (TProj1 a) (TProj1 b) = unify a b
+unify (TProj2 a) (TProj2 b) = unify a b
+unify a b = throw $ "Cannot unify:\n- " ++ show a ++ "\n- " ++ show b
+
+whnf :: Env -> Term -> Term
+whnf env = \case
+ TVar n | Just t <- Map.lookup n env -> whnf env t
+ TApp (TLam n _ a) b -> whnf env (subst n b a)
+ TIMax a b -> merge (whnf env a) (whnf env b)
+ where
+ merge TIZero l = l
+ merge l TIZero = l
+ merge (TISucc l) (TISucc m) = TISucc (merge l m)
+ merge l m = TIMax l m
+ TProj1 (TPair a _) -> a
+ TProj2 (TPair _ b) -> b
+
+ t -> t