summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
blob: baf132e3588f454fa7c9505a87f80797f1765f78 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
module AST.Count where

import AST


data Count = Zero | One | Many
  deriving (Show, Eq, Ord)

instance Semigroup Count where
  Zero <> n = n
  n <> Zero = n
  _ <> _ = Many
instance Monoid Count where
  mempty = Zero

data Occ = Occ { _occLexical :: Count
               , _occRuntime :: Count }
  deriving (Eq)
instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
instance Monoid Occ    where mempty = Occ mempty mempty

-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)

-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l _) = Occ l Many

occCount :: Idx env a -> Expr x env t -> Occ
occCount idx = \case
  EVar _ _ i | idx2int i == idx2int idx -> Occ One One
             | otherwise -> mempty
  ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
  EPair _ a b -> occCount idx a <> occCount idx b
  EFst _ e -> occCount idx e
  ESnd _ e -> occCount idx e
  ENil _ -> mempty
  EInl _ _ e -> occCount idx e
  EInr _ _ e -> occCount idx e
  ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
  EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
  EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
  EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
  EConst{} -> mempty
  EIdx1 _ a b -> occCount idx a <> occCount idx b
  EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
  EOp _ _ e -> occCount idx e
  EMOne _ _ e -> occCount idx e
  EMScope e -> occCount idx e
  EMReturn _ e -> occCount idx e
  EMBind a b -> occCount idx a <> occCount (IS idx) b
  EError{} -> mempty