aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--hs-visinter.cabal1
-rw-r--r--src/Data/Map/Monoidal.hs31
-rw-r--r--src/HSVIS/AST.hs4
-rw-r--r--src/HSVIS/Typecheck.hs206
4 files changed, 142 insertions, 100 deletions
diff --git a/hs-visinter.cabal b/hs-visinter.cabal
index 48d3456..118882f 100644
--- a/hs-visinter.cabal
+++ b/hs-visinter.cabal
@@ -11,6 +11,7 @@ library
exposed-modules:
Control.FAlternative
Data.Bag
+ Data.Map.Monoidal
Data.List.NonEmpty.Util
HSVIS.AST
HSVIS.Diagnostic
diff --git a/src/Data/Map/Monoidal.hs b/src/Data/Map/Monoidal.hs
new file mode 100644
index 0000000..7007934
--- /dev/null
+++ b/src/Data/Map/Monoidal.hs
@@ -0,0 +1,31 @@
+{-# LANGUAGE DeriveTraversable #-}
+module Data.Map.Monoidal where
+
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+import Data.Maybe (fromMaybe)
+
+
+newtype MMap k v = MMap (Map k v)
+ deriving (Show, Functor, Foldable, Traversable)
+
+instance (Ord k, Semigroup v) => Semigroup (MMap k v) where
+ MMap m1 <> MMap m2 = MMap (Map.unionWith (<>) m1 m2)
+
+instance (Ord k, Semigroup v) => Monoid (MMap k v) where
+ mempty = MMap Map.empty
+
+fromList :: (Ord k, Semigroup v) => [(k, v)] -> MMap k v
+fromList l = MMap (Map.fromListWith (<>) l)
+
+singleton :: k -> v -> MMap k v
+singleton k v = MMap (Map.singleton k v)
+
+lookup :: (Ord k, Monoid v) => k -> MMap k v -> v
+lookup k (MMap m) = fromMaybe mempty (Map.lookup k m)
+
+lookup' :: Ord k => k -> MMap k v -> Maybe v
+lookup' k (MMap m) = Map.lookup k m
+
+insert :: (Ord k, Semigroup v) => k -> v -> MMap k v -> MMap k v
+insert k v (MMap m) = MMap (Map.insertWith (<>) k v m)
diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs
index 2986248..91e08eb 100644
--- a/src/HSVIS/AST.hs
+++ b/src/HSVIS/AST.hs
@@ -95,7 +95,7 @@ data Type s
| TFun (X Type s) (Type s) (Type s)
| TCon (X Type s) Name
| TVar (X Type s) Name
- | TForall (X Type s) Name (Type s) -- ^ implicit
+ | TForall (X Type s) Name (Type s) -- ^ implicit; also, not parsed
-- extension point
| TExt (X Type s) !(E Type s)
@@ -161,7 +161,7 @@ instance Pretty (E Type s) => Pretty (Type s) where
prettysPrec _ (TCon _ n) = prettysPrec 11 n
prettysPrec _ (TVar _ n) = prettysPrec 11 n
prettysPrec d (TForall _ n t) = showParen (d > -1) $
- showString "forall " . prettysPrec 11 n . showString "." . prettysPrec (-1) t
+ showString "forall " . prettysPrec 11 n . showString ". " . prettysPrec (-1) t
prettysPrec d (TExt _ e) = prettysPrec d e
instance (Pretty (X Type s), Pretty (E Type s)) => Pretty (DataDef s) where
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index f292c0e..2dd103f 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GADTs #-}
@@ -11,6 +12,7 @@
{-# OPTIONS -Wno-unused-top-binds #-}
{-# OPTIONS -Wno-unused-imports #-}
+{-# LANGUAGE DataKinds #-}
@@ -30,6 +32,7 @@ import Data.Maybe (fromMaybe)
import Data.Monoid (Ap(..))
import qualified Data.Map.Strict as Map
import Data.Tuple (swap)
+import Data.Semigroup (First(..))
import Data.Set (Set)
import qualified Data.Set as Set
import GHC.Stack
@@ -38,6 +41,8 @@ import Debug.Trace
import Data.Bag
import Data.List.NonEmpty.Util
+import Data.Map.Monoidal (MMap(..))
+import qualified Data.Map.Monoidal as MMap
import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
@@ -153,9 +158,12 @@ instance Monad TCM where
(ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)
-raise :: Severity -> Range -> String -> TCM ()
-raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
- (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ())
+class Monad m => MonadRaise m where
+ raise :: Severity -> Range -> String -> m ()
+
+instance MonadRaise TCM where
+ raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
+ (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ())
emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())
@@ -243,9 +251,7 @@ tcTop :: PProgram -> TCM TProgram
tcTop prog = do
(cs, prog') <- collectConstraints Just (tcProgram prog)
(subK, subT) <- solveConstrs cs
- let subK' = Map.map (substFinKind mempty) subK
- subT' = Map.map (substFinType subK' mempty) subT
- return $ substFinProg subK' subT' prog'
+ return $ finaliseProg (substProg subK subT prog')
tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs1 fdefs1) = do
@@ -296,7 +302,7 @@ kcDataDef (kd, parkinds) (DataDef _ name params cons) = do
cons' <- scopeTEnv $ do
modifyTEnv (Map.fromList (zip params' parkinds) <>)
forM cons $ \(cname, fieldtys) -> do
- fieldtys' <- mapM (kcType (Just (KType ()))) fieldtys
+ fieldtys' <- mapM (kcType KCTMNormal (Just (KType ()))) fieldtys
return (cname, fieldtys')
return (DataDef kd name (zip parkinds params') cons')
@@ -315,57 +321,74 @@ downEqK :: Range -> Maybe CKind -> CKind -> TCM ()
downEqK _ Nothing _ = return ()
downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng
+data KCTypeMode ext ret where
+ -- | Kind-check a normal type: out-of-scope type variables are reported as errors.
+ KCTMNormal :: KCTypeMode () CType
+
+ -- | Kind-check an open type: out-of-scope type variables are returned. This
+ -- is used to check function type signatures, which may have an implicit
+ -- forall telescope at the head.
+ KCTMOpen :: KCTypeMode (MMap Name (First CKind)) (CType, Map Name CKind)
+
+-- | Given (maybe) the expected kind of this type, and a type, check it for
+-- kind-correctness.
+kcType :: forall ext ret. KCTypeMode ext ret -> Maybe CKind -> PType -> TCM ret
+kcType KCTMNormal mdown t = snd <$> kcType' KCTMNormal mdown t
+kcType KCTMOpen mdown t = second (\(MMap m) -> Map.map getFirst m) . swap <$> kcType' KCTMOpen mdown t
+
-- | Given (maybe) the expected kind of this type, and a type, check it for
-- kind-correctness.
-kcType :: Maybe CKind -> PType -> TCM CType
-kcType mdown = \case
+kcType' :: forall ext ret. Monoid ext => KCTypeMode ext ret -> Maybe CKind -> PType -> TCM (ext, CType)
+kcType' mode mdown = \case
TApp rng t ts -> do
- t' <- kcType Nothing t
- ts' <- mapM (kcType Nothing) ts
+ (ext1, t') <- kcType' mode Nothing t
+ (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts
retk <- promoteDownK mdown
let expected = foldr (KFun ()) retk (map extOf ts')
emit $ CEqK (extOf t') expected rng
- return (TApp retk t' ts')
+ return (ext1 <> ext2, TApp retk t' ts')
TTup rng ts -> do
- ts' <- mapM (kcType (Just (KType ()))) ts
+ (ext, ts') <- sequence <$> mapM (kcType' mode (Just (KType ()))) ts
forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
emit $ CEqK (extOf ct) (KType ()) trng
downEqK rng mdown (KType ())
- return (TTup (KType ()) ts')
+ return (ext, TTup (KType ()) ts')
TList rng t -> do
- t' <- kcType (Just (KType ())) t
+ (ext, t') <- kcType' mode (Just (KType ())) t
emit $ CEqK (extOf t') (KType ()) (extOf t)
downEqK rng mdown (KType ())
- return (TList (KType ()) t')
+ return (ext, TList (KType ()) t')
TFun rng t1 t2 -> do
- t1' <- kcType (Just (KType ())) t1
- t2' <- kcType (Just (KType ())) t2
+ (ext1, t1') <- kcType' mode (Just (KType ())) t1
+ (ext2, t2') <- kcType' mode (Just (KType ())) t2
emit $ CEqK (extOf t1') (KType ()) (extOf t1)
emit $ CEqK (extOf t2') (KType ()) (extOf t2)
downEqK rng mdown (KType ())
- return (TFun (KType ()) t1' t2')
+ return (ext1 <> ext2, TFun (KType ()) t1' t2')
TCon rng n -> do
k <- getKind' rng n
downEqK rng mdown k
- return (TCon k n)
+ return (mempty, TCon k n)
TVar rng n -> do
k <- getKind' rng n
downEqK rng mdown k
- return (TVar k n)
+ return (case mode of KCTMNormal -> ()
+ KCTMOpen -> MMap.singleton n (pure k)
+ ,TVar k n)
TForall rng n t -> do -- implicit forall
k1 <- genKUniVar
k2 <- genKUniVar
downEqK rng mdown k2
- t' <- scopeTEnv $ do
+ (ext, t') <- scopeTEnv $ do
modifyTEnv (Map.insert n k1)
- kcType (Just k2) t
- return (TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit
+ kcType' mode (Just k2) t
+ return (ext, TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit
tcFunDefBlock :: [PFunDef] -> TCM [CFunDef]
tcFunDefBlock fdefs = do
@@ -390,7 +413,15 @@ tcFunDef (FunDef rng name msig eqs) = do
raise SError rng "Function equations have differing numbers of arguments"
typ <- case msig of
- TypeSig sig -> kcType (Just (KType ())) sig
+ TypeSig sig -> do
+ (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig
+ TODO -- We need to check that these free type variables do not escape.
+ -- Perhaps with levels on unification variables? Associate a level
+ -- to a generated uvar, and increment the global level counter when
+ -- passing below a forall.
+ -- But how do we deal with functions without a type signature
+ -- anyway? We should be able to infer a polymorphic type for them.
+ return $ foldr (\(n, k) -> TForall k n) typ (Map.assocs freetvars)
TypeSigExt NoTypeSig -> genUniVar (KType ())
eqs' <- scopeVEnv $ do
@@ -555,14 +586,15 @@ unfoldFunTy rng n t = do
emit $ CEq expected t rng
return (vars, core)
-solveConstrs :: Bag Constr -> TCM (Map Int CKind, Map Int CType)
+solveConstrs :: MonadRaise m => Bag Constr -> m (Map Int CKind, Map Int CType)
solveConstrs constrs = do
let (tcs, kcs) = partitionConstrs constrs
subK <- solveKindVars kcs
subT <- solveTypeVars tcs
- return (subK, subT)
+ let subT' = Map.map (substType subK mempty mempty) subT
+ return (subK, subT')
-solveKindVars :: Bag (CKind, CKind, Range) -> TCM (Map Int CKind)
+solveKindVars :: MonadRaise m => Bag (CKind, CKind, Range) -> m (Map Int CKind)
solveKindVars cs = do
let (asg, errs) =
solveConstraints
@@ -610,7 +642,7 @@ solveKindVars cs = do
kindSize (KFun () a b) = 1 + kindSize a + kindSize b
kindSize (KExt () KUniVar{}) = 2
-solveTypeVars :: Bag (CType, CType, Range) -> TCM (Map Int CType)
+solveTypeVars :: MonadRaise m => Bag (CType, CType, Range) -> m (Map Int CType)
solveTypeVars cs = do
let (asg, errs) =
solveConstraints
@@ -670,41 +702,43 @@ partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty)
-------------------- SUBSTITUTION FUNCTIONS --------------------
-- These take some of:
--- - an instantiation map for kind unification variables (Map Int {C,T}Kind)
--- - an instantiation map for type unification variables (Map Int {C,T}Type)
+-- - an instantiation map for kind unification variables (Map Int CKind)
+-- - an instantiation map for type unification variables (Map Int CType)
-- - an instantiation map for type variables (Map Name CType)
-substFinProg :: HasCallStack
- => Map Int TKind -> Map Int TType -> CProgram -> TProgram
-substFinProg mk mt (Program ds fs) = Program (map (substFinDdef mk mt) ds) (map (substFinFdef mk mt) fs)
-
-substFinDdef :: HasCallStack
- => Map Int TKind -> Map Int TType -> CDataDef -> TDataDef
-substFinDdef mk mt (DataDef k n ps cs) =
- DataDef (substFinKind mk k) n (map (first (substFinKind mk)) ps) (map (second (map (substFinType mk mt))) cs)
-
-substFinFdef :: HasCallStack
- => Map Int TKind -> Map Int TType -> CFunDef -> TFunDef
-substFinFdef mk mt (FunDef t n (TypeSig sig) eqs) =
- FunDef (substFinType mk mt t) n
- (TypeSig (substFinType mk mt sig))
- (fmap (substFinFunEq mk mt) eqs)
-
-substFinFunEq :: HasCallStack
- => Map Int TKind -> Map Int TType -> CFunEq -> TFunEq
-substFinFunEq mk mt (FunEq () n ps rhs) =
+substProg :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CProgram -> CProgram
+substProg mk mt (Program ds fs) = Program (map (substDdef mk mt) ds) (map (substFdef mk mt) fs)
+
+substDdef :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
+substDdef mk mt (DataDef k name pars cons) =
+ DataDef (substKind mk k) name
+ (map (first (substKind mk)) pars)
+ (map (second (map (substType mk mt mempty))) cons)
+
+substFdef :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CFunDef -> CFunDef
+substFdef mk mt (FunDef t n (TypeSig sig) eqs) =
+ FunDef (substType mk mt mempty t) n
+ (TypeSig (substType mk mt mempty sig))
+ (fmap (substFunEq mk mt) eqs)
+
+substFunEq :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CFunEq -> CFunEq
+substFunEq mk mt (FunEq () n ps rhs) =
FunEq () n
- (map (substFinPattern mk mt) ps)
- (substFinRHS mk mt rhs)
+ (map (substPattern mk mt) ps)
+ (substRHS mk mt rhs)
-substFinRHS :: HasCallStack
- => Map Int TKind -> Map Int TType -> CRHS -> TRHS
-substFinRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported"
-substFinRHS mk mt (Plain t e) = Plain (substFinType mk mt t) (substFinExpr mk mt e)
+substRHS :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CRHS -> CRHS
+substRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported"
+substRHS mk mt (Plain t e) = Plain (substType mk mt mempty t) (substExpr mk mt e)
-substFinPattern :: HasCallStack
- => Map Int TKind -> Map Int TType -> CPattern -> TPattern
-substFinPattern mk mt = go
+substPattern :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CPattern -> CPattern
+substPattern mk mt = go
where
go (PWildcard t) = PWildcard (goType t)
go (PVar t n) = PVar (goType t) n
@@ -714,11 +748,11 @@ substFinPattern mk mt = go
go (PList t ps) = PList (goType t) (map go ps)
go (PTup t ps) = PTup (goType t) (map go ps)
- goType = substFinType mk mt
+ goType = substType mk mt mempty
-substFinExpr :: HasCallStack
- => Map Int TKind -> Map Int TType -> CExpr -> TExpr
-substFinExpr mk mt = go
+substExpr :: HasCallStack
+ => Map Int CKind -> Map Int CType -> CExpr -> CExpr
+substExpr mk mt = go
where
go (ELit t lit) = ELit (goType t) lit
go (EVar t n) = EVar (goType t) n
@@ -728,42 +762,14 @@ substFinExpr mk mt = go
go (EApp t e1 es) = EApp (goType t) (go e1) (map go es)
go (EOp t e1 op e2) = EOp (goType t) (go e1) op (go e2)
go (EIf t e1 e2 e3) = EIf (goType t) (go e1) (go e2) (go e3)
- go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substFinPattern mk mt) (substFinRHS mk mt)) alts)
- go (ELet t defs body) = ELet (goType t) (map (substFinFdef mk mt) defs) (go body)
+ go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substPattern mk mt) (substRHS mk mt)) alts)
+ go (ELet t defs body) = ELet (goType t) (map (substFdef mk mt) defs) (go body)
go (EError t) = EError (goType t)
- goType = substFinType mk mt
-
-substFinType :: HasCallStack
- => Map Int TKind -- ^ kind uvars
- -> Map Int TType -- ^ type uvars
- -> CType -> TType
-substFinType mk mt = go
- where
- go (TApp k t ts) = TApp (substFinKind mk k) (go t) (map go ts)
- go (TTup k ts) = TTup (substFinKind mk k) (map go ts)
- go (TList k t) = TList (substFinKind mk k) (go t)
- go (TFun k t1 t2) = TFun (substFinKind mk k) (go t1) (go t2)
- go (TCon k n) = TCon (substFinKind mk k) n
- go (TVar k n) = TVar (substFinKind mk k) n
- go (TForall k n t) = TForall (substFinKind mk k) n (go t)
- go t@(TExt _ (TUniVar v)) = fromMaybe (error $ "substFinType: unification variables left: " ++ show t)
- (Map.lookup v mt)
-
-substFinKind :: HasCallStack => Map Int TKind -> CKind -> TKind
-substFinKind m = \case
- KType () -> KType ()
- KFun () k1 k2 -> KFun () (substFinKind m k1) (substFinKind m k2)
- k@(KExt () (KUniVar v)) -> fromMaybe (error $ "substFinKind: unification variables left: " ++ show k)
- (Map.lookup v m)
+ goType = substType mk mt mempty
-substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
-substDdef mk mt (DataDef k name pars cons) =
- DataDef (substKind mk k) name
- (map (first (substKind mk)) pars)
- (map (second (map (substType mk mt mempty))) cons)
-
-substType :: Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType
+substType :: HasCallStack
+ => Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType
substType mk mt mtv = go
where
go (TApp k t ts) = TApp (goKind k) (go t) (map go ts)
@@ -777,13 +783,17 @@ substType mk mt mtv = go
goKind = substKind mk
-substKind :: Map Int CKind -> CKind -> CKind
+substKind :: HasCallStack
+ => Map Int CKind -> CKind -> CKind
substKind m = \case
KType () -> KType ()
KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2)
k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m)
--------------------- END OF SUBSTITUTION FUNCTIONS --------------------
+-------------------- FINALISATION FUNCTIONS --------------------
+-- These report free type unification variables.
+
+-- TODO the finalise* functions
typeUniVars :: CType -> Set Int
typeUniVars = \case