aboutsummaryrefslogtreecommitdiff
path: root/typecheck/CC
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck/CC')
-rw-r--r--typecheck/CC/Typecheck.hs12
-rw-r--r--typecheck/CC/Typecheck/Types.hs3
2 files changed, 13 insertions, 2 deletions
diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs
index 824a714..292eaeb 100644
--- a/typecheck/CC/Typecheck.hs
+++ b/typecheck/CC/Typecheck.hs
@@ -146,6 +146,14 @@ instantiate (T.TypeScheme bnds ty) = do
freshenFrees :: Env -> T.Type -> TM T.Type
freshenFrees env = instantiate . generalise env
+replaceRigid :: T.Type -> T.Type
+replaceRigid (T.TFun t1 t2) = T.TFun (replaceRigid t1) (replaceRigid t2)
+replaceRigid T.TInt = T.TInt
+replaceRigid (T.TTup ts) = T.TTup (map (replaceRigid) ts)
+replaceRigid (T.TNamed n ts) = T.TNamed n (map replaceRigid ts)
+replaceRigid (T.TUnion ts) = T.TUnion (Set.map replaceRigid ts)
+replaceRigid (T.TyVar _ v) = T.TyVar T.Rigid v
+
data UnifyContext = UnifyContext SourceRange T.Type T.Type
unify :: SourceRange -> T.Type -> T.Type -> TM Subst
@@ -207,6 +215,10 @@ infer env expr = case expr of
S.Annot sr subex ty -> do
(theta1, subex') <- infer env subex
ty' <- convertType (envAliases env) sr ty
+ -- Make sure the type of the subexpression matches the type with rigid
+ -- variables, then make it instantiable variables instead for the rest
+ -- of the code.
+ void $ unify sr (T.exprType subex') (replaceRigid ty')
theta2 <- unify sr (T.exprType subex') ty'
return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity
diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs
index 3f3c471..3009ca1 100644
--- a/typecheck/CC/Typecheck/Types.hs
+++ b/typecheck/CC/Typecheck/Types.hs
@@ -99,5 +99,4 @@ convertType' aliases extraVars sr origtype = do
convert mp (S.TTup ts) = T.TTup (map (convert mp) ts)
convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts)
convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts)
- -- TODO: Should this be Rigid? I really don't know how this works.
- convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n)
+ convert mp (S.TyVar n) = T.TyVar T.Instantiable (mp Map.! n)