From 2de6cede93912457babc79bcb0f58c9e6b20f05a Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 24 Mar 2024 11:19:48 +0100
Subject: Partially working type checker

---
 src/HSVIS/AST.hs        |   2 +
 src/HSVIS/Diagnostic.hs |   2 +-
 src/HSVIS/Typecheck.hs  | 258 +++++++++++++++++++++++++++++++++++++++---------
 3 files changed, 217 insertions(+), 45 deletions(-)

(limited to 'src/HSVIS')

diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs
index 058f5ac..2986248 100644
--- a/src/HSVIS/AST.hs
+++ b/src/HSVIS/AST.hs
@@ -100,6 +100,8 @@ data Type s
   -- extension point
   | TExt (X Type s) !(E Type s)
 deriving instance (Show (X Type s), Show (E Type s)) => Show (Type s)
+deriving instance (Eq (X Type s), Eq (E Type s)) => Eq (Type s)
+deriving instance (Ord (X Type s), Ord (E Type s)) => Ord (Type s)
 
 data Pattern s
   = PWildcard (X Pattern s)
diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs
index 778fe34..116e4cd 100644
--- a/src/HSVIS/Diagnostic.hs
+++ b/src/HSVIS/Diagnostic.hs
@@ -45,7 +45,7 @@ printDiagnostic :: Diagnostic -> String
 printDiagnostic (Diagnostic sev fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) =
   let linenum = show (y1 + 1)
       locstr = pretty rng
-      ncarets | y1 == y2  = max 1 (x2 - x1 + 1)
+      ncarets | y1 == y2  = max 1 (x2 - x1)
               | otherwise = length srcline - x1
       caretsuffix | y1 == y2  = ""
                   | otherwise = "..."
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index ad754cf..f292c0e 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -22,7 +22,7 @@ module HSVIS.Typecheck (
 ) where
 
 import Control.Monad
-import Data.Bifunctor (first, second)
+import Data.Bifunctor (first, second, bimap)
 import Data.Foldable (toList)
 import Data.List (find, inits)
 import Data.Map.Strict (Map)
@@ -32,6 +32,7 @@ import qualified Data.Map.Strict as Map
 import Data.Tuple (swap)
 import Data.Set (Set)
 import qualified Data.Set as Set
+import GHC.Stack
 
 import Debug.Trace
 
@@ -55,7 +56,7 @@ type instance X Pattern StageTC = CType
 type instance X RHS     StageTC = CType
 type instance X Expr    StageTC = CType
 
-data instance E Type StageTC = TUniVar Int deriving (Show)
+data instance E Type StageTC = TUniVar Int deriving (Show, Eq, Ord)
 data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
 data instance E TypeSig StageTC deriving (Show)
 
@@ -71,7 +72,7 @@ type CExpr    = Expr    StageTC
 
 data StageTyped
 
-type instance X DataDef StageTyped = TType
+type instance X DataDef StageTyped = TKind
 type instance X FunDef  StageTyped = TType
 type instance X FunEq   StageTyped = ()
 type instance X Kind    StageTyped = ()
@@ -242,24 +243,30 @@ tcTop :: PProgram -> TCM TProgram
 tcTop prog = do
   (cs, prog') <- collectConstraints Just (tcProgram prog)
   (subK, subT) <- solveConstrs cs
-  return $ doneProg subK subT prog'
+  let subK' = Map.map (substFinKind mempty) subK
+      subT' = Map.map (substFinType subK' mempty) subT
+  return $ substFinProg subK' subT' prog'
 
 tcProgram :: PProgram -> TCM CProgram
 tcProgram (Program ddefs1 fdefs1) = do
+  -- kind-check data definitions and collect ensuing kind constraints
   (kconstrs, ddefs2) <- collectConstraints isCEqK $ do
     ks <- mapM prepareDataDef ddefs1
     zipWithM kcDataDef ks ddefs1
 
+  -- solve the kind constraints and finalise data types
   kinduvars <- solveKindVars kconstrs
   let ddefs3 = map (substDdef kinduvars mempty) ddefs2
   modifyTEnv (Map.map (substKind kinduvars))
 
+  -- generate inverse constructors for all data types
   forM_ ddefs3 $ \ddef ->
     modifyICEnv (Map.fromList (generateInvCons ddef) <>)
 
   traceM (unlines (map pretty ddefs3))
 
-  fdefs2 <- mapM tcFunDef fdefs1
+  -- check the function definitions
+  fdefs2 <- tcFunDefBlock fdefs1
 
   return (Program ddefs3 fdefs2)
 
@@ -360,6 +367,23 @@ kcType mdown = \case
       kcType (Just k2) t
     return (TForall k2 n t')  -- not 'k1 -> k2' because the forall is implicit
 
+tcFunDefBlock :: [PFunDef] -> TCM [CFunDef]
+tcFunDefBlock fdefs = do
+  -- generate preliminary unification variables for the functions' types
+  bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) fdefs
+  defs' <- forM fdefs $ \def@(FunDef _ name _ _) ->
+    scopeVEnv $ do
+      modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>)
+      tcFunDef def
+
+  -- take the actual found types for typechecking the body (and link them
+  -- to the variables generated above)
+  let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs'
+  forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) ->
+    emit $ CEq ty tvar (extOf fdef)  -- which is expected/observed? which range? /shrug/
+
+  return defs'
+
 tcFunDef :: PFunDef -> TCM CFunDef
 tcFunDef (FunDef rng name msig eqs) = do
   when (not $ allEq (fmap (length . funeqPats) eqs)) $
@@ -369,7 +393,9 @@ tcFunDef (FunDef rng name msig eqs) = do
     TypeSig sig -> kcType (Just (KType ())) sig
     TypeSigExt NoTypeSig -> genUniVar (KType ())
 
-  eqs' <- mapM (tcFunEq typ) eqs
+  eqs' <- scopeVEnv $ do
+    modifyVEnv (Map.insert name typ)  -- allow function to be recursive
+    mapM (tcFunEq typ) eqs
 
   return (FunDef typ name (TypeSig typ) eqs')
 
@@ -386,8 +412,11 @@ tcFunEq down (FunEq rng name pats rhs) = do
 tcPattern :: CType -> PPattern -> TCM CPattern
 tcPattern down = \case
   PWildcard _ -> return $ PWildcard down
+
   PVar _ n -> modifyVEnv (Map.insert n down) >> return (PVar down n)
+
   PAs _ n p -> modifyVEnv (Map.insert n down) >> tcPattern down p
+
   PCon rng n ps ->
     getInvCon n >>= \case
       Just (InvCon tyvars match fields) -> do
@@ -399,6 +428,7 @@ tcPattern down = \case
       Nothing -> do
         raise SError rng $ "Constructor not in scope: " ++ pretty n
         return (PWildcard down)
+
   POp rng p1 op p2 ->
     case op of
       OCons -> do
@@ -411,11 +441,13 @@ tcPattern down = \case
       _ -> do
         raise SError rng $ "Operator is not a constructor: " ++ pretty op
         return (PWildcard down)
+
   PList rng ps -> do
     eltty <- genUniVar (KType ())
     let listty = TList (KType ()) eltty
     emit $ CEq down listty rng
     PList listty <$> mapM (tcPattern eltty) ps
+
   PTup rng ps -> do
     ts <- mapM (\_ -> genUniVar (KType ())) ps
     emit $ CEq down (TTup (KType ()) ts) rng
@@ -438,9 +470,15 @@ tcExpr down = \case
     emit $ CEq down ty rng
     return (ELit ty lit)
 
-  EVar rng n -> EVar <$> getType' rng n <*> pure n
+  EVar rng n -> do
+    ty <- getType' rng n
+    emit $ CEq down ty rng
+    return $ EVar ty n
 
-  ECon rng n -> ECon <$> getType' rng n <*> pure n
+  ECon rng n -> do
+    ty <- getType' rng n
+    emit $ CEq down ty rng
+    return $ EVar ty n
 
   EList rng es -> do
     eltty <- genUniVar (KType ())
@@ -495,17 +533,9 @@ tcExpr down = \case
         (,) <$> tcPattern ty pat <*> tcRHS down rhs
     return $ ECase down e1' alts'
 
-  ELet rng defs body -> do
-    bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) defs
-    defs' <- forM defs $ \def@(FunDef _ name _ _) ->
-      scopeVEnv $ do
-        modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>)
-        tcFunDef def
-    -- take the actual found types for typechecking the body (and linking them
-    -- to the variables generated above)
+  ELet _ defs body -> do
+    defs' <- tcFunDefBlock defs
     let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs'
-    forM_ (zip bound bound2) $ \((_, tvar), (_, ty)) ->
-      emit $ CEq ty tvar rng  -- in which order? which range? /shrug/
     scopeVEnv $ do
       modifyVEnv (Map.fromList bound2 <>)
       body' <- tcExpr down body
@@ -580,51 +610,191 @@ solveKindVars cs = do
     kindSize (KFun () a b) = 1 + kindSize a + kindSize b
     kindSize (KExt () KUniVar{}) = 2
 
+solveTypeVars :: Bag (CType, CType, Range) -> TCM (Map Int CType)
+solveTypeVars cs = do
+  let (asg, errs) =
+        solveConstraints
+          reduce
+          (foldMap pure . typeUniVars)
+          (\m -> substType mempty m mempty)
+          (\case TExt _ (TUniVar v) -> Just v
+                 _ -> Nothing)
+          typeSize
+          (toList cs)
+
+  forM_ errs $ \case
+    UEUnequal t1 t2 rng ->
+      raise SError rng $
+        "Type mismatch:\n\
+        \- " ++ pretty t1 ++ "\n\
+        \- " ++ pretty t2
+    UERecursive uvar t rng ->
+      raise SError rng $
+        "Type cannot be recursive: " ++ pretty (TExt (extOf t) (TUniVar uvar)) ++ " = " ++ pretty t
+
+  return asg
+  where
+    reduce :: CType -> CType -> Range -> (Bag (Int, CType, Range), Bag (CType, CType, Range))
+    reduce lhs rhs rng = case (lhs, rhs) of
+      -- unification variables produce constraints on a unification variable
+      (TExt _ (TUniVar i), TExt _ (TUniVar j)) | i == j -> mempty
+      (TExt _ (TUniVar i), t                 ) -> (pure (i, t, rng), mempty)
+      (t                 , TExt _ (TUniVar i)) -> (pure (i, t, rng), mempty)
+
+      -- if lhs and rhs have equal prefixes, recurse
+      (TApp _ t ts, TApp _ t' ts') -> reduce t t' rng <> foldMap (\(a, b) -> reduce a b rng) (zip ts ts')
+      (TTup _ ts, TTup _ ts') -> foldMap (\(a, b) -> reduce a b rng) (zip ts ts')
+      (TList _ t, TList _ t') -> reduce t t' rng
+      (TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng
+      (TCon _ n1, TCon _ n2) | n1 == n2 -> mempty
+      (TVar _ n1, TVar _ n2) | n1 == n2 -> mempty
+      (TForall _ n1 t1, TForall k n2 t2) ->
+        reduce t1 (substType mempty mempty (Map.singleton n2 (TVar k n1)) t2) rng
+
+      -- otherwise, this is a kind mismatch
+      (k1, k2) -> (mempty, pure (k1, k2, rng))
+
+    typeSize :: CType -> Int
+    typeSize (TApp _ t ts) = typeSize t + sum (map typeSize ts)
+    typeSize (TTup _ ts) = sum (map typeSize ts)
+    typeSize (TList _ t) = 1 + typeSize t
+    typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2
+    typeSize (TCon _ _) = 1
+    typeSize (TVar _ _) = 1
+    typeSize (TForall _ _ t) = 1 + typeSize t
+    typeSize (TExt _ TUniVar{}) = 2
+
 partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range))
 partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty)
                                    CEqK k1 k2 r -> (mempty, pure (k1, k2, r))
 
--- substitute unification variables
-substProg :: Map Int CKind  -- ^ Kind unification variable instantiations
-          -> Map Int CType  -- ^ Type unification variable instantiations
-          -> CProgram
-          -> CProgram
-substProg = error "substProg"
+-------------------- SUBSTITUTION FUNCTIONS --------------------
+-- These take some of:
+-- - an instantiation map for kind unification variables (Map Int {C,T}Kind)
+-- - an instantiation map for type unification variables (Map Int {C,T}Type)
+-- - an instantiation map for type variables (Map Name CType)
+
+substFinProg :: HasCallStack
+             => Map Int TKind -> Map Int TType -> CProgram -> TProgram
+substFinProg mk mt (Program ds fs) = Program (map (substFinDdef mk mt) ds) (map (substFinFdef mk mt) fs)
+
+substFinDdef :: HasCallStack
+             => Map Int TKind -> Map Int TType -> CDataDef -> TDataDef
+substFinDdef mk mt (DataDef k n ps cs) =
+  DataDef (substFinKind mk k) n (map (first (substFinKind mk)) ps) (map (second (map (substFinType mk mt))) cs)
+
+substFinFdef :: HasCallStack
+             => Map Int TKind -> Map Int TType -> CFunDef -> TFunDef
+substFinFdef mk mt (FunDef t n (TypeSig sig) eqs) =
+  FunDef (substFinType mk mt t) n
+         (TypeSig (substFinType mk mt sig))
+         (fmap (substFinFunEq mk mt) eqs)
+
+substFinFunEq :: HasCallStack
+              => Map Int TKind -> Map Int TType -> CFunEq -> TFunEq
+substFinFunEq mk mt (FunEq () n ps rhs) =
+  FunEq () n
+        (map (substFinPattern mk mt) ps)
+        (substFinRHS mk mt rhs)
+
+substFinRHS :: HasCallStack
+            => Map Int TKind -> Map Int TType -> CRHS -> TRHS
+substFinRHS _  _  (Guarded _ _) = error "typecheck: guards unsupported"
+substFinRHS mk mt (Plain t e) = Plain (substFinType mk mt t) (substFinExpr mk mt e)
+
+substFinPattern :: HasCallStack
+                => Map Int TKind -> Map Int TType -> CPattern -> TPattern
+substFinPattern mk mt = go
+  where
+    go (PWildcard t) = PWildcard (goType t)
+    go (PVar t n) = PVar (goType t) n
+    go (PAs t n p) = PAs (goType t) n (go p)
+    go (PCon t n ps) = PCon (goType t) n (map go ps)
+    go (POp t p1 op p2) = POp (goType t) (go p1) op (go p2)
+    go (PList t ps) = PList (goType t) (map go ps)
+    go (PTup t ps) = PTup (goType t) (map go ps)
+
+    goType = substFinType mk mt
+
+substFinExpr :: HasCallStack
+             => Map Int TKind -> Map Int TType -> CExpr -> TExpr
+substFinExpr mk mt = go
+  where
+    go (ELit t lit) = ELit (goType t) lit
+    go (EVar t n) = EVar (goType t) n
+    go (ECon t n) = ECon (goType t) n
+    go (EList t es) = EList (goType t) (map go es)
+    go (ETup t es) = ETup (goType t) (map go es)
+    go (EApp t e1 es) = EApp (goType t) (go e1) (map go es)
+    go (EOp t e1 op e2) = EOp (goType t) (go e1) op (go e2)
+    go (EIf t e1 e2 e3) = EIf (goType t) (go e1) (go e2) (go e3)
+    go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substFinPattern mk mt) (substFinRHS mk mt)) alts)
+    go (ELet t defs body) = ELet (goType t) (map (substFinFdef mk mt) defs) (go body)
+    go (EError t) = EError (goType t)
+
+    goType = substFinType mk mt
+
+substFinType :: HasCallStack
+             => Map Int TKind  -- ^ kind uvars
+             -> Map Int TType  -- ^ type uvars
+             -> CType -> TType
+substFinType mk mt = go
+  where
+    go (TApp k t ts) = TApp (substFinKind mk k) (go t) (map go ts)
+    go (TTup k ts) = TTup (substFinKind mk k) (map go ts)
+    go (TList k t) = TList (substFinKind mk k) (go t)
+    go (TFun k t1 t2) = TFun (substFinKind mk k) (go t1) (go t2)
+    go (TCon k n) = TCon (substFinKind mk k) n
+    go (TVar k n) = TVar (substFinKind mk k) n
+    go (TForall k n t) = TForall (substFinKind mk k) n (go t)
+    go t@(TExt _ (TUniVar v)) = fromMaybe (error $ "substFinType: unification variables left: " ++ show t)
+                                          (Map.lookup v mt)
+
+substFinKind :: HasCallStack => Map Int TKind -> CKind -> TKind
+substFinKind m = \case
+  KType () -> KType ()
+  KFun () k1 k2 -> KFun () (substFinKind m k1) (substFinKind m k2)
+  k@(KExt () (KUniVar v)) -> fromMaybe (error $ "substFinKind: unification variables left: " ++ show k)
+                                       (Map.lookup v m)
 
--- substitute unification variables
 substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
 substDdef mk mt (DataDef k name pars cons) =
   DataDef (substKind mk k) name
           (map (first (substKind mk)) pars)
           (map (second (map (substType mk mt mempty))) cons)
 
-substType :: Map Int CKind  -- ^ kind uvars
-          -> Map Int CType  -- ^ type uvars
-          -> Map Name CType  -- ^ type variables
-          -> CType -> CType
+substType :: Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType
 substType mk mt mtv = go
   where
-    go (TApp k t ts) = TApp (substKind mk k) (go t) (map go ts)
-    go (TTup k ts) = TTup (substKind mk k) (map go ts)
-    go (TList k t) = TList (substKind mk k) (go t)
-    go (TFun k t1 t2) = TFun (substKind mk k) (go t1) (go t2)
-    go (TCon k n) = TCon (substKind mk k) n
-    go (TVar k n) = fromMaybe (TVar (substKind mk k) n) (Map.lookup n mtv)
-    go (TForall k n t) = TForall (substKind mk k) n (go t)
-    go (TExt k (TUniVar v)) = fromMaybe (TExt (substKind mk k) (TUniVar v)) (Map.lookup v mt)
-
--- substitute unification variables
+    go (TApp k t ts) = TApp (goKind k) (go t) (map go ts)
+    go (TTup k ts) = TTup (goKind k) (map go ts)
+    go (TList k t) = TList (goKind k) (go t)
+    go (TFun k t1 t2) = TFun (goKind k) (go t1) (go t2)
+    go (TCon k n) = TCon (goKind k) n
+    go (TVar k n) = fromMaybe (TVar (goKind k) n) (Map.lookup n mtv)
+    go (TForall k n t) = TForall (goKind k) n (go t)
+    go (TExt k (TUniVar v)) = fromMaybe (TExt (goKind k) (TUniVar v)) (Map.lookup v mt)
+
+    goKind = substKind mk
+
 substKind :: Map Int CKind -> CKind -> CKind
 substKind m = \case
   KType () -> KType ()
   KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2)
   k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m)
 
-doneProg :: Map Int TKind  -- ^ Kind unification variable instantiations
-         -> Map Int TType  -- ^ Type unification variable instantiations
-         -> CProgram
-         -> TProgram
-doneProg = error "doneProg"
+-------------------- END OF SUBSTITUTION FUNCTIONS --------------------
+
+typeUniVars :: CType -> Set Int
+typeUniVars = \case
+  TApp _ t ts -> typeUniVars t <> foldMap typeUniVars ts
+  TTup _ ts -> foldMap typeUniVars ts
+  TList _ t -> typeUniVars t
+  TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2
+  TCon _ _ -> mempty
+  TVar _ _ -> mempty
+  TForall _ _ t -> typeUniVars t
+  TExt _ (TUniVar v) -> Set.singleton v
 
 kindUniVars :: CKind -> Set Int
 kindUniVars = \case
-- 
cgit v1.2.3-70-g09d2