aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-24 11:19:48 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-24 11:19:48 +0100
commit2de6cede93912457babc79bcb0f58c9e6b20f05a (patch)
tree30a1f088a8f73e385e2b296c3bfaf0aac7b6314c
parentdefd0cf1a7620eaecda984a58533661a98595bd3 (diff)
Partially working type checker
-rw-r--r--src/HSVIS/AST.hs2
-rw-r--r--src/HSVIS/Diagnostic.hs2
-rw-r--r--src/HSVIS/Typecheck.hs258
3 files changed, 217 insertions, 45 deletions
diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs
index 058f5ac..2986248 100644
--- a/src/HSVIS/AST.hs
+++ b/src/HSVIS/AST.hs
@@ -100,6 +100,8 @@ data Type s
-- extension point
| TExt (X Type s) !(E Type s)
deriving instance (Show (X Type s), Show (E Type s)) => Show (Type s)
+deriving instance (Eq (X Type s), Eq (E Type s)) => Eq (Type s)
+deriving instance (Ord (X Type s), Ord (E Type s)) => Ord (Type s)
data Pattern s
= PWildcard (X Pattern s)
diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs
index 778fe34..116e4cd 100644
--- a/src/HSVIS/Diagnostic.hs
+++ b/src/HSVIS/Diagnostic.hs
@@ -45,7 +45,7 @@ printDiagnostic :: Diagnostic -> String
printDiagnostic (Diagnostic sev fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) =
let linenum = show (y1 + 1)
locstr = pretty rng
- ncarets | y1 == y2 = max 1 (x2 - x1 + 1)
+ ncarets | y1 == y2 = max 1 (x2 - x1)
| otherwise = length srcline - x1
caretsuffix | y1 == y2 = ""
| otherwise = "..."
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index ad754cf..f292c0e 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -22,7 +22,7 @@ module HSVIS.Typecheck (
) where
import Control.Monad
-import Data.Bifunctor (first, second)
+import Data.Bifunctor (first, second, bimap)
import Data.Foldable (toList)
import Data.List (find, inits)
import Data.Map.Strict (Map)
@@ -32,6 +32,7 @@ import qualified Data.Map.Strict as Map
import Data.Tuple (swap)
import Data.Set (Set)
import qualified Data.Set as Set
+import GHC.Stack
import Debug.Trace
@@ -55,7 +56,7 @@ type instance X Pattern StageTC = CType
type instance X RHS StageTC = CType
type instance X Expr StageTC = CType
-data instance E Type StageTC = TUniVar Int deriving (Show)
+data instance E Type StageTC = TUniVar Int deriving (Show, Eq, Ord)
data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
data instance E TypeSig StageTC deriving (Show)
@@ -71,7 +72,7 @@ type CExpr = Expr StageTC
data StageTyped
-type instance X DataDef StageTyped = TType
+type instance X DataDef StageTyped = TKind
type instance X FunDef StageTyped = TType
type instance X FunEq StageTyped = ()
type instance X Kind StageTyped = ()
@@ -242,24 +243,30 @@ tcTop :: PProgram -> TCM TProgram
tcTop prog = do
(cs, prog') <- collectConstraints Just (tcProgram prog)
(subK, subT) <- solveConstrs cs
- return $ doneProg subK subT prog'
+ let subK' = Map.map (substFinKind mempty) subK
+ subT' = Map.map (substFinType subK' mempty) subT
+ return $ substFinProg subK' subT' prog'
tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs1 fdefs1) = do
+ -- kind-check data definitions and collect ensuing kind constraints
(kconstrs, ddefs2) <- collectConstraints isCEqK $ do
ks <- mapM prepareDataDef ddefs1
zipWithM kcDataDef ks ddefs1
+ -- solve the kind constraints and finalise data types
kinduvars <- solveKindVars kconstrs
let ddefs3 = map (substDdef kinduvars mempty) ddefs2
modifyTEnv (Map.map (substKind kinduvars))
+ -- generate inverse constructors for all data types
forM_ ddefs3 $ \ddef ->
modifyICEnv (Map.fromList (generateInvCons ddef) <>)
traceM (unlines (map pretty ddefs3))
- fdefs2 <- mapM tcFunDef fdefs1
+ -- check the function definitions
+ fdefs2 <- tcFunDefBlock fdefs1
return (Program ddefs3 fdefs2)
@@ -360,6 +367,23 @@ kcType mdown = \case
kcType (Just k2) t
return (TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit
+tcFunDefBlock :: [PFunDef] -> TCM [CFunDef]
+tcFunDefBlock fdefs = do
+ -- generate preliminary unification variables for the functions' types
+ bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) fdefs
+ defs' <- forM fdefs $ \def@(FunDef _ name _ _) ->
+ scopeVEnv $ do
+ modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>)
+ tcFunDef def
+
+ -- take the actual found types for typechecking the body (and link them
+ -- to the variables generated above)
+ let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs'
+ forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) ->
+ emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/
+
+ return defs'
+
tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef rng name msig eqs) = do
when (not $ allEq (fmap (length . funeqPats) eqs)) $
@@ -369,7 +393,9 @@ tcFunDef (FunDef rng name msig eqs) = do
TypeSig sig -> kcType (Just (KType ())) sig
TypeSigExt NoTypeSig -> genUniVar (KType ())
- eqs' <- mapM (tcFunEq typ) eqs
+ eqs' <- scopeVEnv $ do
+ modifyVEnv (Map.insert name typ) -- allow function to be recursive
+ mapM (tcFunEq typ) eqs
return (FunDef typ name (TypeSig typ) eqs')
@@ -386,8 +412,11 @@ tcFunEq down (FunEq rng name pats rhs) = do
tcPattern :: CType -> PPattern -> TCM CPattern
tcPattern down = \case
PWildcard _ -> return $ PWildcard down
+
PVar _ n -> modifyVEnv (Map.insert n down) >> return (PVar down n)
+
PAs _ n p -> modifyVEnv (Map.insert n down) >> tcPattern down p
+
PCon rng n ps ->
getInvCon n >>= \case
Just (InvCon tyvars match fields) -> do
@@ -399,6 +428,7 @@ tcPattern down = \case
Nothing -> do
raise SError rng $ "Constructor not in scope: " ++ pretty n
return (PWildcard down)
+
POp rng p1 op p2 ->
case op of
OCons -> do
@@ -411,11 +441,13 @@ tcPattern down = \case
_ -> do
raise SError rng $ "Operator is not a constructor: " ++ pretty op
return (PWildcard down)
+
PList rng ps -> do
eltty <- genUniVar (KType ())
let listty = TList (KType ()) eltty
emit $ CEq down listty rng
PList listty <$> mapM (tcPattern eltty) ps
+
PTup rng ps -> do
ts <- mapM (\_ -> genUniVar (KType ())) ps
emit $ CEq down (TTup (KType ()) ts) rng
@@ -438,9 +470,15 @@ tcExpr down = \case
emit $ CEq down ty rng
return (ELit ty lit)
- EVar rng n -> EVar <$> getType' rng n <*> pure n
+ EVar rng n -> do
+ ty <- getType' rng n
+ emit $ CEq down ty rng
+ return $ EVar ty n
- ECon rng n -> ECon <$> getType' rng n <*> pure n
+ ECon rng n -> do
+ ty <- getType' rng n
+ emit $ CEq down ty rng
+ return $ EVar ty n
EList rng es -> do
eltty <- genUniVar (KType ())
@@ -495,17 +533,9 @@ tcExpr down = \case
(,) <$> tcPattern ty pat <*> tcRHS down rhs
return $ ECase down e1' alts'
- ELet rng defs body -> do
- bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) defs
- defs' <- forM defs $ \def@(FunDef _ name _ _) ->
- scopeVEnv $ do
- modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>)
- tcFunDef def
- -- take the actual found types for typechecking the body (and linking them
- -- to the variables generated above)
+ ELet _ defs body -> do
+ defs' <- tcFunDefBlock defs
let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs'
- forM_ (zip bound bound2) $ \((_, tvar), (_, ty)) ->
- emit $ CEq ty tvar rng -- in which order? which range? /shrug/
scopeVEnv $ do
modifyVEnv (Map.fromList bound2 <>)
body' <- tcExpr down body
@@ -580,51 +610,191 @@ solveKindVars cs = do
kindSize (KFun () a b) = 1 + kindSize a + kindSize b
kindSize (KExt () KUniVar{}) = 2
+solveTypeVars :: Bag (CType, CType, Range) -> TCM (Map Int CType)
+solveTypeVars cs = do
+ let (asg, errs) =
+ solveConstraints
+ reduce
+ (foldMap pure . typeUniVars)
+ (\m -> substType mempty m mempty)
+ (\case TExt _ (TUniVar v) -> Just v
+ _ -> Nothing)
+ typeSize
+ (toList cs)
+
+ forM_ errs $ \case
+ UEUnequal t1 t2 rng ->
+ raise SError rng $
+ "Type mismatch:\n\
+ \- " ++ pretty t1 ++ "\n\
+ \- " ++ pretty t2
+ UERecursive uvar t rng ->
+ raise SError rng $
+ "Type cannot be recursive: " ++ pretty (TExt (extOf t) (TUniVar uvar)) ++ " = " ++ pretty t
+
+ return asg
+ where
+ reduce :: CType -> CType -> Range -> (Bag (Int, CType, Range), Bag (CType, CType, Range))
+ reduce lhs rhs rng = case (lhs, rhs) of
+ -- unification variables produce constraints on a unification variable
+ (TExt _ (TUniVar i), TExt _ (TUniVar j)) | i == j -> mempty
+ (TExt _ (TUniVar i), t ) -> (pure (i, t, rng), mempty)
+ (t , TExt _ (TUniVar i)) -> (pure (i, t, rng), mempty)
+
+ -- if lhs and rhs have equal prefixes, recurse
+ (TApp _ t ts, TApp _ t' ts') -> reduce t t' rng <> foldMap (\(a, b) -> reduce a b rng) (zip ts ts')
+ (TTup _ ts, TTup _ ts') -> foldMap (\(a, b) -> reduce a b rng) (zip ts ts')
+ (TList _ t, TList _ t') -> reduce t t' rng
+ (TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng
+ (TCon _ n1, TCon _ n2) | n1 == n2 -> mempty
+ (TVar _ n1, TVar _ n2) | n1 == n2 -> mempty
+ (TForall _ n1 t1, TForall k n2 t2) ->
+ reduce t1 (substType mempty mempty (Map.singleton n2 (TVar k n1)) t2) rng
+
+ -- otherwise, this is a kind mismatch
+ (k1, k2) -> (mempty, pure (k1, k2, rng))
+
+ typeSize :: CType -> Int
+ typeSize (TApp _ t ts) = typeSize t + sum (map typeSize ts)
+ typeSize (TTup _ ts) = sum (map typeSize ts)
+ typeSize (TList _ t) = 1 + typeSize t
+ typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2
+ typeSize (TCon _ _) = 1
+ typeSize (TVar _ _) = 1
+ typeSize (TForall _ _ t) = 1 + typeSize t
+ typeSize (TExt _ TUniVar{}) = 2
+
partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range))
partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty)
CEqK k1 k2 r -> (mempty, pure (k1, k2, r))
--- substitute unification variables
-substProg :: Map Int CKind -- ^ Kind unification variable instantiations
- -> Map Int CType -- ^ Type unification variable instantiations
- -> CProgram
- -> CProgram
-substProg = error "substProg"
+-------------------- SUBSTITUTION FUNCTIONS --------------------
+-- These take some of:
+-- - an instantiation map for kind unification variables (Map Int {C,T}Kind)
+-- - an instantiation map for type unification variables (Map Int {C,T}Type)
+-- - an instantiation map for type variables (Map Name CType)
+
+substFinProg :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CProgram -> TProgram
+substFinProg mk mt (Program ds fs) = Program (map (substFinDdef mk mt) ds) (map (substFinFdef mk mt) fs)
+
+substFinDdef :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CDataDef -> TDataDef
+substFinDdef mk mt (DataDef k n ps cs) =
+ DataDef (substFinKind mk k) n (map (first (substFinKind mk)) ps) (map (second (map (substFinType mk mt))) cs)
+
+substFinFdef :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CFunDef -> TFunDef
+substFinFdef mk mt (FunDef t n (TypeSig sig) eqs) =
+ FunDef (substFinType mk mt t) n
+ (TypeSig (substFinType mk mt sig))
+ (fmap (substFinFunEq mk mt) eqs)
+
+substFinFunEq :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CFunEq -> TFunEq
+substFinFunEq mk mt (FunEq () n ps rhs) =
+ FunEq () n
+ (map (substFinPattern mk mt) ps)
+ (substFinRHS mk mt rhs)
+
+substFinRHS :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CRHS -> TRHS
+substFinRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported"
+substFinRHS mk mt (Plain t e) = Plain (substFinType mk mt t) (substFinExpr mk mt e)
+
+substFinPattern :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CPattern -> TPattern
+substFinPattern mk mt = go
+ where
+ go (PWildcard t) = PWildcard (goType t)
+ go (PVar t n) = PVar (goType t) n
+ go (PAs t n p) = PAs (goType t) n (go p)
+ go (PCon t n ps) = PCon (goType t) n (map go ps)
+ go (POp t p1 op p2) = POp (goType t) (go p1) op (go p2)
+ go (PList t ps) = PList (goType t) (map go ps)
+ go (PTup t ps) = PTup (goType t) (map go ps)
+
+ goType = substFinType mk mt
+
+substFinExpr :: HasCallStack
+ => Map Int TKind -> Map Int TType -> CExpr -> TExpr
+substFinExpr mk mt = go
+ where
+ go (ELit t lit) = ELit (goType t) lit
+ go (EVar t n) = EVar (goType t) n
+ go (ECon t n) = ECon (goType t) n
+ go (EList t es) = EList (goType t) (map go es)
+ go (ETup t es) = ETup (goType t) (map go es)
+ go (EApp t e1 es) = EApp (goType t) (go e1) (map go es)
+ go (EOp t e1 op e2) = EOp (goType t) (go e1) op (go e2)
+ go (EIf t e1 e2 e3) = EIf (goType t) (go e1) (go e2) (go e3)
+ go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substFinPattern mk mt) (substFinRHS mk mt)) alts)
+ go (ELet t defs body) = ELet (goType t) (map (substFinFdef mk mt) defs) (go body)
+ go (EError t) = EError (goType t)
+
+ goType = substFinType mk mt
+
+substFinType :: HasCallStack
+ => Map Int TKind -- ^ kind uvars
+ -> Map Int TType -- ^ type uvars
+ -> CType -> TType
+substFinType mk mt = go
+ where
+ go (TApp k t ts) = TApp (substFinKind mk k) (go t) (map go ts)
+ go (TTup k ts) = TTup (substFinKind mk k) (map go ts)
+ go (TList k t) = TList (substFinKind mk k) (go t)
+ go (TFun k t1 t2) = TFun (substFinKind mk k) (go t1) (go t2)
+ go (TCon k n) = TCon (substFinKind mk k) n
+ go (TVar k n) = TVar (substFinKind mk k) n
+ go (TForall k n t) = TForall (substFinKind mk k) n (go t)
+ go t@(TExt _ (TUniVar v)) = fromMaybe (error $ "substFinType: unification variables left: " ++ show t)
+ (Map.lookup v mt)
+
+substFinKind :: HasCallStack => Map Int TKind -> CKind -> TKind
+substFinKind m = \case
+ KType () -> KType ()
+ KFun () k1 k2 -> KFun () (substFinKind m k1) (substFinKind m k2)
+ k@(KExt () (KUniVar v)) -> fromMaybe (error $ "substFinKind: unification variables left: " ++ show k)
+ (Map.lookup v m)
--- substitute unification variables
substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
substDdef mk mt (DataDef k name pars cons) =
DataDef (substKind mk k) name
(map (first (substKind mk)) pars)
(map (second (map (substType mk mt mempty))) cons)
-substType :: Map Int CKind -- ^ kind uvars
- -> Map Int CType -- ^ type uvars
- -> Map Name CType -- ^ type variables
- -> CType -> CType
+substType :: Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType
substType mk mt mtv = go
where
- go (TApp k t ts) = TApp (substKind mk k) (go t) (map go ts)
- go (TTup k ts) = TTup (substKind mk k) (map go ts)
- go (TList k t) = TList (substKind mk k) (go t)
- go (TFun k t1 t2) = TFun (substKind mk k) (go t1) (go t2)
- go (TCon k n) = TCon (substKind mk k) n
- go (TVar k n) = fromMaybe (TVar (substKind mk k) n) (Map.lookup n mtv)
- go (TForall k n t) = TForall (substKind mk k) n (go t)
- go (TExt k (TUniVar v)) = fromMaybe (TExt (substKind mk k) (TUniVar v)) (Map.lookup v mt)
-
--- substitute unification variables
+ go (TApp k t ts) = TApp (goKind k) (go t) (map go ts)
+ go (TTup k ts) = TTup (goKind k) (map go ts)
+ go (TList k t) = TList (goKind k) (go t)
+ go (TFun k t1 t2) = TFun (goKind k) (go t1) (go t2)
+ go (TCon k n) = TCon (goKind k) n
+ go (TVar k n) = fromMaybe (TVar (goKind k) n) (Map.lookup n mtv)
+ go (TForall k n t) = TForall (goKind k) n (go t)
+ go (TExt k (TUniVar v)) = fromMaybe (TExt (goKind k) (TUniVar v)) (Map.lookup v mt)
+
+ goKind = substKind mk
+
substKind :: Map Int CKind -> CKind -> CKind
substKind m = \case
KType () -> KType ()
KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2)
k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m)
-doneProg :: Map Int TKind -- ^ Kind unification variable instantiations
- -> Map Int TType -- ^ Type unification variable instantiations
- -> CProgram
- -> TProgram
-doneProg = error "doneProg"
+-------------------- END OF SUBSTITUTION FUNCTIONS --------------------
+
+typeUniVars :: CType -> Set Int
+typeUniVars = \case
+ TApp _ t ts -> typeUniVars t <> foldMap typeUniVars ts
+ TTup _ ts -> foldMap typeUniVars ts
+ TList _ t -> typeUniVars t
+ TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2
+ TCon _ _ -> mempty
+ TVar _ _ -> mempty
+ TForall _ _ t -> typeUniVars t
+ TExt _ (TUniVar v) -> Set.singleton v
kindUniVars :: CKind -> Set Int
kindUniVars = \case