diff options
-rw-r--r-- | app/Main.hs | 8 | ||||
-rw-r--r-- | hs-visinter.cabal | 1 | ||||
-rw-r--r-- | src/Data/Bag.hs | 65 | ||||
-rw-r--r-- | src/HSVIS/AST.hs | 11 | ||||
-rw-r--r-- | src/HSVIS/Parser.hs | 12 | ||||
-rw-r--r-- | src/HSVIS/Typecheck.hs | 112 | ||||
-rw-r--r-- | src/HSVIS/Typecheck/Solve.hs | 103 |
7 files changed, 275 insertions, 37 deletions
diff --git a/app/Main.hs b/app/Main.hs index bf4fcfd..e5144f1 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -8,6 +8,7 @@ import System.Exit (die, exitFailure) import HSVIS.Diagnostic import HSVIS.Parser +import HSVIS.Typecheck main :: IO () @@ -21,8 +22,13 @@ main = do (errs, Nothing) -> do sequence_ $ intersperse (putStrLn "") (map (putStrLn . printDiagnostic) errs) exitFailure - (errs, res) -> do + (errs, Just res) -> do sequence_ $ intersperse (putStrLn "") (map (putStrLn . printDiagnostic) errs) return res print prog + + let (errs, tprog) = typecheck fname source prog + sequence_ $ intersperse (putStrLn "") (map (putStrLn . printDiagnostic) errs) + + print tprog diff --git a/hs-visinter.cabal b/hs-visinter.cabal index 0b21e36..48d3456 100644 --- a/hs-visinter.cabal +++ b/hs-visinter.cabal @@ -17,6 +17,7 @@ library HSVIS.Parser HSVIS.Pretty HSVIS.Typecheck + HSVIS.Typecheck.Solve build-depends: base >= 4.16 && < 4.20, containers >= 0.6.3.1 && < 0.8, diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs index f1173e4..ba1f912 100644 --- a/src/Data/Bag.hs +++ b/src/Data/Bag.hs @@ -1,19 +1,41 @@ {-# LANGUAGE DeriveTraversable #-} -module Data.Bag where +module Data.Bag ( + Bag, + bagFromList, + bagFilter, + bagPartition, +) where + +import Data.Bifunctor (bimap) +import Data.Foldable (toList) +import Data.List.NonEmpty (NonEmpty((:|))) +import Data.Maybe (mapMaybe) +import Data.Semigroup data Bag a = BZero | BOne a | BTwo (Bag a) (Bag a) - deriving (Functor, Foldable, Traversable) + | BList [Bag a] -- make mconcat efficient + | BList' [a] -- make bagFromList efficient + deriving (Functor, Traversable) + +instance Show a => Show (Bag a) where + showsPrec d b = showParen (d > 10) $ + showString "Bag " . showList (toList b) instance Semigroup (Bag a) where BZero <> b = b b <> BZero = b b1 <> b2 = BTwo b1 b2 -instance Monoid (Bag a) where mempty = BZero + sconcat (b :| bs) = b <> mconcat bs + stimes n b = mconcat (replicate (fromIntegral n) b) + +instance Monoid (Bag a) where + mempty = BZero + mconcat = BList instance Applicative Bag where pure = BOne @@ -22,6 +44,36 @@ instance Applicative Bag where _ <*> BZero = BZero BOne f <*> b = f <$> b BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b) + BList bs <*> b = BList (map (<*> b) bs) + BList' xs <*> b = BList (map BOne xs) <*> b + +instance Foldable Bag where + foldMap _ BZero = mempty + foldMap f (BOne x) = f x + foldMap f (BTwo b1 b2) = foldMap f b1 <> foldMap f b2 + foldMap f (BList l) = foldMap (foldMap f) l + foldMap f (BList' l) = foldMap f l + + toList (BList' xs) = xs + toList b = foldr (:) [] b + + null BZero = True + null BOne{} = False + null (BTwo b1 b2) = null b1 && null b2 + null (BList l) = all null l + null (BList' l) = null l + +bagFromList :: [a] -> Bag a +bagFromList = BList' + +bagFilter :: (a -> Maybe b) -> Bag a -> Bag b +bagFilter _ BZero = BZero +bagFilter f (BOne x) + | Just y <- f x = BOne y + | otherwise = BZero +bagFilter f (BTwo b1 b2) = bagFilter f b1 <> bagFilter f b2 +bagFilter f (BList bs) = BList (map (bagFilter f) bs) +bagFilter f (BList' xs) = BList' (mapMaybe f xs) bagPartition :: (a -> Maybe b) -> Bag a -> (Bag b, Bag a) bagPartition _ BZero = (BZero, BZero) @@ -29,3 +81,10 @@ bagPartition f (BOne x) | Just y <- f x = (BOne y, BZero) | otherwise = (BZero, BOne x) bagPartition f (BTwo b1 b2) = bagPartition f b1 <> bagPartition f b2 +bagPartition f (BList bs) = foldMap (bagPartition f) bs +bagPartition f (BList' xs) = + bimap bagFromList bagFromList $ + foldr (\x (l1,l2) -> case f x of + Just y -> (y : l1, l2) + Nothing -> (l1, x : l2)) + ([], []) xs diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs index f95a3cc..8bb2d6c 100644 --- a/src/HSVIS/AST.hs +++ b/src/HSVIS/AST.hs @@ -1,15 +1,10 @@ -{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ConstrainedClassMethods #-} module HSVIS.AST where @@ -63,7 +58,7 @@ type family EMapperTelescope fs s1 s2 a where newtype Name = Name String - deriving (Show, Eq) + deriving (Show, Eq, Ord) data Program s = Program [DataDef s] [FunDef s] deriving instance (Show (DataDef s), Show (FunDef s)) => Show (Program s) @@ -89,6 +84,8 @@ data Kind s -- extension point | KExt (X Kind s) !(E Kind s) deriving instance (Show (X Kind s), Show (E Kind s)) => Show (Kind s) +deriving instance (Eq (X Kind s), Eq (E Kind s)) => Eq (Kind s) +deriving instance (Ord (X Kind s), Ord (E Kind s)) => Ord (Kind s) data Type s = TApp (X Type s) (Type s) [Type s] diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs index 3251989..0df4aa8 100644 --- a/src/HSVIS/Parser.hs +++ b/src/HSVIS/Parser.hs @@ -162,9 +162,9 @@ instance KnownFallible fail => MonadChronicle (Bag Diagnostic) (Parser fail) whe (kok ps mempty def) condemn (Parser f) = Parser $ \ctx ps kok kfat kbt -> f ctx ps - (\ps' errs x -> case errs of - BZero -> kok ps' mempty x - _ -> kfat errs) + (\ps' errs x -> if null errs + then kok ps' mempty x + else kfat errs) kfat kbt retcon g (Parser f) = Parser $ \ctx ps kok kfat kbt -> @@ -180,9 +180,9 @@ instance KnownFallible fail => MonadChronicle (Bag Diagnostic) (Parser fail) whe parse :: FilePath -> String -> ([Diagnostic], Maybe PProgram) parse fp source = runParser pProgram (Context fp (lines source) []) (PS (Pos 0 0) (Pos 0 0) source) - (\_ errs res -> case errs of - BZero -> ([], Just res) - _ -> (toList errs, Just res)) + (\_ errs res -> if null errs + then ([], Just res) + else (toList errs, Just res)) (\errs -> (toList errs, Nothing)) () -- the program parser cannot fail! :D diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index de9d7db..ba853a0 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} @@ -7,11 +8,16 @@ module HSVIS.Typecheck where import Control.Monad -import Data.Bifunctor (first, second) +import Data.Bifunctor (first) import Data.Foldable (toList) import Data.Map.Strict (Map) -import Data.Monoid (First(..)) +import Data.Maybe (fromMaybe) +import Data.Monoid (Ap(..)) import qualified Data.Map.Strict as Map +import Data.Set (Set) +import qualified Data.Set as Set + +import Debug.Trace import Data.Bag import Data.List.NonEmpty.Util @@ -19,6 +25,7 @@ import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic import HSVIS.Pretty +import HSVIS.Typecheck.Solve data StageTC @@ -33,7 +40,8 @@ type instance X RHS StageTC = CType type instance X Expr StageTC = CType data instance E Type StageTC = TUniVar Int deriving (Show) -data instance E Kind StageTC = KUniVar Int deriving (Show) +data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) +data instance E TypeSig StageTC deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC @@ -58,6 +66,7 @@ type instance X Expr StageTyped = TType data instance E Type StageTyped deriving (Show) data instance E Kind StageTyped deriving (Show) +data instance E TypeSig StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped @@ -69,8 +78,11 @@ type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped +instance Pretty (E Kind StageTC) where + prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) + -typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], Program TType) +typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram) typecheck fp source prog = let (ds1, cs, _, _, progtc) = runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty) @@ -269,29 +281,86 @@ tcFunDef (FunDef _ name msig eqs) = do return (FunDef typ name (TypeSig typ) eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq -tcFunEq = _ +tcFunEq = error "tcFunEq" + +newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t)) +instance Monad m => Functor (SolveM v t m) where + fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r + return (f x, m', r') +instance Monad m => Applicative (SolveM v t m) where + pure x = SolveM $ \m r -> return (x, m, r) + (<*>) = ap +instance Monad m => Monad (SolveM v t m) where + SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r + let SolveM h = g x + h m1 r1 + +solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t)) +solvemStateGet = SolveM $ \m r -> return (m, m, r) + +solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m () +solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r) + +solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m () +solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r) + +solvemStateVars :: Monad m => SolveM v t m [v] +solvemStateVars = Map.keys <$> solvemStateGet + +solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t) +solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet + +solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m () +solvemStateSet v b = solvemStateUpdate (Map.insert v b) + +solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m () +solvemLogEq v t = solvemLogUpdate (Map.insert v t) solveKindVars :: Bag (CKind, CKind, Range) -> TCM () -solveKindVars = - mapM_ $ \(a, b, rng) -> do - let (subst, First merr) = reduce a b - forM_ merr $ \(erra, errb) -> - raise rng $ - "Kind mismatch:\n\ - \- Expected: " ++ pretty a ++ "\n\ - \- Observed: " ++ pretty b ++ "\n\ - \because '" ++ pretty erra ++ "' and '" ++ pretty errb ++ "' don't match" - let collected :: [(Int, Bag CKind)] - collected = Map.assocs $ Map.fromListWith (<>) (fmap (second pure) (toList subst)) - _ +solveKindVars cs = do + traceShowM cs + traceShowM $ solveConstraints + reduce + (foldMap pure . kindUniVars) + (\v repl -> substKind (Map.singleton v repl)) + (\case KExt () (KUniVar v) -> Just v + _ -> Nothing) + kindSize + (map (\(a, b, _) -> (a, b)) (toList cs)) where - reduce :: CKind -> CKind -> (Bag (Int, CKind), First (CKind, CKind)) - reduce (KType ()) (KType ()) = mempty - reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d + reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind)) + -- unification variables produce constraints on a unification variable + reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty) reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty) + -- if lhs and rhs have equal prefixes, recurse + reduce (KType ()) (KType ()) = mempty + reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d + -- otherwise, this is a kind mismatch reduce k1 k2 = (mempty, pure (k1, k2)) + kindSize :: CKind -> Int + kindSize KType{} = 1 + kindSize (KFun () a b) = 1 + kindSize a + kindSize b + kindSize (KExt () KUniVar{}) = 1 + +solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType) +solveConstrs = error "solveConstrs" + +substProg :: Map Name TType -> CProgram -> TProgram +substProg = error "substProg" + +substKind :: Map Int CKind -> CKind -> CKind +substKind _ k@KType{} = k +substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2) +substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m) + +kindUniVars :: CKind -> Set Int +kindUniVars = \case + KType{} -> mempty + KFun () a b -> kindUniVars a <> kindUniVars b + KExt () (KUniVar v) -> Set.singleton v + allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True @@ -308,3 +377,6 @@ splitKind (KExt _ e) = ([], Left e) isCEqK :: Constr -> Maybe (CKind, CKind, Range) isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng) isCEqK _ = Nothing + +foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m +foldMapM f = getAp . foldMap (Ap . f) diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs new file mode 100644 index 0000000..184937c --- /dev/null +++ b/src/HSVIS/Typecheck/Solve.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE ScopedTypeVariables #-} +module HSVIS.Typecheck.Solve where + +import Control.Monad (guard, (>=>)) +import Data.Bifunctor (second) +import Data.Foldable (toList, foldl') +import Data.List (sort) +import Data.Ord (comparing) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map + +import Debug.Trace + +import Data.Bag +import Data.List (minimumBy) + + +data UnifyErr v t + = UEUnequal t t + | UERecursive v t + deriving (Show) + +-- | Returns a pair of: +-- 1. A set of unification errors; +-- 2. An assignment of the variables that had any constraints on them. +-- The producedure was successful if the set of errors is empty. Note that +-- unconstrained variables do not appear in the output. +solveConstraints + :: forall v t. (Ord v, Ord t, Show v, Show t) + -- | reduce: take two types and unify them, resulting in: + -- 1. A bag of resulting constraints on variables; + -- 2. A bag of errors: pairs of two types that are provably distinct but + -- need to be equal for the input types to unify. + => (t -> t -> (Bag (v, t), Bag (t, t))) + -- | Free variables in a type + -> (t -> Bag v) + -- | \v repl term -> Substitute v by repl in term + -> (v -> t -> t -> t) + -- | Detect bare-variable types + -> (t -> Maybe v) + -- | Some kind of size measure on types + -> (t -> Int) + -- | Equality constraints to solve + -> [(t, t)] + -> (Map v t, Bag (UnifyErr v t)) +solveConstraints reduce frees subst detect size = \cs -> + let (vcs, errs) = foldMap (uncurry reduce) cs + asg = Map.fromListWith (<>) (map (second pure) (toList vcs)) + (errs', asg') = loop asg [] + errs'' = fmap (uncurry UEUnequal) errs <> errs' + in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $ + trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ + concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg']) + (asg', errs'') + where + loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) + loop m eqlog = do + traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m] + m' <- Map.traverseWithKey + (\v ts -> + let ts' = bagFromList (dedup (toList ts)) + -- filter out recursive equations + (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts' + -- filter out trivial equations (v = v) + (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs + in (UERecursive v <$> recs, nonrecs')) + m + + let msmallestvar = -- var with its smallest RHS, if such a var exists + minimumByMay (comparing (size . snd)) + . map (second (minimumBy (comparing size))) + . filter (not . null . snd) + $ Map.assocs m' + + case msmallestvar of + Nothing -> return $ applyLog eqlog mempty + Just (var, smallrhs) -> do + let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var) + let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss)) + (fmap (uncurry UEUnequal) errs, ()) -- write the errors + let m'' = Map.unionWith (<>) + (Map.map (fmap (subst var smallrhs)) (Map.delete var m')) + (Map.fromListWith (<>) (map (second pure) (toList newcs))) + loop m'' ((var, smallrhs) : eqlog) + + applyLog :: [(v, t)] -> Map v t -> Map v t + applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) + applyLog [] m = m + + dedup :: Ord t => [t] -> [t] + dedup = uniq . sort + + uniq :: Eq a => [a] -> [a] + uniq (x:y:xs) | x == y = uniq (x : xs) + | otherwise = x : uniq (y : xs) + uniq l = l + + minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a + minimumByMay cmp = foldl' min' Nothing + where min' mx y = Just $! case mx of + Nothing -> y + Just x | GT <- cmp x y -> y + | otherwise -> x |