summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
blob: 289c1fb200a888bf41be1ce2f9a8daa2eb8e6a4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where

import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))

import AST
import Data


data Count = Zero | One | Many
  deriving (Show, Eq, Ord)

instance Semigroup Count where
  Zero <> n = n
  n <> Zero = n
  _ <> _ = Many
instance Monoid Count where
  mempty = Zero

data Occ = Occ { _occLexical :: Count
               , _occRuntime :: Count }
  deriving (Eq, Generic)
  deriving (Semigroup, Monoid) via Generically Occ

-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)

-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l _) = Occ l Many

occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
  getConst . occCountGeneral
    (\i o -> if idx2int i == idx2int idx then Const o else mempty)
    (\(Const o) -> Const o)
    (\_ (Const o) -> Const o)
    (\(Const o1) (Const o2) -> Const (o1 <||> o2))
    (\(Const o) -> Const (scaleMany o))


data OccEnv env where
  OccEnd :: OccEnv env  -- not necessarily top!
  OccPush :: OccEnv env -> Occ -> OccEnv (t : env)

instance Semigroup (OccEnv env) where
  OccEnd <> e = e
  e <> OccEnd = e
  OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')

instance Monoid (OccEnv env) where
  mempty = OccEnd

onehotOccEnv :: Idx env t -> Occ -> OccEnv env
onehotOccEnv IZ v = OccPush OccEnd v
onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty

(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
OccEnd <||>! e = e
e <||>! OccEnd = e
OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')

scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)

occEnvPop :: OccEnv (t : env) -> OccEnv env
occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd

occCountAll :: Expr x env t -> OccEnv env
occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv
  where
    occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
    occEnvPopN _ OccEnd = OccEnd
    occEnvPopN SZ e = e
    occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e

occCountGeneral :: forall r env t x.
                   (forall env'. Monoid (r env'))
                => (forall env' a. Idx env' a -> Occ -> r env')  -- ^ one-hot
                -> (forall env' a. r (a : env') -> r env')  -- ^ unpush
                -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env')  -- ^ unpushN
                -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation
                -> (forall env'. r env' -> r env')  -- ^ scale-many
                -> Expr x env t -> r env
occCountGeneral onehot unpush unpushN alter many = go
  where
    go :: Monoid (r env') => Expr x env' t' -> r env'
    go = \case
      EVar _ _ i -> onehot i (Occ One One)
      ELet _ rhs body -> go rhs <> unpush (go body)
      EPair _ a b -> go a <> go b
      EFst _ e -> go e
      ESnd _ e -> go e
      ENil _ -> mempty
      EInl _ _ e -> go e
      EInr _ _ e -> go e
      ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
      EBuild1 _ a b -> go a <> many (unpush (go b))
      EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e))
      EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
      EUnit _ e -> go e
      EConst{} -> mempty
      EIdx0 _ e -> go e
      EIdx1 _ a b -> go a <> go b
      EIdx _ e es -> go e <> foldMap go es
      EOp _ _ e -> go e
      EWith a b -> go a <> unpush (go b)
      EAccum1 a b e -> go a <> go b <> go e
      EError{} -> mempty