summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
blob: c0d8d2dc9b54d0c3f0f781051d3a4880d6cedcd6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where

import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))

import AST
import AST.Env
import Data


data Count = Zero | One | Many
  deriving (Show, Eq, Ord)

instance Semigroup Count where
  Zero <> n = n
  n <> Zero = n
  _ <> _ = Many
instance Monoid Count where
  mempty = Zero

data Occ = Occ { _occLexical :: Count
               , _occRuntime :: Count }
  deriving (Eq, Generic)
  deriving (Semigroup, Monoid) via Generically Occ

instance Show Occ where
  showsPrec d (Occ l r) = showParen (d > 10) $
    showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r

-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)

-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l Zero) = Occ l Zero
scaleMany (Occ l _) = Occ l Many

occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
  getConst . occCountGeneral
    (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
    (\(Const o) -> Const o)
    (\(Const o1) (Const o2) -> Const (o1 <||> o2))
    (\(Const o) -> Const (scaleMany o))


data OccEnv env where
  OccEnd :: OccEnv env  -- not necessarily top!
  OccPush :: OccEnv env -> Occ -> OccEnv (t : env)

instance Semigroup (OccEnv env) where
  OccEnd <> e = e
  e <> OccEnd = e
  OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')

instance Monoid (OccEnv env) where
  mempty = OccEnd

onehotOccEnv :: Idx env t -> Occ -> OccEnv env
onehotOccEnv IZ v = OccPush OccEnd v
onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty

(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
OccEnd <||>! e = e
e <||>! OccEnd = e
OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')

scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)

occEnvPop :: OccEnv (t : env) -> OccEnv env
occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd

occCountAll :: Expr x env t -> OccEnv env
occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv

occCountGeneral :: forall r env t x.
                   (forall env'. Monoid (r env'))
                => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env')  -- ^ one-hot
                -> (forall env' a. r (a : env') -> r env')  -- ^ unpush
                -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation
                -> (forall env'. r env' -> r env')  -- ^ scale-many
                -> Expr x env t -> r env
occCountGeneral onehot unpush alter many = go WId
  where
    go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
    go w = \case
      EVar _ _ i -> onehot w i (Occ One One)
      ELet _ rhs body -> re rhs <> re1 body
      EPair _ a b -> re a <> re b
      EFst _ e -> re e
      ESnd _ e -> re e
      ENil _ -> mempty
      EInl _ _ e -> re e
      EInr _ _ e -> re e
      ECase _ e a b -> re e <> (re1 a `alter` re1 b)
      ENothing _ _ -> mempty
      EJust _ e -> re e
      EMaybe _ a b e -> re a <> re1 b <> re e
      EConstArr{} -> mempty
      EBuild _ _ a b -> re a <> many (re1 b)
      EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
      ESum1Inner _ e -> re e
      EUnit _ e -> re e
      EReplicate1Inner _ a b -> re a <> re b
      EMaximum1Inner _ e -> re e
      EMinimum1Inner _ e -> re e
      EConst{} -> mempty
      EIdx0 _ e -> re e
      EIdx1 _ a b -> re a <> re b
      EIdx _ a b -> re a <> re b
      EShape _ e -> re e
      EOp _ _ e -> re e
      ECustom _ _ _ _ _ _ _ a b -> re a <> re b
      EWith _ _ a b -> re a <> re1 b
      EAccum _ _ _ a b e -> re a <> re b <> re e
      EZero _ _ -> mempty
      EPlus _ _ a b -> re a <> re b
      EOneHot _ _ _ a b -> re a <> re b
      EError{} -> mempty
      where
        re :: Monoid (r env') => Expr x env' t'' -> r env'
        re = go w

        re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
        re1 = unpush . go (WSink .> w)


deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil OccEnd k = k SETop
deleteUnused (_ `SCons` env) OccEnd k =
  deleteUnused env OccEnd $ \sub -> k (SENo sub)
deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
  deleteUnused env occenv $ \sub ->
    case count of Zero -> k (SENo sub)
                  _    -> k (SEYes sub)

unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
  subst (\x t i -> case sinkViaSubenv i sub of
                     Just i' -> EVar x t i'
                     Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
  where
    sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
    sinkViaSubenv IZ (SEYes _) = Just IZ
    sinkViaSubenv IZ (SENo _) = Nothing
    sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
    sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub