diff options
author | Tom Smeding <tom@tomsmeding.com> | 2023-09-19 21:55:38 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2023-09-19 21:55:38 +0200 |
commit | 183e8b4a07231aae904b8234ddeb1c646c031173 (patch) | |
tree | d514667bb46f5bf6553abed83042a5771e3c39f2 /src/Simplify.hs | |
parent | 7095bcf4910e2b1525234ca8e88f4effc25315bd (diff) |
Stuff
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r-- | src/Simplify.hs | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs new file mode 100644 index 0000000..cb649d5 --- /dev/null +++ b/src/Simplify.hs @@ -0,0 +1,164 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +module Simplify where + +import AST + + +simplify :: Ex env t -> Ex env t +simplify = \case + -- inlining + ELet _ rhs body + | Occ lexOcc runOcc <- occCount IZ body + , lexOcc <= One -- prevent code size blowup + , runOcc <= One -- prevent runtime increase + -> simplify (subst1 rhs body) + | cheapExpr rhs + -> simplify (subst1 rhs body) + + -- let splitting + ELet _ (EPair _ a b) body -> + simplify $ + ELet ext a $ + ELet ext (weakenExpr WSink b) $ + subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) + IS i -> EVar ext t (IS (IS i))) + body + + EFst _ (EPair _ e _) -> simplify e + ESnd _ (EPair _ _ e) -> simplify e + + ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs) + ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs) + + -- TODO: array indexing (index of build, index of fold) + + -- TODO: constant folding for operations + + EVar _ t i -> EVar ext t i + ELet _ a b -> ELet ext (simplify a) (simplify b) + EPair _ a b -> EPair ext (simplify a) (simplify b) + EFst _ e -> EFst ext (simplify e) + ESnd _ e -> ESnd ext (simplify e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext t (simplify e) + EInr _ t e -> EInr ext t (simplify e) + ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b) + EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b) + EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e) + EFold1 _ a b -> EFold1 ext (simplify a) (simplify b) + EConst _ t v -> EConst ext t v + EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b) + EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es) + EOp _ op e -> EOp ext op (simplify e) + EMOne t i e -> EMOne t i (simplify e) + EMScope e -> EMScope (simplify e) + EMReturn t e -> EMReturn t (simplify e) + EMBind a b -> EMBind (simplify a) (simplify b) + EError t s -> EError t s + +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + _ -> False + +data Count = Zero | One | Many + deriving (Show, Eq, Ord) + +instance Semigroup Count where + Zero <> n = n + n <> Zero = n + _ <> _ = Many +instance Monoid Count where + mempty = Zero + +data Occ = Occ { _occLexical :: Count + , _occRuntime :: Count } +instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d) +instance Monoid Occ where mempty = Occ mempty mempty + +-- | One of the two branches is taken +(<||>) :: Occ -> Occ -> Occ +Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) + +-- | This code is executed many times +scaleMany :: Occ -> Occ +scaleMany (Occ l _) = Occ l Many + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx = \case + EVar _ _ i | idx2int i == idx2int idx -> Occ One One + | otherwise -> mempty + ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body + EPair _ a b -> occCount idx a <> occCount idx b + EFst _ e -> occCount idx e + ESnd _ e -> occCount idx e + ENil _ -> mempty + EInl _ _ e -> occCount idx e + EInr _ _ e -> occCount idx e + ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) + EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) + EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) + EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b + EConst{} -> mempty + EIdx1 _ a b -> occCount idx a <> occCount idx b + EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es + EOp _ _ e -> occCount idx e + EMOne _ _ e -> occCount idx e + EMScope e -> occCount idx e + EMReturn _ e -> occCount idx e + EMBind a b -> occCount idx a <> occCount (IS idx) b + EError{} -> mempty + +subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t +subst1 repl = subst $ \x t -> \case IZ -> repl + IS i -> EVar x t i + +subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) + -> Expr x env t -> Expr x env' t +subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId + +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) + -> env' :> envOut + -> Expr x env t + -> Expr x envOut t +subst' f w = \case + EVar x t i -> f x t w i + ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) + EPair x a b -> EPair x (subst' f w a) (subst' f w b) + EFst x e -> EFst x (subst' f w e) + ESnd x e -> ESnd x (subst' f w e) + ENil x -> ENil x + EInl x t e -> EInl x t (subst' f w e) + EInr x t e -> EInr x t (subst' f w e) + ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) + EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) + EConst x t v -> EConst x t v + EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) + EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) + EOp x op e -> EOp x op (subst' f w e) + EMOne t i e -> EMOne t i (subst' f w e) + EMScope e -> EMScope (subst' f w e) + EMReturn t e -> EMReturn t (subst' f w e) + EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EError t s -> EError t s + where + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF f' x' t w' = \case + IZ -> EVar x' t (w' @> IZ) + IS i -> f' x' t (WPop w') i + + sinkFN :: SNat n + -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t + sinkFN SZ f' x t w' i = f' x t w' i + sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) + sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i |