summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 17:48:15 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 17:48:15 +0200
commit8b047ff11ebd4715647bfc041a190f72dcf4d5a9 (patch)
treee8440120b7bbd4e45b367acb3f7185d25e7f3766 /src/Simplify.hs
parentf4b94d7cc2cb05611b462ba278e4f12f7a7a5e5e (diff)
Migrate to accumulators (mostly removing EVM code)
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs125
1 files changed, 80 insertions, 45 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 44de164..a5f90b3 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,83 +1,84 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE DataKinds #-}
module Simplify where
+import Data.Monoid
+
import AST
import AST.Count
import Data
-simplifyN :: Int -> Ex env t -> Ex env t
+simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify
-simplify :: Ex env t -> Ex env t
-simplify = \case
+simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
+simplify = let ?accumInScope = checkAccumInScope @env knownEnv in simplify'
+
+simplify' :: (?accumInScope :: Bool) => Ex env t -> Ex env t
+simplify' = \case
-- inlining
ELet _ rhs body
- | Occ lexOcc runOcc <- occCount IZ body
+ | not ?accumInScope || not (hasAdds rhs) -- cannot discard effectful computations
+ , Occ lexOcc runOcc <- occCount IZ body
, lexOcc <= One -- prevent code size blowup
, runOcc <= One -- prevent runtime increase
- -> simplify (subst1 rhs body)
+ -> simplify' (subst1 rhs body)
| cheapExpr rhs
- -> simplify (subst1 rhs body)
+ -> simplify' (subst1 rhs body)
-- let splitting
ELet _ (EPair _ a b) body ->
- simplify $
+ simplify' $
ELet ext a $
ELet ext (weakenExpr WSink b) $
subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
IS i -> EVar ext t (IS (IS i)))
body
+ -- let rotation
+ ELet _ (ELet _ rhs a) b ->
+ ELet ext (simplify' rhs) $
+ ELet ext (simplify' a) $
+ weakenExpr (WCopy WSink) (simplify' b)
+
-- beta rules for products
- EFst _ (EPair _ e _) -> simplify e
- ESnd _ (EPair _ _ e) -> simplify e
+ EFst _ (EPair _ e _) -> simplify' e
+ ESnd _ (EPair _ _ e) -> simplify' e
-- beta rules for coproducts
- ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
- ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)
+ ECase _ (EInl _ _ e) rhs _ -> simplify' (ELet ext e rhs)
+ ECase _ (EInr _ _ e) _ rhs -> simplify' (ELet ext e rhs)
-- TODO: array indexing (index of build, index of fold)
-- TODO: constant folding for operations
- -- eta rule for return+bind
- EMBind (EMReturn _ a) b -> simplify (ELet ext a b)
-
- -- associativity of bind
- EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c)))
-
- -- bind-let commute
- EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c)))
-
- -- return-let commute
- EMReturn env (ELet _ a b) -> simplify (ELet ext a (EMReturn env b))
-
EVar _ t i -> EVar ext t i
- ELet _ a b -> ELet ext (simplify a) (simplify b)
- EPair _ a b -> EPair ext (simplify a) (simplify b)
- EFst _ e -> EFst ext (simplify e)
- ESnd _ e -> ESnd ext (simplify e)
+ ELet _ a b -> ELet ext (simplify' a) (simplify' b)
+ EPair _ a b -> EPair ext (simplify' a) (simplify' b)
+ EFst _ e -> EFst ext (simplify' e)
+ ESnd _ e -> ESnd ext (simplify' e)
ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (simplify e)
- EInr _ t e -> EInr ext t (simplify e)
- ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b)
- EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b)
- EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e)
- EFold1 _ a b -> EFold1 ext (simplify a) (simplify b)
+ EInl _ t e -> EInl ext t (simplify' e)
+ EInr _ t e -> EInr ext t (simplify' e)
+ ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b)
+ EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
+ EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e)
+ EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
EConst _ t v -> EConst ext t v
- EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b)
- EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es)
- EOp _ op e -> EOp ext op (simplify e)
- EMOne t i e -> EMOne t i (simplify e)
- EMScope e -> EMScope (simplify e)
- EMReturn t e -> EMReturn t (simplify e)
- EMBind a b -> EMBind (simplify a) (simplify b)
+ EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
+ EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es)
+ EOp _ op e -> EOp ext op (simplify' e)
+ EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
+ EAccum e1 e2 e3 -> EAccum (simplify' e1) (simplify' e2) (simplify' e3)
EError t s -> EError t s
cheapExpr :: Expr x env t -> Bool
@@ -116,10 +117,8 @@ subst' f w = \case
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
EOp x op e -> EOp x op (subst' f w e)
- EMOne t i e -> EMOne t i (subst' f w e)
- EMScope e -> EMScope (subst' f w e)
- EMReturn t e -> EMReturn t (subst' f w e)
- EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -134,3 +133,39 @@ subst' f w = \case
sinkFN SZ f' x t w' i = f' x t w' i
sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i
+
+-- | This can be made more precise by tracking (and not counting) adds on
+-- locally eliminated accumulators.
+hasAdds :: Expr x env t -> Bool
+hasAdds = \case
+ EVar _ _ _ -> False
+ ELet _ rhs body -> hasAdds rhs || hasAdds body
+ EPair _ a b -> hasAdds a || hasAdds b
+ EFst _ e -> hasAdds e
+ ESnd _ e -> hasAdds e
+ ENil _ -> False
+ EInl _ _ e -> hasAdds e
+ EInr _ _ e -> hasAdds e
+ ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
+ EBuild1 _ a b -> hasAdds a || hasAdds b
+ EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e
+ EFold1 _ a b -> hasAdds a || hasAdds b
+ EConst _ _ _ -> False
+ EIdx1 _ a b -> hasAdds a || hasAdds b
+ EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es)
+ EOp _ _ e -> hasAdds e
+ EWith a b -> hasAdds a || hasAdds b
+ EAccum _ _ _ -> True
+ EError _ _ -> False
+
+checkAccumInScope :: SList STy env -> Bool
+checkAccumInScope = \case SNil -> False
+ SCons t env -> check t || checkAccumInScope env
+ where
+ check :: STy t -> Bool
+ check STNil = False
+ check (STPair s t) = check s || check t
+ check (STEither s t) = check s || check t
+ check (STArr _ t) = check t
+ check (STScal _) = False
+ check STAccum{} = True