summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 15:56:39 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 15:57:17 +0200
commita1074fc851afcb6e858285ab9c6585b042ac1782 (patch)
tree8c40b943ee05134d79d418d23949a965eab1deae /src/Simplify.hs
parent6899e81e8e1fc7fad32515eb0d40465407c7cf87 (diff)
Tracing simplifier
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs180
1 files changed, 117 insertions, 63 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 2a1d3b6..469c7a1 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,8 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -13,20 +15,27 @@ module Simplify (
SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
) where
+import Control.Monad (ap)
+import Data.Bifunctor (first)
import Data.Function (fix)
import Data.Monoid (Any(..))
import Data.Type.Equality (testEquality)
+import Debug.Trace
+
import AST
import AST.Count
+import AST.Pretty
import Data
+import Simplify.TH
--- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.
data SimplifyConfig = SimplifyConfig
+ { scLogging :: Bool
+ }
defaultSimplifyConfig :: SimplifyConfig
-defaultSimplifyConfig = SimplifyConfig
+defaultSimplifyConfig = SimplifyConfig False
simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
@@ -36,13 +45,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplify =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = defaultSimplifyConfig
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
simplifyWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplifyFix = simplifyFixWith defaultSimplifyConfig
@@ -52,11 +61,51 @@ simplifyFixWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
in fix $ \loop e ->
- let (Any act, e') = simplify' e
+ let (act, e') = runSM (simplify' e)
in if act then loop e' else e'
-simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t)
-simplify' = \case
+-- | simplify monad
+newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a))
+ deriving (Functor)
+
+instance Applicative (SM tenv tt env t) where
+ pure x = SM (\_ -> (Any False, x))
+ (<*>) = ap
+
+instance Monad (SM tenv tt env t) where
+ SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx
+
+runSM :: SM env t env t a -> (Bool, a)
+runSM (SM f) = first getAny (f id)
+
+smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
+smReconstruct core = SM (\ctx -> (Any False, ctx core))
+
+tellActed :: SM tenv tt env t ()
+tellActed = SM (\_ -> (Any True, ()))
+
+-- more convenient in practice
+acted :: SM tenv tt env t a -> SM tenv tt env t a
+acted m = tellActed >> m
+
+within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
+within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
+
+simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify' expr
+ | scLogging ?config = do
+ res <- simplify'Rec expr
+ full <- smReconstruct res
+ let printed = ppExpr knownEnv full
+ replace a bs = concatMap (\x -> if x == a then bs else [x])
+ str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed
+ | otherwise = "--- simplify step: " ++ printed
+ traceM str
+ return res
+ | otherwise = simplify'Rec expr
+
+simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify'Rec = \case
-- inlining
ELet _ rhs body
| cheapExpr rhs
@@ -83,11 +132,12 @@ simplify' = \case
acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body
-- let rotation
- ELet _ (ELet _ rhs a) b ->
+ ELet _ (ELet _ rhs a) b -> do
+ b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b
acted $ simplify' $
ELet ext rhs $
ELet ext a $
- weakenExpr (WCopy WSink) (snd (simplify' b))
+ weakenExpr (WCopy WSink) b'
-- beta rules for products
EFst _ (EPair _ e e')
@@ -133,8 +183,8 @@ simplify' = \case
EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
-- TODO: more constant folding
- EOp _ OIf (EConst _ STBool True) -> (Any True, EInl ext STNil (ENil ext))
- EOp _ OIf (EConst _ STBool False) -> (Any True, EInr ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
-- inline cheap array constructors
ELet _ (EReplicate1Inner _ e1 e2) e3 ->
@@ -153,29 +203,29 @@ simplify' = \case
-- eta rule for unit
e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
case e of
- ENil _ -> (Any False, e)
- _ -> (Any True, ENil ext)
+ ENil _ -> return e
+ _ -> acted $ return (ENil ext)
EBuild _ SZ _ e ->
acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
EAccum _ t p e1 e2 acc -> do
- e1' <- simplify' e1
- e2' <- simplify' e2
- acc' <- simplify' acc
+ e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc
simplifyOneHotTerm (OneHotTerm t p e1' e2')
- (Any True, ENil ext)
- (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
+ (acted $ return (ENil ext))
+ (\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))
(\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
- e1' <- simplify' e1
- e2' <- simplify' e2
+ e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
+ e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
simplifyOneHotTerm (OneHotTerm t p e1' e2')
- (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2))
- (\e -> (Any True, e))
+ (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
+ (\e -> acted $ return e)
(\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
-- type-specific equations for plus
@@ -198,49 +248,50 @@ simplify' = \case
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
- ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
- EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
- EFst _ e -> EFst ext <$> simplify' e
- ESnd _ e -> ESnd ext <$> simplify' e
+ ELet _ a b -> [simprec| ELet ext *a *b |]
+ EPair _ a b -> [simprec| EPair ext *a *b |]
+ EFst _ e -> [simprec| EFst ext *e |]
+ ESnd _ e -> [simprec| ESnd ext *e |]
ENil _ -> pure $ ENil ext
- EInl _ t e -> EInl ext t <$> simplify' e
- EInr _ t e -> EInr ext t <$> simplify' e
- ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
+ EInl _ t e -> [simprec| EInl ext t *e |]
+ EInr _ t e -> [simprec| EInr ext t *e |]
+ ECase _ e a b -> [simprec| ECase ext *e *a *b |]
ENothing _ t -> pure $ ENothing ext t
- EJust _ e -> EJust ext <$> simplify' e
- EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
+ EJust _ e -> [simprec| EJust ext *e |]
+ EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]
ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
- ELInl _ t e -> ELInl ext t <$> simplify' e
- ELInr _ t e -> ELInr ext t <$> simplify' e
- ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c
+ ELInl _ t e -> [simprec| ELInl ext t *e |]
+ ELInr _ t e -> [simprec| ELInr ext t *e |]
+ ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
- EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c
- ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
- EUnit _ e -> EUnit ext <$> simplify' e
- EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
- EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
- EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
+ EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
+ EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
+ ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
+ EUnit _ e -> [simprec| EUnit ext *e |]
+ EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
+ EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
+ EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EConst _ t v -> pure $ EConst ext t v
- EIdx0 _ e -> EIdx0 ext <$> simplify' e
- EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
- EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
- EShape _ e -> EShape ext <$> simplify' e
- EOp _ op e -> EOp ext op <$> simplify' e
- ECustom _ s t p a b c e1 e2 ->
- ECustom ext s t p
- <$> (let ?accumInScope = False in simplify' a)
- <*> (let ?accumInScope = False in simplify' b)
- <*> (let ?accumInScope = False in simplify' c)
- <*> simplify' e1 <*> simplify' e2
- EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EZero _ t e -> EZero ext t <$> simplify' e
- EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
+ EIdx0 _ e -> [simprec| EIdx0 ext *e |]
+ EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
+ EIdx _ a b -> [simprec| EIdx ext *a *b |]
+ EShape _ e -> [simprec| EShape ext *e |]
+ EOp _ op e -> [simprec| EOp ext op *e |]
+ ECustom _ s t p a b c e1 e2 -> do
+ a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a)
+ b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b)
+ c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c)
+ e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
+ pure (ECustom ext s t p a' b' c' e1' e2')
+ EWith _ t e1 e2 -> do
+ e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
+ pure (EWith ext t e1' e2')
+ EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b
EError _ t s -> pure $ EError ext t s
-acted :: (Any, a) -> (Any, a)
-acted (_, x) = (Any True, x)
-
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
EVar{} -> True
@@ -312,18 +363,21 @@ data OneHotTerm env p a b where
deriving instance Show (OneHotTerm env p a b)
simplifyOneHotTerm :: OneHotTerm env p a b
- -> (Any, r) -- ^ Zero case (onehot is actually zero)
- -> (Ex env a -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
- -> (Any, r)
+ -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero)
+ -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
+ -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r)
+ -> SM tenv tt env t r
simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero
simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
| Just Refl <- testEquality (acPrjTy prj1 t1) t2
- = do (Any True, ()) -- record, whatever happens later, that we've modified something
+ = do tellActed -- record, whatever happens later, that we've modified something
concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
+-- TODO: This does not actually recurse unless it just so happens to contain
+-- another EZero or EOnehot in the final position. Should match on something
+-- more general than SAPHere here.
simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of
(SMTNil, _) -> kzero