summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs55
1 files changed, 55 insertions, 0 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
new file mode 100644
index 0000000..baf132e
--- /dev/null
+++ b/src/AST/Count.hs
@@ -0,0 +1,55 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+module AST.Count where
+
+import AST
+
+
+data Count = Zero | One | Many
+ deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+ Zero <> n = n
+ n <> Zero = n
+ _ <> _ = Many
+instance Monoid Count where
+ mempty = Zero
+
+data Occ = Occ { _occLexical :: Count
+ , _occRuntime :: Count }
+ deriving (Eq)
+instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
+instance Monoid Occ where mempty = Occ mempty mempty
+
+-- | One of the two branches is taken
+(<||>) :: Occ -> Occ -> Occ
+Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+
+-- | This code is executed many times
+scaleMany :: Occ -> Occ
+scaleMany (Occ l _) = Occ l Many
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx = \case
+ EVar _ _ i | idx2int i == idx2int idx -> Occ One One
+ | otherwise -> mempty
+ ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
+ EPair _ a b -> occCount idx a <> occCount idx b
+ EFst _ e -> occCount idx e
+ ESnd _ e -> occCount idx e
+ ENil _ -> mempty
+ EInl _ _ e -> occCount idx e
+ EInr _ _ e -> occCount idx e
+ ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
+ EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
+ EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
+ EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
+ EConst{} -> mempty
+ EIdx1 _ a b -> occCount idx a <> occCount idx b
+ EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
+ EOp _ _ e -> occCount idx e
+ EMOne _ _ e -> occCount idx e
+ EMScope e -> occCount idx e
+ EMReturn _ e -> occCount idx e
+ EMBind a b -> occCount idx a <> occCount (IS idx) b
+ EError{} -> mempty