{-# LANGUAGE LambdaCase #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE ViewPatterns #-} module TypeCheck (checkProgram, typeCheck, typeInfer) where import Control.Monad -- import Control.Monad.Trans.Class -- import Control.Monad.Trans.State.Strict -- import Control.Monad.Trans.Writer.CPS -- import Data.Foldable (toList) -- import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Numeric.Natural import AST checkProgram :: Env -> [Definition] -> Either String Env checkProgram = foldM $ \env (Definition name jud@(_ :| ty)) -> do ty' <- runTM (inferW env ty) >>= \case ty' :| TSet _ -> return ty' ty' :| TSetw _ -> return ty' _ :| kind -> Left $ "Kind of declared type is not Set i or Setw i, but: " ++ show kind term <- typeCheck env jud return (Map.insert name (Just term :| ty') env) typeCheck :: Env -> OfType Term Term -> Either String Term typeCheck env jud = runTM (check env jud) -- runTM (check env jud) >>= \case -- ([], term') -> return term' -- (_, _) -> error "Don't know how to solve constraints yet" typeInfer :: Env -> Term -> Either String (OfType Term Term) typeInfer env term = runTM (infer env term) -- runTM (infer env term) >>= \case -- ([], jud) -> return jud -- (_, _) -> error "Don't know how to solve constraints yet" -- | type checking monad newtype TM a = TM ({- WriterT (Bag Constr) (StateT Natural ( -} Either String {- )) -} a) deriving stock (Functor) deriving newtype (Applicative, Monad) data Bag a = BTwo (Bag a) (Bag a) | BOne a | BZero deriving stock (Show, Functor, Foldable) instance Semigroup (Bag a) where (<>) = BTwo instance Monoid (Bag a) where mempty = BZero data Constr = VarEq Name Term | LevelLeq Term Term deriving (Show) runTM :: TM a -> Either String a runTM (TM m) = m -- (res, cs) <- evalStateT (runWriterT m) 0 -- return (toList cs, res) -- genId :: TM Natural -- genId = TM (lift (state (\i -> (i, i + 1)))) -- genName :: TM Name -- genName = ("." ++) . show <$> genId throw :: String -> TM a -- throw err = TM (lift (lift (Left err))) throw err = TM (Left err) -- emit :: Constr -> TM () -- emit c = TM (tell (BOne c)) check :: Env -> OfType Term Term -> TM Term check env (topTerm :| typ) = case topTerm of TPair a b -> do case whnf env typ of TSigma name t1 t2 -> do a' <- check env (a :| t1) b' <- check (Map.insert name (Nothing :| t1) env) (b :| t2) return (TPair a' b') t -> throw $ "Pair expression cannot have type " ++ show t _ -> do e' :| typ2 <- infer env topTerm unify typ typ2 return e' -- | Evaluate the type part of the return value to WHNF before returning. inferW :: Env -> Term -> TM (OfType Term Term) inferW env term = do e :| ty <- infer env term return (e :| whnf env ty) infer :: Env -> Term -> TM (OfType Term Term) infer env = \case TSet i -> do return (TSet i :| TSet (TISucc i)) TSetw i -> do return (TSetw i :| TSetw (succ i)) TVar n -> do case Map.lookup n env of Just (_ :| ty) -> return (TVar n :| ty) Nothing -> throw $ "Variable out of scope: " ++ n TPi x a b -> do inferW env a >>= \case a' :| (argumentKind -> Just kindA) -> do inferW (Map.insert x (Nothing :| a') env) b >>= \case b' :| (argumentKind -> Just kindB) -> do return (TPi x a' b' :| akLower (kindA <> kindB)) _ :| tb -> throw $ "RHS of a Pi not of acceptable type, but: " ++ show tb _ :| ta -> throw $ "LHS type of a Pi not of acceptable type, but: " ++ show ta TLam x t e -> do inferW env t >>= \case t' :| _ -> do e' :| te <- inferW (Map.insert x (Nothing :| t') env) e return (TLam x t' e' :| TPi x t' te) TApp a b -> do inferW env a >>= \case a' :| TPi name t1 t2 -> do b' <- check env (b :| t1) return (TApp a' b' :| subst name b' t2) _ :| ta -> throw $ "LHS of application not of Pi type, but: " ++ show ta TLift e -> do inferW env e >>= \case e' :| TSet lvl -> do return (TLift e' :| TSet (TISucc lvl)) _ :| te -> throw $ "Argument to lift not of type Set i, but: " ++ show te TLevel -> do return (TLevel :| TLevelUniv) TLevelUniv -> do return (TLevelUniv :| TSet (TISucc TIZero)) TIZero -> do return (TIZero :| TLevel) TIMax a b -> do infer env a >>= \case a' :| TLevel -> do inferW env b >>= \case b' :| TLevel -> do return (TIMax a' b' :| TLevel) _ :| tb -> throw $ "RHS of imax not of type Level, but: " ++ show tb _ :| ta -> throw $ "LHS of imax not of type Level, but: " ++ show ta TISucc a -> do inferW env a >>= \case a' :| TLevel -> do return (TISucc a' :| TLevel) _ :| ta -> throw $ "Argument of isucc not of type Level, but: " ++ show ta TAnnot (a :| b) -> do inferW env b >>= \case b' :| _ -> do a' <- check env (a :| b') return (a' :| b') TOne -> do return (TOne :| TSet TIZero) TUnit -> do return (TUnit :| TOne) TSigma x a b -> do inferW env a >>= \case a' :| (argumentKind -> Just kindA) -> do inferW (Map.insert x (Nothing :| a') env) b >>= \case b' :| (argumentKind -> Just kindB) -> do return (TSigma x a' b' :| akLower (kindA <> kindB)) _ :| tb -> throw $ "RHS of a Sigma not of acceptable type, but: " ++ show tb _ :| ta -> throw $ "LHS type of a Sigma not of acceptable type, but: " ++ show ta TPair{} -> do throw "Dependent pair occurring in non-checking position" TProj1 e -> do inferW env e >>= \case e' :| TSigma _name t1 _t2 -> do return (TProj1 e' :| t1) _ :| t -> throw $ "Argument of proj1 not of Sigma type, but: " ++ show t TProj2 e -> do inferW env e >>= \case e' :| TSigma name _t1 t2 -> do return (TProj2 e' :| subst name (TProj1 e') t2) _ :| t -> throw $ "Argument of proj2 not of Sigma type, but: " ++ show t data ArgKind = AKSet Term | AKSetw Natural | AKLevelUniv deriving (Show) instance Semigroup ArgKind where AKSet n <> AKSet m = AKSet (TIMax n m) AKSet _ <> ak = AKSetw 0 <> ak ak <> AKSet _ = ak <> AKSetw 0 AKLevelUniv <> AKLevelUniv = AKLevelUniv AKLevelUniv <> ak = AKSetw 0 <> ak ak <> AKLevelUniv = ak <> AKSetw 0 AKSetw i <> AKSetw j = AKSetw (max i j) argumentKind :: Term -> Maybe ArgKind argumentKind (TSet t) = Just (AKSet t) argumentKind (TSetw i) = Just (AKSetw i) argumentKind TLevelUniv = Just AKLevelUniv argumentKind _ = Nothing akLower :: ArgKind -> Term akLower (AKSet t) = TSet t akLower (AKSetw i) = TSetw i akLower AKLevelUniv = TLevelUniv -- freeIn :: Name -> Term -> Bool -- freeIn target = \case -- TSet a -> rec a -- TSetw _ -> False -- TVar n -> n == target -- TPi n a b -> rec a || (if n == target then True else rec b) -- TLam n a b -> rec a || (if n == target then True else rec b) -- TApp a b -> rec a || rec b -- TLift a -> rec a -- TLevel -> False -- TLevelUniv -> False -- TIZero -> False -- TIMax a b -> rec a || rec b -- TISucc a -> rec a -- TAnnot (a :| b) -> rec a || rec b -- TOne -> False -- TUnit -> False -- TSigma n a b -> rec a || (if n == target then True else rec b) -- TPair a b -> rec a || rec b -- TProj1 a -> rec a -- TProj2 a -> rec a -- where -- rec = freeIn target subst :: Name -> Term -> Term -> Term subst target repl = \case TVar n | n == target -> repl TSet a -> TSet (rec a) TSetw i -> TSetw i TVar n -> TVar n TPi n a b -> TPi n (rec a) (if n == target then b else rec b) TLam n a b -> TLam n (rec a) (if n == target then b else rec b) TApp a b -> TApp (rec a) (rec b) TLift a -> TLift (rec a) TLevel -> TLevel TLevelUniv -> TLevelUniv TIZero -> TIZero TIMax a b -> TIMax (rec a) (rec b) TISucc a -> TISucc (rec a) TAnnot (a :| b) -> TAnnot (rec a :| rec b) TOne -> TOne TUnit -> TUnit TSigma n a b -> TSigma n (rec a) (if n == target then b else rec b) TPair a b -> TPair (rec a) (rec b) TProj1 a -> TProj1 (rec a) TProj2 a -> TProj2 (rec a) where rec = subst target repl unify :: Term -> Term -> TM () unify (TSet a) (TSet b) = unify a b unify (TSetw i) (TSetw j) | i == j = return () unify (TVar n) (TVar m) | n == m = return () unify (TPi n a b) (TPi m c d) = unify a c >> unify b (subst m (TVar n) d) unify (TLam n a b) (TLam m c d) = unify a c >> unify b (subst m (TVar n) d) unify (TLift a) (TLift b) = unify a b unify TLevel TLevel = return () unify TLevelUniv TLevelUniv = return () unify TIZero TIZero = return () unify (TIMax a b) (TIMax c d) = unify a c >> unify b d unify (TISucc a) (TISucc b) = unify a b unify TOne TOne = return () unify TUnit TUnit = return () unify (TSigma n a b) (TSigma m c d) = unify a c >> unify b (subst m (TVar n) d) unify (TPair a b) (TPair c d) = unify a c >> unify b d unify (TProj1 a) (TProj1 b) = unify a b unify (TProj2 a) (TProj2 b) = unify a b unify a b = throw $ "Cannot unify:\n- " ++ show a ++ "\n- " ++ show b whnf :: Env -> Term -> Term whnf env = \case TApp (TLam n _ a) b -> whnf env (subst n b a) TIMax a b -> merge (whnf env a) (whnf env b) where -- TODO all of the properties from https://agda.readthedocs.io/en/v2.6.3/language/universe-levels.html#intrinsic-level-properties merge TIZero l = l merge l TIZero = l merge (TISucc l) (TISucc m) = TISucc (merge l m) merge l m = TIMax l m TProj1 (TPair a _) -> a TProj2 (TPair _ b) -> b t -> t