blob: 7e70a7dceb7fb9175c1d14cc7302012bac7bfbf7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where
import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))
import AST
import Data
data Count = Zero | One | Many
deriving (Show, Eq, Ord)
instance Semigroup Count where
Zero <> n = n
n <> Zero = n
_ <> _ = Many
instance Monoid Count where
mempty = Zero
data Occ = Occ { _occLexical :: Count
, _occRuntime :: Count }
deriving (Eq, Generic)
deriving (Semigroup, Monoid) via Generically Occ
-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
getConst . occCountGeneral
(\i o -> if idx2int i == idx2int idx then Const o else mempty)
(\(Const o) -> Const o)
(\_ (Const o) -> Const o)
(\(Const o1) (Const o2) -> Const (o1 <||> o2))
(\(Const o) -> Const (scaleMany o))
data OccEnv env where
OccEnd :: OccEnv env -- not necessarily top!
OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
instance Semigroup (OccEnv env) where
OccEnd <> e = e
e <> OccEnd = e
OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
instance Monoid (OccEnv env) where
mempty = OccEnd
onehotOccEnv :: Idx env t -> Occ -> OccEnv env
onehotOccEnv IZ v = OccPush OccEnd v
onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
OccEnd <||>! e = e
e <||>! OccEnd = e
OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
occCountAll :: Expr x env t -> OccEnv env
occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv
where
unpush :: OccEnv (t : env) -> OccEnv env
unpush (OccPush o _) = o
unpush OccEnd = OccEnd
unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
unpushN _ OccEnd = OccEnd
unpushN SZ e = e
unpushN (SS n) (OccPush e _) = unpushN n e
occCountGeneral :: forall r env t x.
(forall env'. Monoid (r env'))
=> (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot
-> (forall env' a. r (a : env') -> r env') -- ^ unpush
-> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN
-> (forall env'. r env' -> r env' -> r env') -- ^ alternation
-> (forall env'. r env' -> r env') -- ^ scale-many
-> Expr x env t -> r env
occCountGeneral onehot unpush unpushN alter many = go
where
go :: Monoid (r env') => Expr x env' t' -> r env'
go = \case
EVar _ _ i -> onehot i (Occ One One)
ELet _ rhs body -> go rhs <> unpush (go body)
EPair _ a b -> go a <> go b
EFst _ e -> go e
ESnd _ e -> go e
ENil _ -> mempty
EInl _ _ e -> go e
EInr _ _ e -> go e
ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
EBuild1 _ a b -> go a <> many (unpush (go b))
EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e))
EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
EConst{} -> mempty
EIdx0 _ e -> go e
EIdx1 _ a b -> go a <> go b
EIdx _ e es -> go e <> foldMap go es
EOp _ _ e -> go e
EWith a b -> go a <> unpush (go b)
EAccum a b e -> go a <> go b <> go e
EError{} -> mempty
|