From e7bed242ba52e6d3233928f2c6189e701cfa5e4c Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 14 Mar 2024 23:21:53 +0100 Subject: Some typechecker work --- src/HSVIS/AST.hs | 11 ++--- src/HSVIS/Parser.hs | 12 ++--- src/HSVIS/Typecheck.hs | 112 +++++++++++++++++++++++++++++++++++-------- src/HSVIS/Typecheck/Solve.hs | 103 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+), 33 deletions(-) create mode 100644 src/HSVIS/Typecheck/Solve.hs (limited to 'src/HSVIS') diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs index f95a3cc..8bb2d6c 100644 --- a/src/HSVIS/AST.hs +++ b/src/HSVIS/AST.hs @@ -1,15 +1,10 @@ -{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ConstrainedClassMethods #-} module HSVIS.AST where @@ -63,7 +58,7 @@ type family EMapperTelescope fs s1 s2 a where newtype Name = Name String - deriving (Show, Eq) + deriving (Show, Eq, Ord) data Program s = Program [DataDef s] [FunDef s] deriving instance (Show (DataDef s), Show (FunDef s)) => Show (Program s) @@ -89,6 +84,8 @@ data Kind s -- extension point | KExt (X Kind s) !(E Kind s) deriving instance (Show (X Kind s), Show (E Kind s)) => Show (Kind s) +deriving instance (Eq (X Kind s), Eq (E Kind s)) => Eq (Kind s) +deriving instance (Ord (X Kind s), Ord (E Kind s)) => Ord (Kind s) data Type s = TApp (X Type s) (Type s) [Type s] diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs index 3251989..0df4aa8 100644 --- a/src/HSVIS/Parser.hs +++ b/src/HSVIS/Parser.hs @@ -162,9 +162,9 @@ instance KnownFallible fail => MonadChronicle (Bag Diagnostic) (Parser fail) whe (kok ps mempty def) condemn (Parser f) = Parser $ \ctx ps kok kfat kbt -> f ctx ps - (\ps' errs x -> case errs of - BZero -> kok ps' mempty x - _ -> kfat errs) + (\ps' errs x -> if null errs + then kok ps' mempty x + else kfat errs) kfat kbt retcon g (Parser f) = Parser $ \ctx ps kok kfat kbt -> @@ -180,9 +180,9 @@ instance KnownFallible fail => MonadChronicle (Bag Diagnostic) (Parser fail) whe parse :: FilePath -> String -> ([Diagnostic], Maybe PProgram) parse fp source = runParser pProgram (Context fp (lines source) []) (PS (Pos 0 0) (Pos 0 0) source) - (\_ errs res -> case errs of - BZero -> ([], Just res) - _ -> (toList errs, Just res)) + (\_ errs res -> if null errs + then ([], Just res) + else (toList errs, Just res)) (\errs -> (toList errs, Nothing)) () -- the program parser cannot fail! :D diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index de9d7db..ba853a0 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} @@ -7,11 +8,16 @@ module HSVIS.Typecheck where import Control.Monad -import Data.Bifunctor (first, second) +import Data.Bifunctor (first) import Data.Foldable (toList) import Data.Map.Strict (Map) -import Data.Monoid (First(..)) +import Data.Maybe (fromMaybe) +import Data.Monoid (Ap(..)) import qualified Data.Map.Strict as Map +import Data.Set (Set) +import qualified Data.Set as Set + +import Debug.Trace import Data.Bag import Data.List.NonEmpty.Util @@ -19,6 +25,7 @@ import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic import HSVIS.Pretty +import HSVIS.Typecheck.Solve data StageTC @@ -33,7 +40,8 @@ type instance X RHS StageTC = CType type instance X Expr StageTC = CType data instance E Type StageTC = TUniVar Int deriving (Show) -data instance E Kind StageTC = KUniVar Int deriving (Show) +data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) +data instance E TypeSig StageTC deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC @@ -58,6 +66,7 @@ type instance X Expr StageTyped = TType data instance E Type StageTyped deriving (Show) data instance E Kind StageTyped deriving (Show) +data instance E TypeSig StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped @@ -69,8 +78,11 @@ type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped +instance Pretty (E Kind StageTC) where + prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) + -typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], Program TType) +typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram) typecheck fp source prog = let (ds1, cs, _, _, progtc) = runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty) @@ -269,29 +281,86 @@ tcFunDef (FunDef _ name msig eqs) = do return (FunDef typ name (TypeSig typ) eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq -tcFunEq = _ +tcFunEq = error "tcFunEq" + +newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t)) +instance Monad m => Functor (SolveM v t m) where + fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r + return (f x, m', r') +instance Monad m => Applicative (SolveM v t m) where + pure x = SolveM $ \m r -> return (x, m, r) + (<*>) = ap +instance Monad m => Monad (SolveM v t m) where + SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r + let SolveM h = g x + h m1 r1 + +solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t)) +solvemStateGet = SolveM $ \m r -> return (m, m, r) + +solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m () +solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r) + +solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m () +solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r) + +solvemStateVars :: Monad m => SolveM v t m [v] +solvemStateVars = Map.keys <$> solvemStateGet + +solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t) +solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet + +solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m () +solvemStateSet v b = solvemStateUpdate (Map.insert v b) + +solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m () +solvemLogEq v t = solvemLogUpdate (Map.insert v t) solveKindVars :: Bag (CKind, CKind, Range) -> TCM () -solveKindVars = - mapM_ $ \(a, b, rng) -> do - let (subst, First merr) = reduce a b - forM_ merr $ \(erra, errb) -> - raise rng $ - "Kind mismatch:\n\ - \- Expected: " ++ pretty a ++ "\n\ - \- Observed: " ++ pretty b ++ "\n\ - \because '" ++ pretty erra ++ "' and '" ++ pretty errb ++ "' don't match" - let collected :: [(Int, Bag CKind)] - collected = Map.assocs $ Map.fromListWith (<>) (fmap (second pure) (toList subst)) - _ +solveKindVars cs = do + traceShowM cs + traceShowM $ solveConstraints + reduce + (foldMap pure . kindUniVars) + (\v repl -> substKind (Map.singleton v repl)) + (\case KExt () (KUniVar v) -> Just v + _ -> Nothing) + kindSize + (map (\(a, b, _) -> (a, b)) (toList cs)) where - reduce :: CKind -> CKind -> (Bag (Int, CKind), First (CKind, CKind)) - reduce (KType ()) (KType ()) = mempty - reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d + reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind)) + -- unification variables produce constraints on a unification variable + reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty) reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty) + -- if lhs and rhs have equal prefixes, recurse + reduce (KType ()) (KType ()) = mempty + reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d + -- otherwise, this is a kind mismatch reduce k1 k2 = (mempty, pure (k1, k2)) + kindSize :: CKind -> Int + kindSize KType{} = 1 + kindSize (KFun () a b) = 1 + kindSize a + kindSize b + kindSize (KExt () KUniVar{}) = 1 + +solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType) +solveConstrs = error "solveConstrs" + +substProg :: Map Name TType -> CProgram -> TProgram +substProg = error "substProg" + +substKind :: Map Int CKind -> CKind -> CKind +substKind _ k@KType{} = k +substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2) +substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m) + +kindUniVars :: CKind -> Set Int +kindUniVars = \case + KType{} -> mempty + KFun () a b -> kindUniVars a <> kindUniVars b + KExt () (KUniVar v) -> Set.singleton v + allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True @@ -308,3 +377,6 @@ splitKind (KExt _ e) = ([], Left e) isCEqK :: Constr -> Maybe (CKind, CKind, Range) isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng) isCEqK _ = Nothing + +foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m +foldMapM f = getAp . foldMap (Ap . f) diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs new file mode 100644 index 0000000..184937c --- /dev/null +++ b/src/HSVIS/Typecheck/Solve.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE ScopedTypeVariables #-} +module HSVIS.Typecheck.Solve where + +import Control.Monad (guard, (>=>)) +import Data.Bifunctor (second) +import Data.Foldable (toList, foldl') +import Data.List (sort) +import Data.Ord (comparing) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map + +import Debug.Trace + +import Data.Bag +import Data.List (minimumBy) + + +data UnifyErr v t + = UEUnequal t t + | UERecursive v t + deriving (Show) + +-- | Returns a pair of: +-- 1. A set of unification errors; +-- 2. An assignment of the variables that had any constraints on them. +-- The producedure was successful if the set of errors is empty. Note that +-- unconstrained variables do not appear in the output. +solveConstraints + :: forall v t. (Ord v, Ord t, Show v, Show t) + -- | reduce: take two types and unify them, resulting in: + -- 1. A bag of resulting constraints on variables; + -- 2. A bag of errors: pairs of two types that are provably distinct but + -- need to be equal for the input types to unify. + => (t -> t -> (Bag (v, t), Bag (t, t))) + -- | Free variables in a type + -> (t -> Bag v) + -- | \v repl term -> Substitute v by repl in term + -> (v -> t -> t -> t) + -- | Detect bare-variable types + -> (t -> Maybe v) + -- | Some kind of size measure on types + -> (t -> Int) + -- | Equality constraints to solve + -> [(t, t)] + -> (Map v t, Bag (UnifyErr v t)) +solveConstraints reduce frees subst detect size = \cs -> + let (vcs, errs) = foldMap (uncurry reduce) cs + asg = Map.fromListWith (<>) (map (second pure) (toList vcs)) + (errs', asg') = loop asg [] + errs'' = fmap (uncurry UEUnequal) errs <> errs' + in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $ + trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ + concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg']) + (asg', errs'') + where + loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) + loop m eqlog = do + traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m] + m' <- Map.traverseWithKey + (\v ts -> + let ts' = bagFromList (dedup (toList ts)) + -- filter out recursive equations + (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts' + -- filter out trivial equations (v = v) + (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs + in (UERecursive v <$> recs, nonrecs')) + m + + let msmallestvar = -- var with its smallest RHS, if such a var exists + minimumByMay (comparing (size . snd)) + . map (second (minimumBy (comparing size))) + . filter (not . null . snd) + $ Map.assocs m' + + case msmallestvar of + Nothing -> return $ applyLog eqlog mempty + Just (var, smallrhs) -> do + let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var) + let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss)) + (fmap (uncurry UEUnequal) errs, ()) -- write the errors + let m'' = Map.unionWith (<>) + (Map.map (fmap (subst var smallrhs)) (Map.delete var m')) + (Map.fromListWith (<>) (map (second pure) (toList newcs))) + loop m'' ((var, smallrhs) : eqlog) + + applyLog :: [(v, t)] -> Map v t -> Map v t + applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) + applyLog [] m = m + + dedup :: Ord t => [t] -> [t] + dedup = uniq . sort + + uniq :: Eq a => [a] -> [a] + uniq (x:y:xs) | x == y = uniq (x : xs) + | otherwise = x : uniq (y : xs) + uniq l = l + + minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a + minimumByMay cmp = foldl' min' Nothing + where min' mx y = Just $! case mx of + Nothing -> y + Just x | GT <- cmp x y -> y + | otherwise -> x -- cgit v1.2.3-70-g09d2