{-# LANGUAGE LambdaCase #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE DeriveFoldable #-} module TypeCheck where import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS import Data.Foldable (toList) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Numeric.Natural import AST -- | type checking monad newtype TM a = TM (WriterT (Bag Constr) (StateT Natural (Either String)) a) deriving stock (Functor) deriving newtype (Applicative, Monad) data Bag a = BTwo (Bag a) (Bag a) | BOne a | BZero deriving stock (Show, Functor, Foldable) instance Semigroup (Bag a) where (<>) = BTwo instance Monoid (Bag a) where mempty = BZero data Constr = VarEq Name Term | LevelLeq Term Term deriving (Show) type Subst = Map Name Term runTM :: TM a -> Either String ([Constr], a) runTM (TM m) = do (res, cs) <- evalStateT (runWriterT m) 0 return (toList cs, res) genId :: TM Natural genId = TM (lift (state (\i -> (i, i + 1)))) genName :: TM Name genName = ("." ++) . show <$> genId throw :: String -> TM a throw err = TM (lift (lift (Left err))) emit :: Constr -> TM () emit c = TM (tell (BOne c)) data OfType a b = a :| b deriving stock (Show) infix 1 :| check :: Env -> Term -> Term -> TM Term check env topTerm typ = case topTerm of TPair a b -> do case whnf env typ of TSigma name t1 t2 -> do a' <- check env a t1 b' <- check (Map.insert name t1 env) b t2 return (TPair a' b') t -> throw $ "Pair expression cannot have type " ++ show t _ -> do e' :| typ2 <- infer env topTerm unify typ typ2 return e' -- | Evaluate the type part of the return value to WHNF before returning. inferW :: Env -> Term -> TM (OfType Term Term) inferW env term = do e :| ty <- infer env term return (e :| whnf env ty) infer :: Env -> Term -> TM (OfType Term Term) infer env = \case TSet i -> do return (TSet i :| TSet (TISucc i)) TVar n -> do case Map.lookup n env of Just ty -> return (TVar n :| ty) Nothing -> throw $ "Variable out of scope: " ++ n TPi x a b -> do inferW env a >>= \case a' :| TSet lvlA -> do inferW (Map.insert x a' env) b >>= \case b' :| TSet lvlB -> do when (x `freeIn` lvlB) $ throw $ "Variable " ++ show x ++ " escapes Pi" return (TPi x a' b' :| TSet (TIMax lvlA lvlB)) _ :| tb -> throw $ "RHS of a Pi not of type Set i, but: " ++ show tb _ :| ta -> throw $ "LHS type of a Pi not of type Set i, but: " ++ show ta TLam x t e -> do inferW env t >>= \case t' :| TSet{} -> do e' :| te <- inferW (Map.insert x t' env) e when (x `freeIn` te) $ throw $ "Variable " ++ show x ++ " escape lambda" return (TLam x t' e' :| TPi x t' te) _ :| tt -> throw $ "Lambda variable type not of type Set i, but: " ++ show tt TApp a b -> do inferW env a >>= \case a' :| TPi name t1 t2 -> do b' <- check env b t1 return (TApp a' b' :| subst name b' t2) _ :| ta -> throw $ "LHS of application not of Pi type, but: " ++ show ta TLift e -> do inferW env e >>= \case e' :| TSet lvl -> do return (TLift e' :| TSet (TISucc lvl)) _ :| te -> throw $ "Argument to lift not of type Set i, but: " ++ show te TLevel -> do return (TLevel :| TLevelUniv) TLevelUniv -> do return (TLevelUniv :| TSet (TISucc TIZero)) TIZero -> do return (TIZero :| TLevel) TIMax a b -> do infer env a >>= \case a' :| TLevel -> do inferW env b >>= \case b' :| TLevel -> do return (TIMax a' b' :| TLevel) _ :| tb -> throw $ "RHS of imax not of type Level, but: " ++ show tb _ :| ta -> throw $ "LHS of imax not of type Level, but: " ++ show ta TISucc a -> do inferW env a >>= \case a' :| TLevel -> do return (TISucc a' :| TLevel) _ :| ta -> throw $ "Argument of isucc not of type Level, but: " ++ show ta TOne -> do return (TOne :| TSet TIZero) TUnit -> do return (TUnit :| TOne) TSigma x a b -> do inferW env a >>= \case a' :| TSet lvlA -> do inferW env b >>= \case b' :| TSet lvlB -> do return (TSigma x a' b' :| TSet (TIMax lvlA lvlB)) _ :| tb -> throw $ "RHS of a Sigma not of type Set i, but: " ++ show tb _ :| ta -> throw $ "LHS type of a Sigma not of type Set i, but: " ++ show ta TPair{} -> do throw "Dependent pair occurring in non-checking position" TProj1 e -> do inferW env e >>= \case e' :| TSigma _name t1 _t2 -> do return (TProj1 e' :| t1) _ :| t -> throw $ "Argument of proj1 not of Sigma type, but: " ++ show t TProj2 e -> do inferW env e >>= \case e' :| TSigma name _t1 t2 -> do return (TProj2 e' :| subst name (TProj1 e') t2) _ :| t -> throw $ "Argument of proj2 not of Sigma type, but: " ++ show t freeIn :: Name -> Term -> Bool freeIn target = \case TSet a -> rec a TVar n -> n == target TPi n a b -> rec a && (if n == target then True else rec b) TLam n a b -> rec a && (if n == target then True else rec b) TApp a b -> rec a && rec b TLift a -> rec a TLevel -> False TLevelUniv -> False TIZero -> False TIMax a b -> rec a && rec b TISucc a -> rec a TOne -> False TUnit -> False TSigma n a b -> rec a && (if n == target then True else rec b) TPair a b -> rec a && rec b TProj1 a -> rec a TProj2 a -> rec a where rec = freeIn target subst :: Name -> Term -> Term -> Term subst target repl = \case TVar n | n == target -> repl TSet a -> TSet (rec a) TVar n -> TVar n TPi n a b -> TPi n (rec a) (if n == target then b else rec b) TLam n a b -> TLam n (rec a) (if n == target then b else rec b) TApp a b -> TApp (rec a) (rec b) TLift a -> TLift (rec a) TLevel -> TLevel TLevelUniv -> TLevelUniv TIZero -> TIZero TIMax a b -> TIMax (rec a) (rec b) TISucc a -> TISucc (rec a) TOne -> TOne TUnit -> TUnit TSigma n a b -> TSigma n (rec a) (if n == target then b else rec b) TPair a b -> TPair (rec a) (rec b) TProj1 a -> TProj1 (rec a) TProj2 a -> TProj2 (rec a) where rec = subst target repl unify :: Term -> Term -> TM () unify (TSet a) (TSet b) = unify a b unify (TVar n) (TVar m) | n == m = return () unify (TPi n a b) (TPi m c d) = unify a c >> unify b (subst m (TVar n) d) unify (TLam n a b) (TLam m c d) = unify a c >> unify b (subst m (TVar n) d) unify (TLift a) (TLift b) = unify a b unify TLevel TLevel = return () unify TLevelUniv TLevelUniv = return () unify TIZero TIZero = return () unify (TIMax a b) (TIMax c d) = unify a c >> unify b d unify (TISucc a) (TISucc b) = unify a b unify TOne TOne = return () unify TUnit TUnit = return () unify (TSigma n a b) (TSigma m c d) = unify a c >> unify b (subst m (TVar n) d) unify (TPair a b) (TPair c d) = unify a c >> unify b d unify (TProj1 a) (TProj1 b) = unify a b unify (TProj2 a) (TProj2 b) = unify a b unify a b = throw $ "Cannot unify:\n- " ++ show a ++ "\n- " ++ show b whnf :: Env -> Term -> Term whnf env = \case TVar n | Just t <- Map.lookup n env -> whnf env t TApp (TLam n _ a) b -> whnf env (subst n b a) TIMax a b -> merge (whnf env a) (whnf env b) where merge TIZero l = l merge l TIZero = l merge (TISucc l) (TISucc m) = TISucc (merge l m) merge l m = TIMax l m TProj1 (TPair a _) -> a TProj2 (TPair _ b) -> b t -> t