From 2de6cede93912457babc79bcb0f58c9e6b20f05a Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 24 Mar 2024 11:19:48 +0100 Subject: Partially working type checker --- src/HSVIS/AST.hs | 2 + src/HSVIS/Diagnostic.hs | 2 +- src/HSVIS/Typecheck.hs | 258 +++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 217 insertions(+), 45 deletions(-) (limited to 'src/HSVIS') diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs index 058f5ac..2986248 100644 --- a/src/HSVIS/AST.hs +++ b/src/HSVIS/AST.hs @@ -100,6 +100,8 @@ data Type s -- extension point | TExt (X Type s) !(E Type s) deriving instance (Show (X Type s), Show (E Type s)) => Show (Type s) +deriving instance (Eq (X Type s), Eq (E Type s)) => Eq (Type s) +deriving instance (Ord (X Type s), Ord (E Type s)) => Ord (Type s) data Pattern s = PWildcard (X Pattern s) diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs index 778fe34..116e4cd 100644 --- a/src/HSVIS/Diagnostic.hs +++ b/src/HSVIS/Diagnostic.hs @@ -45,7 +45,7 @@ printDiagnostic :: Diagnostic -> String printDiagnostic (Diagnostic sev fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) = let linenum = show (y1 + 1) locstr = pretty rng - ncarets | y1 == y2 = max 1 (x2 - x1 + 1) + ncarets | y1 == y2 = max 1 (x2 - x1) | otherwise = length srcline - x1 caretsuffix | y1 == y2 = "" | otherwise = "..." diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index ad754cf..f292c0e 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -22,7 +22,7 @@ module HSVIS.Typecheck ( ) where import Control.Monad -import Data.Bifunctor (first, second) +import Data.Bifunctor (first, second, bimap) import Data.Foldable (toList) import Data.List (find, inits) import Data.Map.Strict (Map) @@ -32,6 +32,7 @@ import qualified Data.Map.Strict as Map import Data.Tuple (swap) import Data.Set (Set) import qualified Data.Set as Set +import GHC.Stack import Debug.Trace @@ -55,7 +56,7 @@ type instance X Pattern StageTC = CType type instance X RHS StageTC = CType type instance X Expr StageTC = CType -data instance E Type StageTC = TUniVar Int deriving (Show) +data instance E Type StageTC = TUniVar Int deriving (Show, Eq, Ord) data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) data instance E TypeSig StageTC deriving (Show) @@ -71,7 +72,7 @@ type CExpr = Expr StageTC data StageTyped -type instance X DataDef StageTyped = TType +type instance X DataDef StageTyped = TKind type instance X FunDef StageTyped = TType type instance X FunEq StageTyped = () type instance X Kind StageTyped = () @@ -242,24 +243,30 @@ tcTop :: PProgram -> TCM TProgram tcTop prog = do (cs, prog') <- collectConstraints Just (tcProgram prog) (subK, subT) <- solveConstrs cs - return $ doneProg subK subT prog' + let subK' = Map.map (substFinKind mempty) subK + subT' = Map.map (substFinType subK' mempty) subT + return $ substFinProg subK' subT' prog' tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs1 fdefs1) = do + -- kind-check data definitions and collect ensuing kind constraints (kconstrs, ddefs2) <- collectConstraints isCEqK $ do ks <- mapM prepareDataDef ddefs1 zipWithM kcDataDef ks ddefs1 + -- solve the kind constraints and finalise data types kinduvars <- solveKindVars kconstrs let ddefs3 = map (substDdef kinduvars mempty) ddefs2 modifyTEnv (Map.map (substKind kinduvars)) + -- generate inverse constructors for all data types forM_ ddefs3 $ \ddef -> modifyICEnv (Map.fromList (generateInvCons ddef) <>) traceM (unlines (map pretty ddefs3)) - fdefs2 <- mapM tcFunDef fdefs1 + -- check the function definitions + fdefs2 <- tcFunDefBlock fdefs1 return (Program ddefs3 fdefs2) @@ -360,6 +367,23 @@ kcType mdown = \case kcType (Just k2) t return (TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit +tcFunDefBlock :: [PFunDef] -> TCM [CFunDef] +tcFunDefBlock fdefs = do + -- generate preliminary unification variables for the functions' types + bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) fdefs + defs' <- forM fdefs $ \def@(FunDef _ name _ _) -> + scopeVEnv $ do + modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>) + tcFunDef def + + -- take the actual found types for typechecking the body (and link them + -- to the variables generated above) + let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs' + forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) -> + emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/ + + return defs' + tcFunDef :: PFunDef -> TCM CFunDef tcFunDef (FunDef rng name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ @@ -369,7 +393,9 @@ tcFunDef (FunDef rng name msig eqs) = do TypeSig sig -> kcType (Just (KType ())) sig TypeSigExt NoTypeSig -> genUniVar (KType ()) - eqs' <- mapM (tcFunEq typ) eqs + eqs' <- scopeVEnv $ do + modifyVEnv (Map.insert name typ) -- allow function to be recursive + mapM (tcFunEq typ) eqs return (FunDef typ name (TypeSig typ) eqs') @@ -386,8 +412,11 @@ tcFunEq down (FunEq rng name pats rhs) = do tcPattern :: CType -> PPattern -> TCM CPattern tcPattern down = \case PWildcard _ -> return $ PWildcard down + PVar _ n -> modifyVEnv (Map.insert n down) >> return (PVar down n) + PAs _ n p -> modifyVEnv (Map.insert n down) >> tcPattern down p + PCon rng n ps -> getInvCon n >>= \case Just (InvCon tyvars match fields) -> do @@ -399,6 +428,7 @@ tcPattern down = \case Nothing -> do raise SError rng $ "Constructor not in scope: " ++ pretty n return (PWildcard down) + POp rng p1 op p2 -> case op of OCons -> do @@ -411,11 +441,13 @@ tcPattern down = \case _ -> do raise SError rng $ "Operator is not a constructor: " ++ pretty op return (PWildcard down) + PList rng ps -> do eltty <- genUniVar (KType ()) let listty = TList (KType ()) eltty emit $ CEq down listty rng PList listty <$> mapM (tcPattern eltty) ps + PTup rng ps -> do ts <- mapM (\_ -> genUniVar (KType ())) ps emit $ CEq down (TTup (KType ()) ts) rng @@ -438,9 +470,15 @@ tcExpr down = \case emit $ CEq down ty rng return (ELit ty lit) - EVar rng n -> EVar <$> getType' rng n <*> pure n + EVar rng n -> do + ty <- getType' rng n + emit $ CEq down ty rng + return $ EVar ty n - ECon rng n -> ECon <$> getType' rng n <*> pure n + ECon rng n -> do + ty <- getType' rng n + emit $ CEq down ty rng + return $ EVar ty n EList rng es -> do eltty <- genUniVar (KType ()) @@ -495,17 +533,9 @@ tcExpr down = \case (,) <$> tcPattern ty pat <*> tcRHS down rhs return $ ECase down e1' alts' - ELet rng defs body -> do - bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) defs - defs' <- forM defs $ \def@(FunDef _ name _ _) -> - scopeVEnv $ do - modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>) - tcFunDef def - -- take the actual found types for typechecking the body (and linking them - -- to the variables generated above) + ELet _ defs body -> do + defs' <- tcFunDefBlock defs let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs' - forM_ (zip bound bound2) $ \((_, tvar), (_, ty)) -> - emit $ CEq ty tvar rng -- in which order? which range? /shrug/ scopeVEnv $ do modifyVEnv (Map.fromList bound2 <>) body' <- tcExpr down body @@ -580,51 +610,191 @@ solveKindVars cs = do kindSize (KFun () a b) = 1 + kindSize a + kindSize b kindSize (KExt () KUniVar{}) = 2 +solveTypeVars :: Bag (CType, CType, Range) -> TCM (Map Int CType) +solveTypeVars cs = do + let (asg, errs) = + solveConstraints + reduce + (foldMap pure . typeUniVars) + (\m -> substType mempty m mempty) + (\case TExt _ (TUniVar v) -> Just v + _ -> Nothing) + typeSize + (toList cs) + + forM_ errs $ \case + UEUnequal t1 t2 rng -> + raise SError rng $ + "Type mismatch:\n\ + \- " ++ pretty t1 ++ "\n\ + \- " ++ pretty t2 + UERecursive uvar t rng -> + raise SError rng $ + "Type cannot be recursive: " ++ pretty (TExt (extOf t) (TUniVar uvar)) ++ " = " ++ pretty t + + return asg + where + reduce :: CType -> CType -> Range -> (Bag (Int, CType, Range), Bag (CType, CType, Range)) + reduce lhs rhs rng = case (lhs, rhs) of + -- unification variables produce constraints on a unification variable + (TExt _ (TUniVar i), TExt _ (TUniVar j)) | i == j -> mempty + (TExt _ (TUniVar i), t ) -> (pure (i, t, rng), mempty) + (t , TExt _ (TUniVar i)) -> (pure (i, t, rng), mempty) + + -- if lhs and rhs have equal prefixes, recurse + (TApp _ t ts, TApp _ t' ts') -> reduce t t' rng <> foldMap (\(a, b) -> reduce a b rng) (zip ts ts') + (TTup _ ts, TTup _ ts') -> foldMap (\(a, b) -> reduce a b rng) (zip ts ts') + (TList _ t, TList _ t') -> reduce t t' rng + (TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng + (TCon _ n1, TCon _ n2) | n1 == n2 -> mempty + (TVar _ n1, TVar _ n2) | n1 == n2 -> mempty + (TForall _ n1 t1, TForall k n2 t2) -> + reduce t1 (substType mempty mempty (Map.singleton n2 (TVar k n1)) t2) rng + + -- otherwise, this is a kind mismatch + (k1, k2) -> (mempty, pure (k1, k2, rng)) + + typeSize :: CType -> Int + typeSize (TApp _ t ts) = typeSize t + sum (map typeSize ts) + typeSize (TTup _ ts) = sum (map typeSize ts) + typeSize (TList _ t) = 1 + typeSize t + typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2 + typeSize (TCon _ _) = 1 + typeSize (TVar _ _) = 1 + typeSize (TForall _ _ t) = 1 + typeSize t + typeSize (TExt _ TUniVar{}) = 2 + partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range)) partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty) CEqK k1 k2 r -> (mempty, pure (k1, k2, r)) --- substitute unification variables -substProg :: Map Int CKind -- ^ Kind unification variable instantiations - -> Map Int CType -- ^ Type unification variable instantiations - -> CProgram - -> CProgram -substProg = error "substProg" +-------------------- SUBSTITUTION FUNCTIONS -------------------- +-- These take some of: +-- - an instantiation map for kind unification variables (Map Int {C,T}Kind) +-- - an instantiation map for type unification variables (Map Int {C,T}Type) +-- - an instantiation map for type variables (Map Name CType) + +substFinProg :: HasCallStack + => Map Int TKind -> Map Int TType -> CProgram -> TProgram +substFinProg mk mt (Program ds fs) = Program (map (substFinDdef mk mt) ds) (map (substFinFdef mk mt) fs) + +substFinDdef :: HasCallStack + => Map Int TKind -> Map Int TType -> CDataDef -> TDataDef +substFinDdef mk mt (DataDef k n ps cs) = + DataDef (substFinKind mk k) n (map (first (substFinKind mk)) ps) (map (second (map (substFinType mk mt))) cs) + +substFinFdef :: HasCallStack + => Map Int TKind -> Map Int TType -> CFunDef -> TFunDef +substFinFdef mk mt (FunDef t n (TypeSig sig) eqs) = + FunDef (substFinType mk mt t) n + (TypeSig (substFinType mk mt sig)) + (fmap (substFinFunEq mk mt) eqs) + +substFinFunEq :: HasCallStack + => Map Int TKind -> Map Int TType -> CFunEq -> TFunEq +substFinFunEq mk mt (FunEq () n ps rhs) = + FunEq () n + (map (substFinPattern mk mt) ps) + (substFinRHS mk mt rhs) + +substFinRHS :: HasCallStack + => Map Int TKind -> Map Int TType -> CRHS -> TRHS +substFinRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported" +substFinRHS mk mt (Plain t e) = Plain (substFinType mk mt t) (substFinExpr mk mt e) + +substFinPattern :: HasCallStack + => Map Int TKind -> Map Int TType -> CPattern -> TPattern +substFinPattern mk mt = go + where + go (PWildcard t) = PWildcard (goType t) + go (PVar t n) = PVar (goType t) n + go (PAs t n p) = PAs (goType t) n (go p) + go (PCon t n ps) = PCon (goType t) n (map go ps) + go (POp t p1 op p2) = POp (goType t) (go p1) op (go p2) + go (PList t ps) = PList (goType t) (map go ps) + go (PTup t ps) = PTup (goType t) (map go ps) + + goType = substFinType mk mt + +substFinExpr :: HasCallStack + => Map Int TKind -> Map Int TType -> CExpr -> TExpr +substFinExpr mk mt = go + where + go (ELit t lit) = ELit (goType t) lit + go (EVar t n) = EVar (goType t) n + go (ECon t n) = ECon (goType t) n + go (EList t es) = EList (goType t) (map go es) + go (ETup t es) = ETup (goType t) (map go es) + go (EApp t e1 es) = EApp (goType t) (go e1) (map go es) + go (EOp t e1 op e2) = EOp (goType t) (go e1) op (go e2) + go (EIf t e1 e2 e3) = EIf (goType t) (go e1) (go e2) (go e3) + go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substFinPattern mk mt) (substFinRHS mk mt)) alts) + go (ELet t defs body) = ELet (goType t) (map (substFinFdef mk mt) defs) (go body) + go (EError t) = EError (goType t) + + goType = substFinType mk mt + +substFinType :: HasCallStack + => Map Int TKind -- ^ kind uvars + -> Map Int TType -- ^ type uvars + -> CType -> TType +substFinType mk mt = go + where + go (TApp k t ts) = TApp (substFinKind mk k) (go t) (map go ts) + go (TTup k ts) = TTup (substFinKind mk k) (map go ts) + go (TList k t) = TList (substFinKind mk k) (go t) + go (TFun k t1 t2) = TFun (substFinKind mk k) (go t1) (go t2) + go (TCon k n) = TCon (substFinKind mk k) n + go (TVar k n) = TVar (substFinKind mk k) n + go (TForall k n t) = TForall (substFinKind mk k) n (go t) + go t@(TExt _ (TUniVar v)) = fromMaybe (error $ "substFinType: unification variables left: " ++ show t) + (Map.lookup v mt) + +substFinKind :: HasCallStack => Map Int TKind -> CKind -> TKind +substFinKind m = \case + KType () -> KType () + KFun () k1 k2 -> KFun () (substFinKind m k1) (substFinKind m k2) + k@(KExt () (KUniVar v)) -> fromMaybe (error $ "substFinKind: unification variables left: " ++ show k) + (Map.lookup v m) --- substitute unification variables substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef substDdef mk mt (DataDef k name pars cons) = DataDef (substKind mk k) name (map (first (substKind mk)) pars) (map (second (map (substType mk mt mempty))) cons) -substType :: Map Int CKind -- ^ kind uvars - -> Map Int CType -- ^ type uvars - -> Map Name CType -- ^ type variables - -> CType -> CType +substType :: Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType substType mk mt mtv = go where - go (TApp k t ts) = TApp (substKind mk k) (go t) (map go ts) - go (TTup k ts) = TTup (substKind mk k) (map go ts) - go (TList k t) = TList (substKind mk k) (go t) - go (TFun k t1 t2) = TFun (substKind mk k) (go t1) (go t2) - go (TCon k n) = TCon (substKind mk k) n - go (TVar k n) = fromMaybe (TVar (substKind mk k) n) (Map.lookup n mtv) - go (TForall k n t) = TForall (substKind mk k) n (go t) - go (TExt k (TUniVar v)) = fromMaybe (TExt (substKind mk k) (TUniVar v)) (Map.lookup v mt) - --- substitute unification variables + go (TApp k t ts) = TApp (goKind k) (go t) (map go ts) + go (TTup k ts) = TTup (goKind k) (map go ts) + go (TList k t) = TList (goKind k) (go t) + go (TFun k t1 t2) = TFun (goKind k) (go t1) (go t2) + go (TCon k n) = TCon (goKind k) n + go (TVar k n) = fromMaybe (TVar (goKind k) n) (Map.lookup n mtv) + go (TForall k n t) = TForall (goKind k) n (go t) + go (TExt k (TUniVar v)) = fromMaybe (TExt (goKind k) (TUniVar v)) (Map.lookup v mt) + + goKind = substKind mk + substKind :: Map Int CKind -> CKind -> CKind substKind m = \case KType () -> KType () KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2) k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m) -doneProg :: Map Int TKind -- ^ Kind unification variable instantiations - -> Map Int TType -- ^ Type unification variable instantiations - -> CProgram - -> TProgram -doneProg = error "doneProg" +-------------------- END OF SUBSTITUTION FUNCTIONS -------------------- + +typeUniVars :: CType -> Set Int +typeUniVars = \case + TApp _ t ts -> typeUniVars t <> foldMap typeUniVars ts + TTup _ ts -> foldMap typeUniVars ts + TList _ t -> typeUniVars t + TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2 + TCon _ _ -> mempty + TVar _ _ -> mempty + TForall _ _ t -> typeUniVars t + TExt _ (TUniVar v) -> Set.singleton v kindUniVars :: CKind -> Set Int kindUniVars = \case -- cgit v1.2.3-70-g09d2