From 897fefce372f00d3e904e83eb92c83e3e653b5be Mon Sep 17 00:00:00 2001
From: Tom Smeding <t.j.smeding@uu.nl>
Date: Wed, 20 Sep 2023 15:53:59 +0200
Subject: Examples with conditionals

---
 src/AST/Count.hs  | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 src/AST/Pretty.hs | 38 ++++++++++++++++++++++++--------------
 2 files changed, 79 insertions(+), 14 deletions(-)
 create mode 100644 src/AST/Count.hs

(limited to 'src/AST')

diff --git a/src/AST/Count.hs b/src/AST/Count.hs
new file mode 100644
index 0000000..baf132e
--- /dev/null
+++ b/src/AST/Count.hs
@@ -0,0 +1,55 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+module AST.Count where
+
+import AST
+
+
+data Count = Zero | One | Many
+  deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+  Zero <> n = n
+  n <> Zero = n
+  _ <> _ = Many
+instance Monoid Count where
+  mempty = Zero
+
+data Occ = Occ { _occLexical :: Count
+               , _occRuntime :: Count }
+  deriving (Eq)
+instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
+instance Monoid Occ    where mempty = Occ mempty mempty
+
+-- | One of the two branches is taken
+(<||>) :: Occ -> Occ -> Occ
+Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+
+-- | This code is executed many times
+scaleMany :: Occ -> Occ
+scaleMany (Occ l _) = Occ l Many
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx = \case
+  EVar _ _ i | idx2int i == idx2int idx -> Occ One One
+             | otherwise -> mempty
+  ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
+  EPair _ a b -> occCount idx a <> occCount idx b
+  EFst _ e -> occCount idx e
+  ESnd _ e -> occCount idx e
+  ENil _ -> mempty
+  EInl _ _ e -> occCount idx e
+  EInr _ _ e -> occCount idx e
+  ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
+  EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
+  EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
+  EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
+  EConst{} -> mempty
+  EIdx1 _ a b -> occCount idx a <> occCount idx b
+  EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
+  EOp _ _ e -> occCount idx e
+  EMOne _ _ e -> occCount idx e
+  EMScope e -> occCount idx e
+  EMReturn _ e -> occCount idx e
+  EMBind a b -> occCount idx a <> occCount (IS idx) b
+  EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index e793ce1..289294d 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -14,6 +14,7 @@ import Data.Foldable (toList)
 import Data.Functor.Const
 
 import AST
+import AST.Count
 
 
 data Val f env where
@@ -42,6 +43,12 @@ genId = M (\i -> (i, i + 1))
 genName :: M String
 genName = ('x' :) . show <$> genId
 
+genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn ty idx ex
+  | occCount idx ex == mempty = case ty of STNil -> return "()"
+                                           _ -> return "_"
+  | otherwise                 = genName
+
 ppExpr :: SList STy env -> Expr x env t -> String
 ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
   where
@@ -59,7 +66,7 @@ ppExpr' d val = \case
   etop@ELet{} -> do
     let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
         collect val' (ELet _ rhs body) = do
-          name <- genName
+          name <- genNameIfUsedIn (typeOf rhs) IZ body
           (binds, core) <- collect (VPush (Const name) val') body
           rhs' <- ppExpr' 0 val' rhs
           return ((name, rhs') : binds, core)
@@ -101,9 +108,10 @@ ppExpr' d val = \case
 
   ECase _ e a b -> do
     e' <- ppExpr' 0 val e
-    name1 <- genName
+    let STEither t1 t2 = typeOf e
+    name1 <- genNameIfUsedIn t1 IZ a
     a' <- ppExpr' 0 (VPush (Const name1) val) a
-    name2 <- genName
+    name2 <- genNameIfUsedIn t2 IZ b
     b' <- ppExpr' 0 (VPush (Const name2) val) b
     return $ showParen (d > 0) $
       showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
@@ -111,14 +119,14 @@ ppExpr' d val = \case
 
   EBuild1 _ a b -> do
     a' <- ppExpr' 11 val a
-    name <- genName
+    name <- genNameIfUsedIn (STScal STI64) IZ b
     b' <- ppExpr' 0 (VPush (Const name) val) b
     return $ showParen (d > 10) $
       showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
 
   EBuild _ es e -> do
     es' <- mapM (ppExpr' 0 val) es
-    names <- mapM (const genName) es
+    names <- mapM (const genName) es  -- TODO generate underscores
     e' <- ppExpr' 0 (vpushN names val) e
     return $ showParen (d > 10) $
       showString "build ["
@@ -128,8 +136,8 @@ ppExpr' d val = \case
       . showString ("] -> ") . e' . showString ")"
 
   EFold1 _ a b -> do
-    name1 <- genName
-    name2 <- genName
+    name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
+    name2 <- genNameIfUsedIn (typeOf a) IZ a
     a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
     b' <- ppExpr' 11 val b
     return $ showParen (d > 10) $
@@ -185,10 +193,11 @@ ppExpr' d val = \case
     e' <- ppExpr' 11 val e
     return $ showParen (d > 10) $ showString ("return ") . e'
 
-  etop@(EMBind _ EMBind{}) -> do
+  etop@(EMBind _ _) -> do
     let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
         collect val' (EMBind lhs cont) = do
-          name <- genName
+          let STEVM _ t = typeOf lhs
+          name <- genNameIfUsedIn t IZ cont
           (binds, core) <- collect (VPush (Const name) val') cont
           lhs' <- ppExpr' 0 val' lhs
           return ((name, lhs') : binds, core)
@@ -201,11 +210,11 @@ ppExpr' d val = \case
                         (map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds))
       . showString " ; " . core . showString " }"
 
-  EMBind a b -> do
-    a' <- ppExpr' 0 val a
-    name <- genName
-    b' <- ppExpr' 0 (VPush (Const name) val) b
-    return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
+  -- EMBind a b -> do
+  --   a' <- ppExpr' 0 val a
+  --   name <- genNameIfUsedIn IZ b
+  --   b' <- ppExpr' 0 (VPush (Const name) val) b
+  --   return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
 
   EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
 
@@ -220,3 +229,4 @@ operator OLt{} = (Infix, "<")
 operator OLe{} = (Infix, "<=")
 operator OEq{} = (Infix, "==")
 operator ONot = (Prefix, "not")
+operator OIf = (Prefix, "ifB")
-- 
cgit v1.2.3-70-g09d2