diff options
-rw-r--r-- | hs-visinter.cabal | 1 | ||||
-rw-r--r-- | src/Data/Map/Monoidal.hs | 31 | ||||
-rw-r--r-- | src/HSVIS/AST.hs | 4 | ||||
-rw-r--r-- | src/HSVIS/Typecheck.hs | 206 |
4 files changed, 142 insertions, 100 deletions
diff --git a/hs-visinter.cabal b/hs-visinter.cabal index 48d3456..118882f 100644 --- a/hs-visinter.cabal +++ b/hs-visinter.cabal @@ -11,6 +11,7 @@ library exposed-modules: Control.FAlternative Data.Bag + Data.Map.Monoidal Data.List.NonEmpty.Util HSVIS.AST HSVIS.Diagnostic diff --git a/src/Data/Map/Monoidal.hs b/src/Data/Map/Monoidal.hs new file mode 100644 index 0000000..7007934 --- /dev/null +++ b/src/Data/Map/Monoidal.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE DeriveTraversable #-} +module Data.Map.Monoidal where + +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import Data.Maybe (fromMaybe) + + +newtype MMap k v = MMap (Map k v) + deriving (Show, Functor, Foldable, Traversable) + +instance (Ord k, Semigroup v) => Semigroup (MMap k v) where + MMap m1 <> MMap m2 = MMap (Map.unionWith (<>) m1 m2) + +instance (Ord k, Semigroup v) => Monoid (MMap k v) where + mempty = MMap Map.empty + +fromList :: (Ord k, Semigroup v) => [(k, v)] -> MMap k v +fromList l = MMap (Map.fromListWith (<>) l) + +singleton :: k -> v -> MMap k v +singleton k v = MMap (Map.singleton k v) + +lookup :: (Ord k, Monoid v) => k -> MMap k v -> v +lookup k (MMap m) = fromMaybe mempty (Map.lookup k m) + +lookup' :: Ord k => k -> MMap k v -> Maybe v +lookup' k (MMap m) = Map.lookup k m + +insert :: (Ord k, Semigroup v) => k -> v -> MMap k v -> MMap k v +insert k v (MMap m) = MMap (Map.insertWith (<>) k v m) diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs index 2986248..91e08eb 100644 --- a/src/HSVIS/AST.hs +++ b/src/HSVIS/AST.hs @@ -95,7 +95,7 @@ data Type s | TFun (X Type s) (Type s) (Type s) | TCon (X Type s) Name | TVar (X Type s) Name - | TForall (X Type s) Name (Type s) -- ^ implicit + | TForall (X Type s) Name (Type s) -- ^ implicit; also, not parsed -- extension point | TExt (X Type s) !(E Type s) @@ -161,7 +161,7 @@ instance Pretty (E Type s) => Pretty (Type s) where prettysPrec _ (TCon _ n) = prettysPrec 11 n prettysPrec _ (TVar _ n) = prettysPrec 11 n prettysPrec d (TForall _ n t) = showParen (d > -1) $ - showString "forall " . prettysPrec 11 n . showString "." . prettysPrec (-1) t + showString "forall " . prettysPrec 11 n . showString ". " . prettysPrec (-1) t prettysPrec d (TExt _ e) = prettysPrec d e instance (Pretty (X Type s), Pretty (E Type s)) => Pretty (DataDef s) where diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index f292c0e..2dd103f 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE GADTs #-} @@ -11,6 +12,7 @@ {-# OPTIONS -Wno-unused-top-binds #-} {-# OPTIONS -Wno-unused-imports #-} +{-# LANGUAGE DataKinds #-} @@ -30,6 +32,7 @@ import Data.Maybe (fromMaybe) import Data.Monoid (Ap(..)) import qualified Data.Map.Strict as Map import Data.Tuple (swap) +import Data.Semigroup (First(..)) import Data.Set (Set) import qualified Data.Set as Set import GHC.Stack @@ -38,6 +41,8 @@ import Debug.Trace import Data.Bag import Data.List.NonEmpty.Util +import Data.Map.Monoidal (MMap(..)) +import qualified Data.Map.Monoidal as MMap import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic @@ -153,9 +158,12 @@ instance Monad TCM where (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) -raise :: Severity -> Range -> String -> TCM () -raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> - (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ()) +class Monad m => MonadRaise m where + raise :: Severity -> Range -> String -> m () + +instance MonadRaise TCM where + raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> + (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ()) emit :: Constr -> TCM () emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ()) @@ -243,9 +251,7 @@ tcTop :: PProgram -> TCM TProgram tcTop prog = do (cs, prog') <- collectConstraints Just (tcProgram prog) (subK, subT) <- solveConstrs cs - let subK' = Map.map (substFinKind mempty) subK - subT' = Map.map (substFinType subK' mempty) subT - return $ substFinProg subK' subT' prog' + return $ finaliseProg (substProg subK subT prog') tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs1 fdefs1) = do @@ -296,7 +302,7 @@ kcDataDef (kd, parkinds) (DataDef _ name params cons) = do cons' <- scopeTEnv $ do modifyTEnv (Map.fromList (zip params' parkinds) <>) forM cons $ \(cname, fieldtys) -> do - fieldtys' <- mapM (kcType (Just (KType ()))) fieldtys + fieldtys' <- mapM (kcType KCTMNormal (Just (KType ()))) fieldtys return (cname, fieldtys') return (DataDef kd name (zip parkinds params') cons') @@ -315,57 +321,74 @@ downEqK :: Range -> Maybe CKind -> CKind -> TCM () downEqK _ Nothing _ = return () downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng +data KCTypeMode ext ret where + -- | Kind-check a normal type: out-of-scope type variables are reported as errors. + KCTMNormal :: KCTypeMode () CType + + -- | Kind-check an open type: out-of-scope type variables are returned. This + -- is used to check function type signatures, which may have an implicit + -- forall telescope at the head. + KCTMOpen :: KCTypeMode (MMap Name (First CKind)) (CType, Map Name CKind) + +-- | Given (maybe) the expected kind of this type, and a type, check it for +-- kind-correctness. +kcType :: forall ext ret. KCTypeMode ext ret -> Maybe CKind -> PType -> TCM ret +kcType KCTMNormal mdown t = snd <$> kcType' KCTMNormal mdown t +kcType KCTMOpen mdown t = second (\(MMap m) -> Map.map getFirst m) . swap <$> kcType' KCTMOpen mdown t + -- | Given (maybe) the expected kind of this type, and a type, check it for -- kind-correctness. -kcType :: Maybe CKind -> PType -> TCM CType -kcType mdown = \case +kcType' :: forall ext ret. Monoid ext => KCTypeMode ext ret -> Maybe CKind -> PType -> TCM (ext, CType) +kcType' mode mdown = \case TApp rng t ts -> do - t' <- kcType Nothing t - ts' <- mapM (kcType Nothing) ts + (ext1, t') <- kcType' mode Nothing t + (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts retk <- promoteDownK mdown let expected = foldr (KFun ()) retk (map extOf ts') emit $ CEqK (extOf t') expected rng - return (TApp retk t' ts') + return (ext1 <> ext2, TApp retk t' ts') TTup rng ts -> do - ts' <- mapM (kcType (Just (KType ()))) ts + (ext, ts') <- sequence <$> mapM (kcType' mode (Just (KType ()))) ts forM_ (zip (map extOf ts) ts') $ \(trng, ct) -> emit $ CEqK (extOf ct) (KType ()) trng downEqK rng mdown (KType ()) - return (TTup (KType ()) ts') + return (ext, TTup (KType ()) ts') TList rng t -> do - t' <- kcType (Just (KType ())) t + (ext, t') <- kcType' mode (Just (KType ())) t emit $ CEqK (extOf t') (KType ()) (extOf t) downEqK rng mdown (KType ()) - return (TList (KType ()) t') + return (ext, TList (KType ()) t') TFun rng t1 t2 -> do - t1' <- kcType (Just (KType ())) t1 - t2' <- kcType (Just (KType ())) t2 + (ext1, t1') <- kcType' mode (Just (KType ())) t1 + (ext2, t2') <- kcType' mode (Just (KType ())) t2 emit $ CEqK (extOf t1') (KType ()) (extOf t1) emit $ CEqK (extOf t2') (KType ()) (extOf t2) downEqK rng mdown (KType ()) - return (TFun (KType ()) t1' t2') + return (ext1 <> ext2, TFun (KType ()) t1' t2') TCon rng n -> do k <- getKind' rng n downEqK rng mdown k - return (TCon k n) + return (mempty, TCon k n) TVar rng n -> do k <- getKind' rng n downEqK rng mdown k - return (TVar k n) + return (case mode of KCTMNormal -> () + KCTMOpen -> MMap.singleton n (pure k) + ,TVar k n) TForall rng n t -> do -- implicit forall k1 <- genKUniVar k2 <- genKUniVar downEqK rng mdown k2 - t' <- scopeTEnv $ do + (ext, t') <- scopeTEnv $ do modifyTEnv (Map.insert n k1) - kcType (Just k2) t - return (TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit + kcType' mode (Just k2) t + return (ext, TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit tcFunDefBlock :: [PFunDef] -> TCM [CFunDef] tcFunDefBlock fdefs = do @@ -390,7 +413,15 @@ tcFunDef (FunDef rng name msig eqs) = do raise SError rng "Function equations have differing numbers of arguments" typ <- case msig of - TypeSig sig -> kcType (Just (KType ())) sig + TypeSig sig -> do + (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig + TODO -- We need to check that these free type variables do not escape. + -- Perhaps with levels on unification variables? Associate a level + -- to a generated uvar, and increment the global level counter when + -- passing below a forall. + -- But how do we deal with functions without a type signature + -- anyway? We should be able to infer a polymorphic type for them. + return $ foldr (\(n, k) -> TForall k n) typ (Map.assocs freetvars) TypeSigExt NoTypeSig -> genUniVar (KType ()) eqs' <- scopeVEnv $ do @@ -555,14 +586,15 @@ unfoldFunTy rng n t = do emit $ CEq expected t rng return (vars, core) -solveConstrs :: Bag Constr -> TCM (Map Int CKind, Map Int CType) +solveConstrs :: MonadRaise m => Bag Constr -> m (Map Int CKind, Map Int CType) solveConstrs constrs = do let (tcs, kcs) = partitionConstrs constrs subK <- solveKindVars kcs subT <- solveTypeVars tcs - return (subK, subT) + let subT' = Map.map (substType subK mempty mempty) subT + return (subK, subT') -solveKindVars :: Bag (CKind, CKind, Range) -> TCM (Map Int CKind) +solveKindVars :: MonadRaise m => Bag (CKind, CKind, Range) -> m (Map Int CKind) solveKindVars cs = do let (asg, errs) = solveConstraints @@ -610,7 +642,7 @@ solveKindVars cs = do kindSize (KFun () a b) = 1 + kindSize a + kindSize b kindSize (KExt () KUniVar{}) = 2 -solveTypeVars :: Bag (CType, CType, Range) -> TCM (Map Int CType) +solveTypeVars :: MonadRaise m => Bag (CType, CType, Range) -> m (Map Int CType) solveTypeVars cs = do let (asg, errs) = solveConstraints @@ -670,41 +702,43 @@ partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty) -------------------- SUBSTITUTION FUNCTIONS -------------------- -- These take some of: --- - an instantiation map for kind unification variables (Map Int {C,T}Kind) --- - an instantiation map for type unification variables (Map Int {C,T}Type) +-- - an instantiation map for kind unification variables (Map Int CKind) +-- - an instantiation map for type unification variables (Map Int CType) -- - an instantiation map for type variables (Map Name CType) -substFinProg :: HasCallStack - => Map Int TKind -> Map Int TType -> CProgram -> TProgram -substFinProg mk mt (Program ds fs) = Program (map (substFinDdef mk mt) ds) (map (substFinFdef mk mt) fs) - -substFinDdef :: HasCallStack - => Map Int TKind -> Map Int TType -> CDataDef -> TDataDef -substFinDdef mk mt (DataDef k n ps cs) = - DataDef (substFinKind mk k) n (map (first (substFinKind mk)) ps) (map (second (map (substFinType mk mt))) cs) - -substFinFdef :: HasCallStack - => Map Int TKind -> Map Int TType -> CFunDef -> TFunDef -substFinFdef mk mt (FunDef t n (TypeSig sig) eqs) = - FunDef (substFinType mk mt t) n - (TypeSig (substFinType mk mt sig)) - (fmap (substFinFunEq mk mt) eqs) - -substFinFunEq :: HasCallStack - => Map Int TKind -> Map Int TType -> CFunEq -> TFunEq -substFinFunEq mk mt (FunEq () n ps rhs) = +substProg :: HasCallStack + => Map Int CKind -> Map Int CType -> CProgram -> CProgram +substProg mk mt (Program ds fs) = Program (map (substDdef mk mt) ds) (map (substFdef mk mt) fs) + +substDdef :: HasCallStack + => Map Int CKind -> Map Int CType -> CDataDef -> CDataDef +substDdef mk mt (DataDef k name pars cons) = + DataDef (substKind mk k) name + (map (first (substKind mk)) pars) + (map (second (map (substType mk mt mempty))) cons) + +substFdef :: HasCallStack + => Map Int CKind -> Map Int CType -> CFunDef -> CFunDef +substFdef mk mt (FunDef t n (TypeSig sig) eqs) = + FunDef (substType mk mt mempty t) n + (TypeSig (substType mk mt mempty sig)) + (fmap (substFunEq mk mt) eqs) + +substFunEq :: HasCallStack + => Map Int CKind -> Map Int CType -> CFunEq -> CFunEq +substFunEq mk mt (FunEq () n ps rhs) = FunEq () n - (map (substFinPattern mk mt) ps) - (substFinRHS mk mt rhs) + (map (substPattern mk mt) ps) + (substRHS mk mt rhs) -substFinRHS :: HasCallStack - => Map Int TKind -> Map Int TType -> CRHS -> TRHS -substFinRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported" -substFinRHS mk mt (Plain t e) = Plain (substFinType mk mt t) (substFinExpr mk mt e) +substRHS :: HasCallStack + => Map Int CKind -> Map Int CType -> CRHS -> CRHS +substRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported" +substRHS mk mt (Plain t e) = Plain (substType mk mt mempty t) (substExpr mk mt e) -substFinPattern :: HasCallStack - => Map Int TKind -> Map Int TType -> CPattern -> TPattern -substFinPattern mk mt = go +substPattern :: HasCallStack + => Map Int CKind -> Map Int CType -> CPattern -> CPattern +substPattern mk mt = go where go (PWildcard t) = PWildcard (goType t) go (PVar t n) = PVar (goType t) n @@ -714,11 +748,11 @@ substFinPattern mk mt = go go (PList t ps) = PList (goType t) (map go ps) go (PTup t ps) = PTup (goType t) (map go ps) - goType = substFinType mk mt + goType = substType mk mt mempty -substFinExpr :: HasCallStack - => Map Int TKind -> Map Int TType -> CExpr -> TExpr -substFinExpr mk mt = go +substExpr :: HasCallStack + => Map Int CKind -> Map Int CType -> CExpr -> CExpr +substExpr mk mt = go where go (ELit t lit) = ELit (goType t) lit go (EVar t n) = EVar (goType t) n @@ -728,42 +762,14 @@ substFinExpr mk mt = go go (EApp t e1 es) = EApp (goType t) (go e1) (map go es) go (EOp t e1 op e2) = EOp (goType t) (go e1) op (go e2) go (EIf t e1 e2 e3) = EIf (goType t) (go e1) (go e2) (go e3) - go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substFinPattern mk mt) (substFinRHS mk mt)) alts) - go (ELet t defs body) = ELet (goType t) (map (substFinFdef mk mt) defs) (go body) + go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substPattern mk mt) (substRHS mk mt)) alts) + go (ELet t defs body) = ELet (goType t) (map (substFdef mk mt) defs) (go body) go (EError t) = EError (goType t) - goType = substFinType mk mt - -substFinType :: HasCallStack - => Map Int TKind -- ^ kind uvars - -> Map Int TType -- ^ type uvars - -> CType -> TType -substFinType mk mt = go - where - go (TApp k t ts) = TApp (substFinKind mk k) (go t) (map go ts) - go (TTup k ts) = TTup (substFinKind mk k) (map go ts) - go (TList k t) = TList (substFinKind mk k) (go t) - go (TFun k t1 t2) = TFun (substFinKind mk k) (go t1) (go t2) - go (TCon k n) = TCon (substFinKind mk k) n - go (TVar k n) = TVar (substFinKind mk k) n - go (TForall k n t) = TForall (substFinKind mk k) n (go t) - go t@(TExt _ (TUniVar v)) = fromMaybe (error $ "substFinType: unification variables left: " ++ show t) - (Map.lookup v mt) - -substFinKind :: HasCallStack => Map Int TKind -> CKind -> TKind -substFinKind m = \case - KType () -> KType () - KFun () k1 k2 -> KFun () (substFinKind m k1) (substFinKind m k2) - k@(KExt () (KUniVar v)) -> fromMaybe (error $ "substFinKind: unification variables left: " ++ show k) - (Map.lookup v m) + goType = substType mk mt mempty -substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef -substDdef mk mt (DataDef k name pars cons) = - DataDef (substKind mk k) name - (map (first (substKind mk)) pars) - (map (second (map (substType mk mt mempty))) cons) - -substType :: Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType +substType :: HasCallStack + => Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType substType mk mt mtv = go where go (TApp k t ts) = TApp (goKind k) (go t) (map go ts) @@ -777,13 +783,17 @@ substType mk mt mtv = go goKind = substKind mk -substKind :: Map Int CKind -> CKind -> CKind +substKind :: HasCallStack + => Map Int CKind -> CKind -> CKind substKind m = \case KType () -> KType () KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2) k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m) --------------------- END OF SUBSTITUTION FUNCTIONS -------------------- +-------------------- FINALISATION FUNCTIONS -------------------- +-- These report free type unification variables. + +-- TODO the finalise* functions typeUniVars :: CType -> Set Int typeUniVars = \case |