summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
blob: a4ff9f2840e4500c5caf7132fd16e2cf0aca5b20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where

import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))

import AST
import AST.Env
import Data


data Count = Zero | One | Many
  deriving (Show, Eq, Ord)

instance Semigroup Count where
  Zero <> n = n
  n <> Zero = n
  _ <> _ = Many
instance Monoid Count where
  mempty = Zero

data Occ = Occ { _occLexical :: Count
               , _occRuntime :: Count }
  deriving (Eq, Generic)
  deriving (Semigroup, Monoid) via Generically Occ

-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)

-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l _) = Occ l Many

occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
  getConst . occCountGeneral
    (\i o -> if idx2int i == idx2int idx then Const o else mempty)
    (\(Const o) -> Const o)
    (\_ (Const o) -> Const o)
    (\(Const o1) (Const o2) -> Const (o1 <||> o2))
    (\(Const o) -> Const (scaleMany o))


data OccEnv env where
  OccEnd :: OccEnv env  -- not necessarily top!
  OccPush :: OccEnv env -> Occ -> OccEnv (t : env)

instance Semigroup (OccEnv env) where
  OccEnd <> e = e
  e <> OccEnd = e
  OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')

instance Monoid (OccEnv env) where
  mempty = OccEnd

onehotOccEnv :: Idx env t -> Occ -> OccEnv env
onehotOccEnv IZ v = OccPush OccEnd v
onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty

(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
OccEnd <||>! e = e
e <||>! OccEnd = e
OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')

scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)

occEnvPop :: OccEnv (t : env) -> OccEnv env
occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd

occCountAll :: Expr x env t -> OccEnv env
occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv
  where
    occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
    occEnvPopN _ OccEnd = OccEnd
    occEnvPopN SZ e = e
    occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e

occCountGeneral :: forall r env t x.
                   (forall env'. Monoid (r env'))
                => (forall env' a. Idx env' a -> Occ -> r env')  -- ^ one-hot
                -> (forall env' a. r (a : env') -> r env')  -- ^ unpush
                -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env')  -- ^ unpushN
                -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation
                -> (forall env'. r env' -> r env')  -- ^ scale-many
                -> Expr x env t -> r env
occCountGeneral onehot unpush unpushN alter many = go
  where
    go :: Monoid (r env') => Expr x env' t' -> r env'
    go = \case
      EVar _ _ i -> onehot i (Occ One One)
      ELet _ rhs body -> go rhs <> unpush (go body)
      EPair _ a b -> go a <> go b
      EFst _ e -> go e
      ESnd _ e -> go e
      ENil _ -> mempty
      EInl _ _ e -> go e
      EInr _ _ e -> go e
      ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
      EBuild1 _ a b -> go a <> many (unpush (go b))
      EBuild _ n a b -> go a <> many (unpushN n (go b))
      EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
      EUnit _ e -> go e
      EReplicate _ e -> go e
      EConst{} -> mempty
      EIdx0 _ e -> go e
      EIdx1 _ a b -> go a <> go b
      EIdx _ e es -> go e <> foldMap go es
      EOp _ _ e -> go e
      EWith a b -> go a <> unpush (go b)
      EAccum1 a b e -> go a <> go b <> go e
      EError{} -> mempty


deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil OccEnd k = k SETop
deleteUnused (_ `SCons` env) OccEnd k =
  deleteUnused env OccEnd $ \sub -> k (SENo sub)
deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
  deleteUnused env occenv $ \sub ->
    case count of Zero -> k (SENo sub)
                  _    -> k (SEYes sub)

unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
  subst (\x t i -> case sinkWithSubenv i sub of
                     Just i' -> EVar x t i'
                     Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
  where
    sinkWithSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
    sinkWithSubenv IZ (SEYes _) = Just IZ
    sinkWithSubenv IZ (SENo _) = Nothing
    sinkWithSubenv (IS i) (SEYes sub) = IS <$> sinkWithSubenv i sub
    sinkWithSubenv (IS i) (SENo sub) = sinkWithSubenv i sub