aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-14 23:21:53 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-14 23:21:53 +0100
commite7bed242ba52e6d3233928f2c6189e701cfa5e4c (patch)
tree4bdda2b7bc702c87d97f89946362e6b719126831 /src
parente8f09ff3f9d40922238d646c8fbcbacf9cfdfb62 (diff)
Some typechecker work
Diffstat (limited to 'src')
-rw-r--r--src/Data/Bag.hs65
-rw-r--r--src/HSVIS/AST.hs11
-rw-r--r--src/HSVIS/Parser.hs12
-rw-r--r--src/HSVIS/Typecheck.hs112
-rw-r--r--src/HSVIS/Typecheck/Solve.hs103
5 files changed, 267 insertions, 36 deletions
diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs
index f1173e4..ba1f912 100644
--- a/src/Data/Bag.hs
+++ b/src/Data/Bag.hs
@@ -1,19 +1,41 @@
{-# LANGUAGE DeriveTraversable #-}
-module Data.Bag where
+module Data.Bag (
+ Bag,
+ bagFromList,
+ bagFilter,
+ bagPartition,
+) where
+
+import Data.Bifunctor (bimap)
+import Data.Foldable (toList)
+import Data.List.NonEmpty (NonEmpty((:|)))
+import Data.Maybe (mapMaybe)
+import Data.Semigroup
data Bag a
= BZero
| BOne a
| BTwo (Bag a) (Bag a)
- deriving (Functor, Foldable, Traversable)
+ | BList [Bag a] -- make mconcat efficient
+ | BList' [a] -- make bagFromList efficient
+ deriving (Functor, Traversable)
+
+instance Show a => Show (Bag a) where
+ showsPrec d b = showParen (d > 10) $
+ showString "Bag " . showList (toList b)
instance Semigroup (Bag a) where
BZero <> b = b
b <> BZero = b
b1 <> b2 = BTwo b1 b2
-instance Monoid (Bag a) where mempty = BZero
+ sconcat (b :| bs) = b <> mconcat bs
+ stimes n b = mconcat (replicate (fromIntegral n) b)
+
+instance Monoid (Bag a) where
+ mempty = BZero
+ mconcat = BList
instance Applicative Bag where
pure = BOne
@@ -22,6 +44,36 @@ instance Applicative Bag where
_ <*> BZero = BZero
BOne f <*> b = f <$> b
BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b)
+ BList bs <*> b = BList (map (<*> b) bs)
+ BList' xs <*> b = BList (map BOne xs) <*> b
+
+instance Foldable Bag where
+ foldMap _ BZero = mempty
+ foldMap f (BOne x) = f x
+ foldMap f (BTwo b1 b2) = foldMap f b1 <> foldMap f b2
+ foldMap f (BList l) = foldMap (foldMap f) l
+ foldMap f (BList' l) = foldMap f l
+
+ toList (BList' xs) = xs
+ toList b = foldr (:) [] b
+
+ null BZero = True
+ null BOne{} = False
+ null (BTwo b1 b2) = null b1 && null b2
+ null (BList l) = all null l
+ null (BList' l) = null l
+
+bagFromList :: [a] -> Bag a
+bagFromList = BList'
+
+bagFilter :: (a -> Maybe b) -> Bag a -> Bag b
+bagFilter _ BZero = BZero
+bagFilter f (BOne x)
+ | Just y <- f x = BOne y
+ | otherwise = BZero
+bagFilter f (BTwo b1 b2) = bagFilter f b1 <> bagFilter f b2
+bagFilter f (BList bs) = BList (map (bagFilter f) bs)
+bagFilter f (BList' xs) = BList' (mapMaybe f xs)
bagPartition :: (a -> Maybe b) -> Bag a -> (Bag b, Bag a)
bagPartition _ BZero = (BZero, BZero)
@@ -29,3 +81,10 @@ bagPartition f (BOne x)
| Just y <- f x = (BOne y, BZero)
| otherwise = (BZero, BOne x)
bagPartition f (BTwo b1 b2) = bagPartition f b1 <> bagPartition f b2
+bagPartition f (BList bs) = foldMap (bagPartition f) bs
+bagPartition f (BList' xs) =
+ bimap bagFromList bagFromList $
+ foldr (\x (l1,l2) -> case f x of
+ Just y -> (y : l1, l2)
+ Nothing -> (l1, x : l2))
+ ([], []) xs
diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs
index f95a3cc..8bb2d6c 100644
--- a/src/HSVIS/AST.hs
+++ b/src/HSVIS/AST.hs
@@ -1,15 +1,10 @@
-{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE KindSignatures #-}
-{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ConstrainedClassMethods #-}
module HSVIS.AST where
@@ -63,7 +58,7 @@ type family EMapperTelescope fs s1 s2 a where
newtype Name = Name String
- deriving (Show, Eq)
+ deriving (Show, Eq, Ord)
data Program s = Program [DataDef s] [FunDef s]
deriving instance (Show (DataDef s), Show (FunDef s)) => Show (Program s)
@@ -89,6 +84,8 @@ data Kind s
-- extension point
| KExt (X Kind s) !(E Kind s)
deriving instance (Show (X Kind s), Show (E Kind s)) => Show (Kind s)
+deriving instance (Eq (X Kind s), Eq (E Kind s)) => Eq (Kind s)
+deriving instance (Ord (X Kind s), Ord (E Kind s)) => Ord (Kind s)
data Type s
= TApp (X Type s) (Type s) [Type s]
diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs
index 3251989..0df4aa8 100644
--- a/src/HSVIS/Parser.hs
+++ b/src/HSVIS/Parser.hs
@@ -162,9 +162,9 @@ instance KnownFallible fail => MonadChronicle (Bag Diagnostic) (Parser fail) whe
(kok ps mempty def)
condemn (Parser f) = Parser $ \ctx ps kok kfat kbt ->
f ctx ps
- (\ps' errs x -> case errs of
- BZero -> kok ps' mempty x
- _ -> kfat errs)
+ (\ps' errs x -> if null errs
+ then kok ps' mempty x
+ else kfat errs)
kfat
kbt
retcon g (Parser f) = Parser $ \ctx ps kok kfat kbt ->
@@ -180,9 +180,9 @@ instance KnownFallible fail => MonadChronicle (Bag Diagnostic) (Parser fail) whe
parse :: FilePath -> String -> ([Diagnostic], Maybe PProgram)
parse fp source =
runParser pProgram (Context fp (lines source) []) (PS (Pos 0 0) (Pos 0 0) source)
- (\_ errs res -> case errs of
- BZero -> ([], Just res)
- _ -> (toList errs, Just res))
+ (\_ errs res -> if null errs
+ then ([], Just res)
+ else (toList errs, Just res))
(\errs -> (toList errs, Nothing))
() -- the program parser cannot fail! :D
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index de9d7db..ba853a0 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
@@ -7,11 +8,16 @@
module HSVIS.Typecheck where
import Control.Monad
-import Data.Bifunctor (first, second)
+import Data.Bifunctor (first)
import Data.Foldable (toList)
import Data.Map.Strict (Map)
-import Data.Monoid (First(..))
+import Data.Maybe (fromMaybe)
+import Data.Monoid (Ap(..))
import qualified Data.Map.Strict as Map
+import Data.Set (Set)
+import qualified Data.Set as Set
+
+import Debug.Trace
import Data.Bag
import Data.List.NonEmpty.Util
@@ -19,6 +25,7 @@ import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
import HSVIS.Pretty
+import HSVIS.Typecheck.Solve
data StageTC
@@ -33,7 +40,8 @@ type instance X RHS StageTC = CType
type instance X Expr StageTC = CType
data instance E Type StageTC = TUniVar Int deriving (Show)
-data instance E Kind StageTC = KUniVar Int deriving (Show)
+data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
+data instance E TypeSig StageTC deriving (Show)
type CProgram = Program StageTC
type CDataDef = DataDef StageTC
@@ -58,6 +66,7 @@ type instance X Expr StageTyped = TType
data instance E Type StageTyped deriving (Show)
data instance E Kind StageTyped deriving (Show)
+data instance E TypeSig StageTyped deriving (Show)
type TProgram = Program StageTyped
type TDataDef = DataDef StageTyped
@@ -69,8 +78,11 @@ type TPattern = Pattern StageTyped
type TRHS = RHS StageTyped
type TExpr = Expr StageTyped
+instance Pretty (E Kind StageTC) where
+ prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)
+
-typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], Program TType)
+typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram)
typecheck fp source prog =
let (ds1, cs, _, _, progtc) =
runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty)
@@ -269,29 +281,86 @@ tcFunDef (FunDef _ name msig eqs) = do
return (FunDef typ name (TypeSig typ) eqs')
tcFunEq :: CType -> PFunEq -> TCM CFunEq
-tcFunEq = _
+tcFunEq = error "tcFunEq"
+
+newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t))
+instance Monad m => Functor (SolveM v t m) where
+ fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r
+ return (f x, m', r')
+instance Monad m => Applicative (SolveM v t m) where
+ pure x = SolveM $ \m r -> return (x, m, r)
+ (<*>) = ap
+instance Monad m => Monad (SolveM v t m) where
+ SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r
+ let SolveM h = g x
+ h m1 r1
+
+solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t))
+solvemStateGet = SolveM $ \m r -> return (m, m, r)
+
+solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m ()
+solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r)
+
+solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m ()
+solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r)
+
+solvemStateVars :: Monad m => SolveM v t m [v]
+solvemStateVars = Map.keys <$> solvemStateGet
+
+solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t)
+solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet
+
+solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m ()
+solvemStateSet v b = solvemStateUpdate (Map.insert v b)
+
+solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m ()
+solvemLogEq v t = solvemLogUpdate (Map.insert v t)
solveKindVars :: Bag (CKind, CKind, Range) -> TCM ()
-solveKindVars =
- mapM_ $ \(a, b, rng) -> do
- let (subst, First merr) = reduce a b
- forM_ merr $ \(erra, errb) ->
- raise rng $
- "Kind mismatch:\n\
- \- Expected: " ++ pretty a ++ "\n\
- \- Observed: " ++ pretty b ++ "\n\
- \because '" ++ pretty erra ++ "' and '" ++ pretty errb ++ "' don't match"
- let collected :: [(Int, Bag CKind)]
- collected = Map.assocs $ Map.fromListWith (<>) (fmap (second pure) (toList subst))
- _
+solveKindVars cs = do
+ traceShowM cs
+ traceShowM $ solveConstraints
+ reduce
+ (foldMap pure . kindUniVars)
+ (\v repl -> substKind (Map.singleton v repl))
+ (\case KExt () (KUniVar v) -> Just v
+ _ -> Nothing)
+ kindSize
+ (map (\(a, b, _) -> (a, b)) (toList cs))
where
- reduce :: CKind -> CKind -> (Bag (Int, CKind), First (CKind, CKind))
- reduce (KType ()) (KType ()) = mempty
- reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d
+ reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind))
+ -- unification variables produce constraints on a unification variable
+ reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty
reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty)
reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty)
+ -- if lhs and rhs have equal prefixes, recurse
+ reduce (KType ()) (KType ()) = mempty
+ reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d
+ -- otherwise, this is a kind mismatch
reduce k1 k2 = (mempty, pure (k1, k2))
+ kindSize :: CKind -> Int
+ kindSize KType{} = 1
+ kindSize (KFun () a b) = 1 + kindSize a + kindSize b
+ kindSize (KExt () KUniVar{}) = 1
+
+solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType)
+solveConstrs = error "solveConstrs"
+
+substProg :: Map Name TType -> CProgram -> TProgram
+substProg = error "substProg"
+
+substKind :: Map Int CKind -> CKind -> CKind
+substKind _ k@KType{} = k
+substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2)
+substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m)
+
+kindUniVars :: CKind -> Set Int
+kindUniVars = \case
+ KType{} -> mempty
+ KFun () a b -> kindUniVars a <> kindUniVars b
+ KExt () (KUniVar v) -> Set.singleton v
+
allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
[] -> True
@@ -308,3 +377,6 @@ splitKind (KExt _ e) = ([], Left e)
isCEqK :: Constr -> Maybe (CKind, CKind, Range)
isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng)
isCEqK _ = Nothing
+
+foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m
+foldMapM f = getAp . foldMap (Ap . f)
diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs
new file mode 100644
index 0000000..184937c
--- /dev/null
+++ b/src/HSVIS/Typecheck/Solve.hs
@@ -0,0 +1,103 @@
+{-# LANGUAGE ScopedTypeVariables #-}
+module HSVIS.Typecheck.Solve where
+
+import Control.Monad (guard, (>=>))
+import Data.Bifunctor (second)
+import Data.Foldable (toList, foldl')
+import Data.List (sort)
+import Data.Ord (comparing)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+
+import Debug.Trace
+
+import Data.Bag
+import Data.List (minimumBy)
+
+
+data UnifyErr v t
+ = UEUnequal t t
+ | UERecursive v t
+ deriving (Show)
+
+-- | Returns a pair of:
+-- 1. A set of unification errors;
+-- 2. An assignment of the variables that had any constraints on them.
+-- The producedure was successful if the set of errors is empty. Note that
+-- unconstrained variables do not appear in the output.
+solveConstraints
+ :: forall v t. (Ord v, Ord t, Show v, Show t)
+ -- | reduce: take two types and unify them, resulting in:
+ -- 1. A bag of resulting constraints on variables;
+ -- 2. A bag of errors: pairs of two types that are provably distinct but
+ -- need to be equal for the input types to unify.
+ => (t -> t -> (Bag (v, t), Bag (t, t)))
+ -- | Free variables in a type
+ -> (t -> Bag v)
+ -- | \v repl term -> Substitute v by repl in term
+ -> (v -> t -> t -> t)
+ -- | Detect bare-variable types
+ -> (t -> Maybe v)
+ -- | Some kind of size measure on types
+ -> (t -> Int)
+ -- | Equality constraints to solve
+ -> [(t, t)]
+ -> (Map v t, Bag (UnifyErr v t))
+solveConstraints reduce frees subst detect size = \cs ->
+ let (vcs, errs) = foldMap (uncurry reduce) cs
+ asg = Map.fromListWith (<>) (map (second pure) (toList vcs))
+ (errs', asg') = loop asg []
+ errs'' = fmap (uncurry UEUnequal) errs <> errs'
+ in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $
+ trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++
+ concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg'])
+ (asg', errs'')
+ where
+ loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t)
+ loop m eqlog = do
+ traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m]
+ m' <- Map.traverseWithKey
+ (\v ts ->
+ let ts' = bagFromList (dedup (toList ts))
+ -- filter out recursive equations
+ (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts'
+ -- filter out trivial equations (v = v)
+ (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs
+ in (UERecursive v <$> recs, nonrecs'))
+ m
+
+ let msmallestvar = -- var with its smallest RHS, if such a var exists
+ minimumByMay (comparing (size . snd))
+ . map (second (minimumBy (comparing size)))
+ . filter (not . null . snd)
+ $ Map.assocs m'
+
+ case msmallestvar of
+ Nothing -> return $ applyLog eqlog mempty
+ Just (var, smallrhs) -> do
+ let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var)
+ let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss))
+ (fmap (uncurry UEUnequal) errs, ()) -- write the errors
+ let m'' = Map.unionWith (<>)
+ (Map.map (fmap (subst var smallrhs)) (Map.delete var m'))
+ (Map.fromListWith (<>) (map (second pure) (toList newcs)))
+ loop m'' ((var, smallrhs) : eqlog)
+
+ applyLog :: [(v, t)] -> Map v t -> Map v t
+ applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m)
+ applyLog [] m = m
+
+ dedup :: Ord t => [t] -> [t]
+ dedup = uniq . sort
+
+ uniq :: Eq a => [a] -> [a]
+ uniq (x:y:xs) | x == y = uniq (x : xs)
+ | otherwise = x : uniq (y : xs)
+ uniq l = l
+
+ minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a
+ minimumByMay cmp = foldl' min' Nothing
+ where min' mx y = Just $! case mx of
+ Nothing -> y
+ Just x | GT <- cmp x y -> y
+ | otherwise -> x