summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
blob: 16a3e1dd0ae15bf877b9ad16d310e4fdad431cf1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
module Simplify where

import AST
import AST.Count


simplifyN :: Int -> Ex env t -> Ex env t
simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify

simplify :: Ex env t -> Ex env t
simplify = \case
  -- inlining
  ELet _ rhs body
    | Occ lexOcc runOcc <- occCount IZ body
    , lexOcc <= One  -- prevent code size blowup
    , runOcc <= One  -- prevent runtime increase
    -> simplify (subst1 rhs body)
    | cheapExpr rhs
    -> simplify (subst1 rhs body)

  -- let splitting
  ELet _ (EPair _ a b) body ->
    simplify $
      ELet ext a $
      ELet ext (weakenExpr WSink b) $
        subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
                             IS i -> EVar ext t (IS (IS i)))
              body

  -- beta rules for products
  EFst _ (EPair _ e _) -> simplify e
  ESnd _ (EPair _ _ e) -> simplify e

  -- beta rules for coproducts
  ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
  ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)

  -- TODO: array indexing (index of build, index of fold)

  -- TODO: constant folding for operations

  -- eta rule for return+bind
  EMBind (EMReturn _ a) b -> simplify (ELet ext a b)

  -- associativity of bind
  EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c)))

  -- bind-let commute
  EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c)))

  EVar _ t i -> EVar ext t i
  ELet _ a b -> ELet ext (simplify a) (simplify b)
  EPair _ a b -> EPair ext (simplify a) (simplify b)
  EFst _ e -> EFst ext (simplify e)
  ESnd _ e -> ESnd ext (simplify e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext t (simplify e)
  EInr _ t e -> EInr ext t (simplify e)
  ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b)
  EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b)
  EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e)
  EFold1 _ a b -> EFold1 ext (simplify a) (simplify b)
  EConst _ t v -> EConst ext t v
  EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b)
  EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es)
  EOp _ op e -> EOp ext op (simplify e)
  EMOne t i e -> EMOne t i (simplify e)
  EMScope e -> EMScope (simplify e)
  EMReturn t e -> EMReturn t (simplify e)
  EMBind a b -> EMBind (simplify a) (simplify b)
  EError t s -> EError t s

cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
  EVar{} -> True
  ENil{} -> True
  EConst{} -> True
  _ -> False

subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
                                    IS i -> EVar x t i

subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
      -> Expr x env t -> Expr x env' t
subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId

subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
       -> env' :> envOut
       -> Expr x env t
       -> Expr x envOut t
subst' f w = \case
  EVar x t i -> f x t w i
  ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
  EPair x a b -> EPair x (subst' f w a) (subst' f w b)
  EFst x e -> EFst x (subst' f w e)
  ESnd x e -> ESnd x (subst' f w e)
  ENil x -> ENil x
  EInl x t e -> EInl x t (subst' f w e)
  EInr x t e -> EInr x t (subst' f w e)
  ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
  EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
  EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
  EConst x t v -> EConst x t v
  EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
  EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
  EOp x op e -> EOp x op (subst' f w e)
  EMOne t i e -> EMOne t i (subst' f w e)
  EMScope e -> EMScope (subst' f w e)
  EMReturn t e -> EMReturn t (subst' f w e)
  EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EError t s -> EError t s
  where
    sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
          -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
    sinkF f' x' t w' = \case
      IZ -> EVar x' t (w' @> IZ)
      IS i -> f' x' t (WPop w') i

    sinkFN :: SNat n
           -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
           -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
    sinkFN SZ f' x t w' i = f' x t w' i
    sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
    sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i