summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs164
1 files changed, 164 insertions, 0 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
new file mode 100644
index 0000000..cb649d5
--- /dev/null
+++ b/src/Simplify.hs
@@ -0,0 +1,164 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE DataKinds #-}
+module Simplify where
+
+import AST
+
+
+simplify :: Ex env t -> Ex env t
+simplify = \case
+ -- inlining
+ ELet _ rhs body
+ | Occ lexOcc runOcc <- occCount IZ body
+ , lexOcc <= One -- prevent code size blowup
+ , runOcc <= One -- prevent runtime increase
+ -> simplify (subst1 rhs body)
+ | cheapExpr rhs
+ -> simplify (subst1 rhs body)
+
+ -- let splitting
+ ELet _ (EPair _ a b) body ->
+ simplify $
+ ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
+ IS i -> EVar ext t (IS (IS i)))
+ body
+
+ EFst _ (EPair _ e _) -> simplify e
+ ESnd _ (EPair _ _ e) -> simplify e
+
+ ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
+ ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)
+
+ -- TODO: array indexing (index of build, index of fold)
+
+ -- TODO: constant folding for operations
+
+ EVar _ t i -> EVar ext t i
+ ELet _ a b -> ELet ext (simplify a) (simplify b)
+ EPair _ a b -> EPair ext (simplify a) (simplify b)
+ EFst _ e -> EFst ext (simplify e)
+ ESnd _ e -> ESnd ext (simplify e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext t (simplify e)
+ EInr _ t e -> EInr ext t (simplify e)
+ ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b)
+ EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b)
+ EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e)
+ EFold1 _ a b -> EFold1 ext (simplify a) (simplify b)
+ EConst _ t v -> EConst ext t v
+ EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b)
+ EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es)
+ EOp _ op e -> EOp ext op (simplify e)
+ EMOne t i e -> EMOne t i (simplify e)
+ EMScope e -> EMScope (simplify e)
+ EMReturn t e -> EMReturn t (simplify e)
+ EMBind a b -> EMBind (simplify a) (simplify b)
+ EError t s -> EError t s
+
+cheapExpr :: Expr x env t -> Bool
+cheapExpr = \case
+ EVar{} -> True
+ ENil{} -> True
+ EConst{} -> True
+ _ -> False
+
+data Count = Zero | One | Many
+ deriving (Show, Eq, Ord)
+
+instance Semigroup Count where
+ Zero <> n = n
+ n <> Zero = n
+ _ <> _ = Many
+instance Monoid Count where
+ mempty = Zero
+
+data Occ = Occ { _occLexical :: Count
+ , _occRuntime :: Count }
+instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
+instance Monoid Occ where mempty = Occ mempty mempty
+
+-- | One of the two branches is taken
+(<||>) :: Occ -> Occ -> Occ
+Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+
+-- | This code is executed many times
+scaleMany :: Occ -> Occ
+scaleMany (Occ l _) = Occ l Many
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx = \case
+ EVar _ _ i | idx2int i == idx2int idx -> Occ One One
+ | otherwise -> mempty
+ ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
+ EPair _ a b -> occCount idx a <> occCount idx b
+ EFst _ e -> occCount idx e
+ ESnd _ e -> occCount idx e
+ ENil _ -> mempty
+ EInl _ _ e -> occCount idx e
+ EInr _ _ e -> occCount idx e
+ ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
+ EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
+ EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
+ EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
+ EConst{} -> mempty
+ EIdx1 _ a b -> occCount idx a <> occCount idx b
+ EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
+ EOp _ _ e -> occCount idx e
+ EMOne _ _ e -> occCount idx e
+ EMScope e -> occCount idx e
+ EMReturn _ e -> occCount idx e
+ EMBind a b -> occCount idx a <> occCount (IS idx) b
+ EError{} -> mempty
+
+subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
+subst1 repl = subst $ \x t -> \case IZ -> repl
+ IS i -> EVar x t i
+
+subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
+ -> Expr x env t -> Expr x env' t
+subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
+
+subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
+ -> env' :> envOut
+ -> Expr x env t
+ -> Expr x envOut t
+subst' f w = \case
+ EVar x t i -> f x t w i
+ ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
+ EPair x a b -> EPair x (subst' f w a) (subst' f w b)
+ EFst x e -> EFst x (subst' f w e)
+ ESnd x e -> ESnd x (subst' f w e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (subst' f w e)
+ EInr x t e -> EInr x t (subst' f w e)
+ ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
+ EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ EConst x t v -> EConst x t v
+ EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
+ EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
+ EOp x op e -> EOp x op (subst' f w e)
+ EMOne t i e -> EMOne t i (subst' f w e)
+ EMScope e -> EMScope (subst' f w e)
+ EMReturn t e -> EMReturn t (subst' f w e)
+ EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EError t s -> EError t s
+ where
+ sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
+ sinkF f' x' t w' = \case
+ IZ -> EVar x' t (w' @> IZ)
+ IS i -> f' x' t (WPop w') i
+
+ sinkFN :: SNat n
+ -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
+ -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
+ sinkFN SZ f' x t w' i = f' x t w' i
+ sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
+ sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i