summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs55
-rw-r--r--src/AST/Pretty.hs38
2 files changed, 79 insertions, 14 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
new file mode 100644
index 0000000..baf132e
--- /dev/null
+++ b/src/AST/Count.hs
@@ -0,0 +1,55 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+module AST.Count where
+
+import AST
+
+
+data Count = Zero | One | Many
+ deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+ Zero <> n = n
+ n <> Zero = n
+ _ <> _ = Many
+instance Monoid Count where
+ mempty = Zero
+
+data Occ = Occ { _occLexical :: Count
+ , _occRuntime :: Count }
+ deriving (Eq)
+instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
+instance Monoid Occ where mempty = Occ mempty mempty
+
+-- | One of the two branches is taken
+(<||>) :: Occ -> Occ -> Occ
+Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+
+-- | This code is executed many times
+scaleMany :: Occ -> Occ
+scaleMany (Occ l _) = Occ l Many
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx = \case
+ EVar _ _ i | idx2int i == idx2int idx -> Occ One One
+ | otherwise -> mempty
+ ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
+ EPair _ a b -> occCount idx a <> occCount idx b
+ EFst _ e -> occCount idx e
+ ESnd _ e -> occCount idx e
+ ENil _ -> mempty
+ EInl _ _ e -> occCount idx e
+ EInr _ _ e -> occCount idx e
+ ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
+ EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
+ EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
+ EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
+ EConst{} -> mempty
+ EIdx1 _ a b -> occCount idx a <> occCount idx b
+ EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
+ EOp _ _ e -> occCount idx e
+ EMOne _ _ e -> occCount idx e
+ EMScope e -> occCount idx e
+ EMReturn _ e -> occCount idx e
+ EMBind a b -> occCount idx a <> occCount (IS idx) b
+ EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index e793ce1..289294d 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -14,6 +14,7 @@ import Data.Foldable (toList)
import Data.Functor.Const
import AST
+import AST.Count
data Val f env where
@@ -42,6 +43,12 @@ genId = M (\i -> (i, i + 1))
genName :: M String
genName = ('x' :) . show <$> genId
+genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn ty idx ex
+ | occCount idx ex == mempty = case ty of STNil -> return "()"
+ _ -> return "_"
+ | otherwise = genName
+
ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
where
@@ -59,7 +66,7 @@ ppExpr' d val = \case
etop@ELet{} -> do
let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
collect val' (ELet _ rhs body) = do
- name <- genName
+ name <- genNameIfUsedIn (typeOf rhs) IZ body
(binds, core) <- collect (VPush (Const name) val') body
rhs' <- ppExpr' 0 val' rhs
return ((name, rhs') : binds, core)
@@ -101,9 +108,10 @@ ppExpr' d val = \case
ECase _ e a b -> do
e' <- ppExpr' 0 val e
- name1 <- genName
+ let STEither t1 t2 = typeOf e
+ name1 <- genNameIfUsedIn t1 IZ a
a' <- ppExpr' 0 (VPush (Const name1) val) a
- name2 <- genName
+ name2 <- genNameIfUsedIn t2 IZ b
b' <- ppExpr' 0 (VPush (Const name2) val) b
return $ showParen (d > 0) $
showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
@@ -111,14 +119,14 @@ ppExpr' d val = \case
EBuild1 _ a b -> do
a' <- ppExpr' 11 val a
- name <- genName
+ name <- genNameIfUsedIn (STScal STI64) IZ b
b' <- ppExpr' 0 (VPush (Const name) val) b
return $ showParen (d > 10) $
showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
EBuild _ es e -> do
es' <- mapM (ppExpr' 0 val) es
- names <- mapM (const genName) es
+ names <- mapM (const genName) es -- TODO generate underscores
e' <- ppExpr' 0 (vpushN names val) e
return $ showParen (d > 10) $
showString "build ["
@@ -128,8 +136,8 @@ ppExpr' d val = \case
. showString ("] -> ") . e' . showString ")"
EFold1 _ a b -> do
- name1 <- genName
- name2 <- genName
+ name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
+ name2 <- genNameIfUsedIn (typeOf a) IZ a
a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $
@@ -185,10 +193,11 @@ ppExpr' d val = \case
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString ("return ") . e'
- etop@(EMBind _ EMBind{}) -> do
+ etop@(EMBind _ _) -> do
let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
collect val' (EMBind lhs cont) = do
- name <- genName
+ let STEVM _ t = typeOf lhs
+ name <- genNameIfUsedIn t IZ cont
(binds, core) <- collect (VPush (Const name) val') cont
lhs' <- ppExpr' 0 val' lhs
return ((name, lhs') : binds, core)
@@ -201,11 +210,11 @@ ppExpr' d val = \case
(map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds))
. showString " ; " . core . showString " }"
- EMBind a b -> do
- a' <- ppExpr' 0 val a
- name <- genName
- b' <- ppExpr' 0 (VPush (Const name) val) b
- return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
+ -- EMBind a b -> do
+ -- a' <- ppExpr' 0 val a
+ -- name <- genNameIfUsedIn IZ b
+ -- b' <- ppExpr' 0 (VPush (Const name) val) b
+ -- return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
@@ -220,3 +229,4 @@ operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
+operator OIf = (Prefix, "ifB")