aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS
diff options
context:
space:
mode:
Diffstat (limited to 'src/HSVIS')
-rw-r--r--src/HSVIS/AST.hs38
-rw-r--r--src/HSVIS/Diagnostic.hs23
-rw-r--r--src/HSVIS/Parser.hs10
-rw-r--r--src/HSVIS/Typecheck.hs83
-rw-r--r--src/HSVIS/Typecheck/Solve.hs131
5 files changed, 197 insertions, 88 deletions
diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs
index 8bb2d6c..2b125b9 100644
--- a/src/HSVIS/AST.hs
+++ b/src/HSVIS/AST.hs
@@ -11,6 +11,7 @@ module HSVIS.AST where
import Data.Bifunctor (bimap, first, second)
import qualified Data.Kind as DK
+import Data.List (intersperse)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
@@ -136,13 +137,36 @@ data Operator = OAdd | OSub | OMul | ODiv | OMod | OEqu | OPow
deriving (Show)
instance Pretty Name where
- prettysPrec _ (Name n) = showString ("\"" ++ n ++ "\"")
-
-instance (X Kind s ~ (), Pretty (E Kind s)) => Pretty (Kind s) where
- prettysPrec _ (KType ()) = showString "Type"
- prettysPrec d (KFun () a b) =
- showParen (d > -1) $ prettysPrec 0 a . showString " -> " . prettysPrec 0 b
- prettysPrec d (KExt () e) = prettysPrec d e
+ prettysPrec _ (Name n) = showString n
+
+instance Pretty (E Kind s) => Pretty (Kind s) where
+ prettysPrec _ (KType _) = showString "Type"
+ prettysPrec d (KFun _ a b) = showParen (d > -1) $
+ prettysPrec 0 a . showString " -> " . prettysPrec (-1) b
+ prettysPrec d (KExt _ e) = prettysPrec d e
+
+instance Pretty (E Type s) => Pretty (Type s) where
+ prettysPrec d (TApp _ a ts) = showParen (d > 10) $
+ prettysPrec 10 a . foldr (.) id [showString " " . prettysPrec 11 t | t <- ts]
+ prettysPrec _ (TTup _ ts) =
+ showString "(" . foldr (.) id (intersperse (showString ",") (map (prettysPrec 0) ts)) . showString ")"
+ prettysPrec _ (TList _ t) =
+ showString "[" . prettysPrec 0 t . showString "]"
+ prettysPrec d (TFun _ a b) = showParen (d > -1) $
+ prettysPrec 0 a . showString " -> " . prettysPrec (-1) b
+ prettysPrec _ (TCon _ n) = prettysPrec 11 n
+ prettysPrec _ (TVar _ n) = prettysPrec 11 n
+ prettysPrec d (TExt _ e) = prettysPrec d e
+
+instance (Pretty (X Type s), Pretty (E Type s)) => Pretty (DataDef s) where
+ prettysPrec _ (DataDef _ name vars cons) =
+ showString "data " . prettysPrec 11 name
+ . foldr (.) id [showString " (" . prettysPrec 11 n . showString " :: " . prettysPrec 11 k . showString ")"
+ | (k, n) <- vars]
+ . showString " = "
+ . foldr (.) id (intersperse (showString " | ")
+ [prettysPrec 11 cname . foldr (.) id [showString " " . prettysPrec 11 t | t <- fields]
+ | (cname, fields) <- cons])
instance HasExt DataDef where
type HasXField DataDef = 'True
diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs
index 322f9eb..675482d 100644
--- a/src/HSVIS/Diagnostic.hs
+++ b/src/HSVIS/Diagnostic.hs
@@ -2,6 +2,8 @@ module HSVIS.Diagnostic where
import Data.List (intercalate)
+import HSVIS.Pretty
+
data Pos = Pos
{ posLine :: Int -- ^ zero-based
@@ -9,6 +11,9 @@ data Pos = Pos
}
deriving (Show, Eq, Ord)
+instance Pretty Pos where
+ prettysPrec _ (Pos y x) = showString (show (y + 1) ++ ":" ++ show (x + 1))
+
-- | Inclusive-exclusive range of positions in a file.
data Range = Range Pos Pos
deriving (Show)
@@ -16,8 +21,15 @@ data Range = Range Pos Pos
instance Semigroup Range where
Range a b <> Range c d = Range (min a c) (max b d)
-data Diagnostic = Diagnostic
- { dFile :: FilePath -- ^ The file for which the diagnostic was raised
+instance Pretty Range where
+ prettysPrec _ (Range (Pos y1 x1) (Pos y2 x2))
+ | y2 <= y1 + 1, x2 <= x1 + 1 = showString (show (y1 + 1) ++ ":" ++ show (x1 + 1))
+ | y2 <= y1 + 1 = showString (show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ "-" ++ show x2)
+ | otherwise =
+ showString ("(" ++ show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ ")-(" ++ show (y2 + 1) ++ ":" ++ show x2 ++ ")")
+
+data Diagnostic = Diagnostic
+ { dFile :: FilePath -- ^ The file for which the diagnostic was rai sed
, dRange :: Range -- ^ Where in the file
, dStk :: [String] -- ^ Stack of contexts (innermost at head) of the diagnostic
, dSourceLine :: String -- ^ The line in the source file of the start of the range
@@ -26,12 +38,9 @@ data Diagnostic = Diagnostic
deriving (Show)
printDiagnostic :: Diagnostic -> String
-printDiagnostic (Diagnostic fp (Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) =
+printDiagnostic (Diagnostic fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) =
let linenum = show (y1 + 1)
- locstr | y1 == y2, x1 == x2 = show y1 ++ ":" ++ show x1
- | y1 == y2 = show y1 ++ ":" ++ show x1 ++ "-" ++ show x2
- | otherwise = "(" ++ show y1 ++ ":" ++ show x1 ++ ")-(" ++
- show y1 ++ ":" ++ show x1 ++ ")"
+ locstr = pretty rng
ncarets | y1 == y2 = max 1 (x2 - x1 + 1)
| otherwise = length srcline - x1
caretsuffix | y1 == y2 = ""
diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs
index 0df4aa8..b4d8754 100644
--- a/src/HSVIS/Parser.hs
+++ b/src/HSVIS/Parser.hs
@@ -672,9 +672,13 @@ pType = do
return (TFun (Range pos1 pos2) ty1 ty2)
pTypeApp :: FParser PType
-pTypeApp = fasome pTypeAtom >>= \case
- t :| [] -> return t
- t :| ts -> return (TApp (foldMapne extOf (t :| ts)) t ts)
+pTypeApp = do
+ pos1 <- gets psCur
+ ts <- fasome pTypeAtom
+ pos2 <- gets psCur
+ case ts of
+ t :| [] -> return t
+ t :| ts' -> return (TApp (Range pos1 pos2) t ts')
pTypeAtom :: FParser PType
pTypeAtom = faasum' [pTypeParens, pTypeList, pTypeName]
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index ba853a0..c97064a 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -78,6 +78,9 @@ type TPattern = Pattern StageTyped
type TRHS = RHS StageTyped
type TExpr = Expr StageTyped
+instance Pretty (E Type StageTC) where
+ prettysPrec _ (TUniVar n) = showString ("?t" ++ show n)
+
instance Pretty (E Kind StageTC) where
prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)
@@ -144,7 +147,7 @@ putFullEnv :: Env -> TCM ()
putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ())
genId :: TCM Int
-genId = TCM $ \_ i env -> (mempty, mempty, i, env, i)
+genId = TCM $ \_ i env -> (mempty, mempty, i + 1, env, i)
getKind :: Name -> TCM (Maybe CKind)
getKind name = do
@@ -208,6 +211,8 @@ tcProgram (Program ddefs fdefs) = do
solveKindVars kconstrs
+ traceM (unlines (map pretty ddefs'))
+
fdefs' <- mapM tcFunDef fdefs
return (Program ddefs' fdefs')
@@ -232,40 +237,59 @@ tcDataDef (DataDef rng name params cons) = do
cons' <- scopeTEnv $ do
modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>)
- mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons
+ mapM (\(cname, fieldtys) -> (cname,) <$> mapM (kcType (Just (KType ()))) fieldtys) cons
return (DataDef () name (zip pkinds (map snd params)) cons')
-kcType :: PType -> TCM CType
-kcType = \case
+promoteDown :: Maybe CKind -> TCM CKind
+promoteDown Nothing = genKUniVar
+promoteDown (Just k) = return k
+
+downEqK :: Range -> Maybe CKind -> CKind -> TCM ()
+downEqK _ Nothing _ = return ()
+downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng
+
+-- | Given (maybe) the expected kind of this type, and a type, check it for
+-- kind-correctness.
+kcType :: Maybe CKind -> PType -> TCM CType
+kcType mdown = \case
TApp rng t ts -> do
- t' <- kcType t
- ts' <- mapM kcType ts
- retk <- genKUniVar
+ t' <- kcType Nothing t
+ ts' <- mapM (kcType Nothing) ts
+ retk <- promoteDown mdown
let expected = foldr (KFun ()) retk (map extOf ts')
emit $ CEqK (extOf t') expected rng
return (TApp retk t' ts')
- TTup _ ts -> do
- ts' <- mapM kcType ts
+ TTup rng ts -> do
+ ts' <- mapM (kcType (Just (KType ()))) ts
forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
emit $ CEqK (extOf ct) (KType ()) trng
+ downEqK rng mdown (KType ())
return (TTup (KType ()) ts')
- TList _ t -> do
- t' <- kcType t
+ TList rng t -> do
+ t' <- kcType (Just (KType ())) t
emit $ CEqK (extOf t') (KType ()) (extOf t)
+ downEqK rng mdown (KType ())
return (TList (KType ()) t')
- TFun _ t1 t2 -> do
- t1' <- kcType t1
- t2' <- kcType t2
+ TFun rng t1 t2 -> do
+ t1' <- kcType (Just (KType ())) t1
+ t2' <- kcType (Just (KType ())) t2
emit $ CEqK (extOf t1') (KType ()) (extOf t1)
emit $ CEqK (extOf t2') (KType ()) (extOf t2)
+ downEqK rng mdown (KType ())
return (TFun (KType ()) t1' t2')
- TCon rng n -> TCon <$> getKind' rng n <*> pure n
+ TCon rng n -> do
+ k <- getKind' rng n
+ downEqK rng mdown k
+ return (TCon k n)
- TVar rng n -> TVar <$> getKind' rng n <*> pure n
+ TVar rng n -> do
+ k <- getKind' rng n
+ downEqK rng mdown k
+ return (TVar k n)
tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef _ name msig eqs) = do
@@ -273,7 +297,7 @@ tcFunDef (FunDef _ name msig eqs) = do
raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments"
typ <- case msig of
- TypeSig sig -> kcType sig
+ TypeSig sig -> kcType (Just (KType ())) sig
TypeSigExt NoTypeSig -> genUniVar (KType ())
eqs' <- mapM (tcFunEq typ) eqs
@@ -326,18 +350,21 @@ solveKindVars cs = do
(\case KExt () (KUniVar v) -> Just v
_ -> Nothing)
kindSize
- (map (\(a, b, _) -> (a, b)) (toList cs))
+ (toList cs)
where
- reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind))
- -- unification variables produce constraints on a unification variable
- reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty
- reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty)
- reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty)
- -- if lhs and rhs have equal prefixes, recurse
- reduce (KType ()) (KType ()) = mempty
- reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d
- -- otherwise, this is a kind mismatch
- reduce k1 k2 = (mempty, pure (k1, k2))
+ reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range))
+ reduce lhs rhs rng = case (lhs, rhs) of
+ -- unification variables produce constraints on a unification variable
+ (KExt () (KUniVar i), KExt () (KUniVar j)) | i == j -> mempty
+ (KExt () (KUniVar i), k ) -> (pure (i, k, rng), mempty)
+ (k , KExt () (KUniVar i)) -> (pure (i, k, rng), mempty)
+
+ -- if lhs and rhs have equal prefixes, recurse
+ (KType () , KType () ) -> mempty
+ (KFun () a b, KFun () c d) -> reduce a c rng <> reduce b d rng
+
+ -- otherwise, this is a kind mismatch
+ (k1, k2) -> (mempty, pure (k1, k2, rng))
kindSize :: CKind -> Int
kindSize KType{} = 1
diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs
index 184937c..5f51abe 100644
--- a/src/HSVIS/Typecheck/Solve.hs
+++ b/src/HSVIS/Typecheck/Solve.hs
@@ -1,10 +1,15 @@
+{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE ScopedTypeVariables #-}
-module HSVIS.Typecheck.Solve where
+{-# LANGUAGE TypeApplications #-}
+module HSVIS.Typecheck.Solve (
+ solveConstraints,
+ UnifyErr(..),
+) where
import Control.Monad (guard, (>=>))
-import Data.Bifunctor (second)
+import Data.Bifunctor (Bifunctor(..))
import Data.Foldable (toList, foldl')
-import Data.List (sort)
+import Data.List (sortBy, minimumBy, groupBy)
import Data.Ord (comparing)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
@@ -12,26 +17,53 @@ import qualified Data.Map.Strict as Map
import Debug.Trace
import Data.Bag
-import Data.List (minimumBy)
+import HSVIS.Diagnostic (Range(..))
+import HSVIS.Pretty
+import Data.Function (on)
data UnifyErr v t
- = UEUnequal t t
- | UERecursive v t
+ = UEUnequal t t Range
+ | UERecursive v t Range
deriving (Show)
+data Constr a b = Constr a b Range
+ deriving (Show)
+
+instance Bifunctor Constr where
+ bimap f g (Constr x y r) = Constr (f x) (g y) r
+
+-- right-hand side of a constraint
+data RConstr b = RConstr b Range
+ deriving (Show, Functor)
+
+splitConstr :: Constr a b -> (a, RConstr b)
+splitConstr (Constr x y r) = (x, RConstr y r)
+
+unsplitConstr :: a -> RConstr b -> Constr a b
+unsplitConstr x (RConstr y r) = Constr x y r
+
+constrUnequal :: Constr t t -> UnifyErr v t
+constrUnequal (Constr x y r) = UEUnequal x y r
+
+constrRecursive :: Constr v t -> UnifyErr v t
+constrRecursive (Constr x y r) = UERecursive x y r
+
+rconType :: RConstr b -> b
+rconType (RConstr t _) = t
+
-- | Returns a pair of:
-- 1. A set of unification errors;
-- 2. An assignment of the variables that had any constraints on them.
-- The producedure was successful if the set of errors is empty. Note that
-- unconstrained variables do not appear in the output.
solveConstraints
- :: forall v t. (Ord v, Ord t, Show v, Show t)
+ :: forall v t. (Ord v, Ord t, Show v, Pretty t)
-- | reduce: take two types and unify them, resulting in:
-- 1. A bag of resulting constraints on variables;
-- 2. A bag of errors: pairs of two types that are provably distinct but
-- need to be equal for the input types to unify.
- => (t -> t -> (Bag (v, t), Bag (t, t)))
+ => (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range)))
-- | Free variables in a type
-> (t -> Bag v)
-- | \v repl term -> Substitute v by repl in term
@@ -41,63 +73,76 @@ solveConstraints
-- | Some kind of size measure on types
-> (t -> Int)
-- | Equality constraints to solve
- -> [(t, t)]
+ -> [(t, t, Range)]
-> (Map v t, Bag (UnifyErr v t))
-solveConstraints reduce frees subst detect size = \cs ->
- let (vcs, errs) = foldMap (uncurry reduce) cs
- asg = Map.fromListWith (<>) (map (second pure) (toList vcs))
+solveConstraints reduce frees subst detect size = \tupcs ->
+ let cs = map (uncurry3 Constr) tupcs :: [Constr t t]
+ (vcs, errs) = foldMap reduce' cs
+ asg = Map.fromListWith (<>) (map (second pure . splitConstr) (toList vcs))
(errs', asg') = loop asg []
- errs'' = fmap (uncurry UEUnequal) errs <> errs'
- in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $
+ errs'' = fmap constrUnequal errs <> errs'
+ in trace ("[solver] Solving:" ++ concat ["\n- " ++ pretty a ++ " == " ++ pretty b ++ " {" ++ pretty r ++ "}" | Constr a b r <- cs]) $
trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++
- concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg'])
+ concat ["\n- " ++ show v ++ " = " ++ pretty t | (v, t) <- Map.assocs asg'])
(asg', errs'')
where
- loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t)
+ reduce' :: Constr t t -> (Bag (Constr v t), Bag (Constr t t))
+ reduce' (Constr t1 t2 r) = bimap (fmap (uncurry3 Constr)) (fmap (uncurry3 Constr)) $ reduce t1 t2 r
+
+ loop :: Map v (Bag (RConstr t)) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t)
loop m eqlog = do
- traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m]
+ traceM $ "[solver] Step:" ++ concat
+ [case toList rhss of
+ [] -> "\n- " ++ show v ++ " <no RHSs>"
+ RConstr t r : rest ->
+ "\n- " ++ show v ++ " == " ++ pretty t ++ " {" ++ pretty r ++ "}" ++
+ concat ["\n " ++ replicate (length (show v)) ' ' ++ " == " ++ pretty t' ++ " {" ++ pretty r' ++ "}"
+ | RConstr t' r' <- rest]
+ | (v, rhss) <- Map.assocs m]
+
m' <- Map.traverseWithKey
- (\v ts ->
- let ts' = bagFromList (dedup (toList ts))
+ (\v rhss ->
+ let rhss' = bagFromList (dedupRCs (toList rhss))
-- filter out recursive equations
- (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts'
+ (recs, nonrecs) = bagPartition (\c@(RConstr t _) -> if v `elem` frees t then Just c else Nothing) rhss'
-- filter out trivial equations (v = v)
- (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs
- in (UERecursive v <$> recs, nonrecs'))
+ (_, nonrecs') = bagPartition (detect . rconType >=> guard . (== v)) nonrecs
+ in (constrRecursive . unsplitConstr v <$> recs, nonrecs'))
m
- let msmallestvar = -- var with its smallest RHS, if such a var exists
- minimumByMay (comparing (size . snd))
- . map (second (minimumBy (comparing size)))
+ let msmallestvar :: Maybe (v, RConstr t) -- var with its smallest RHS, if such a var exists
+ msmallestvar =
+ minimumByMay (comparing (size . rconType . snd))
+ . map (second (minimumBy (comparing (size . rconType))))
. filter (not . null . snd)
$ Map.assocs m'
case msmallestvar of
Nothing -> return $ applyLog eqlog mempty
- Just (var, smallrhs) -> do
- let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var)
- let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss))
- (fmap (uncurry UEUnequal) errs, ()) -- write the errors
+ Just (var, RConstr smallrhs _) -> do
+ let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var)
+ let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss))
+ (fmap constrUnequal errs, ()) -- write the errors
let m'' = Map.unionWith (<>)
- (Map.map (fmap (subst var smallrhs)) (Map.delete var m'))
- (Map.fromListWith (<>) (map (second pure) (toList newcs)))
+ (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs)))
+ (Map.delete var m'))
+ (Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs)))
loop m'' ((var, smallrhs) : eqlog)
applyLog :: [(v, t)] -> Map v t -> Map v t
applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m)
applyLog [] m = m
- dedup :: Ord t => [t] -> [t]
- dedup = uniq . sort
+ -- If there are multiple sources for the same cosntraint, only one of them is kept.
+ dedupRCs :: Ord t => [RConstr t] -> [RConstr t]
+ dedupRCs = map head . groupBy ((==) `on` rconType) . sortBy (comparing rconType)
- uniq :: Eq a => [a] -> [a]
- uniq (x:y:xs) | x == y = uniq (x : xs)
- | otherwise = x : uniq (y : xs)
- uniq l = l
+minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a
+minimumByMay cmp = foldl' min' Nothing
+ where min' mx y = Just $! case mx of
+ Nothing -> y
+ Just x | GT <- cmp x y -> y
+ | otherwise -> x
- minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a
- minimumByMay cmp = foldl' min' Nothing
- where min' mx y = Just $! case mx of
- Nothing -> y
- Just x | GT <- cmp x y -> y
- | otherwise -> x
+uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
+uncurry3 f (x, y, z) = f x y z