diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:56:39 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:57:17 +0200 | 
| commit | a1074fc851afcb6e858285ab9c6585b042ac1782 (patch) | |
| tree | 8c40b943ee05134d79d418d23949a965eab1deae | |
| parent | 6899e81e8e1fc7fad32515eb0d40465407c7cf87 (diff) | |
Tracing simplifier
| -rw-r--r-- | chad-fast.cabal | 3 | ||||
| -rw-r--r-- | src/Simplify.hs | 180 | ||||
| -rw-r--r-- | src/Simplify/TH.hs | 80 | 
3 files changed, 199 insertions, 64 deletions
| diff --git a/chad-fast.cabal b/chad-fast.cabal index 1aadc6b..b0ed639 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -46,6 +46,7 @@ library      Language.AST      Lemmas      Simplify +    Simplify.TH      Util.IdGen    other-modules:    build-depends: @@ -53,10 +54,10 @@ library      containers,      deepseq,      directory, -    -- template-haskell,      prettyprinter,      process,      some, +    template-haskell,      transformers,      unix,      vector, diff --git a/src/Simplify.hs b/src/Simplify.hs index 2a1d3b6..469c7a1 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,8 +1,10 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE ImplicitParams #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE QuasiQuotes #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-} @@ -13,20 +15,27 @@ module Simplify (    SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,  ) where +import Control.Monad (ap) +import Data.Bifunctor (first)  import Data.Function (fix)  import Data.Monoid (Any(..))  import Data.Type.Equality (testEquality) +import Debug.Trace +  import AST  import AST.Count +import AST.Pretty  import Data +import Simplify.TH --- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.  data SimplifyConfig = SimplifyConfig +  { scLogging :: Bool +  }  defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig +defaultSimplifyConfig = SimplifyConfig False  simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t  simplifyN 0 = id @@ -36,13 +45,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t  simplify =    let ?accumInScope = checkAccumInScope @env knownEnv        ?config = defaultSimplifyConfig -  in snd . simplify' +  in snd . runSM . simplify'  simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t  simplifyWith config =    let ?accumInScope = checkAccumInScope @env knownEnv        ?config = config -  in snd . simplify' +  in snd . runSM . simplify'  simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t  simplifyFix = simplifyFixWith defaultSimplifyConfig @@ -52,11 +61,51 @@ simplifyFixWith config =    let ?accumInScope = checkAccumInScope @env knownEnv        ?config = config    in fix $ \loop e -> -            let (Any act, e') = simplify' e +            let (act, e') = runSM (simplify' e)              in if act then loop e' else e' -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) -simplify' = \case +-- | simplify monad +newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) +  deriving (Functor) + +instance Applicative (SM tenv tt env t) where +  pure x = SM (\_ -> (Any False, x)) +  (<*>) = ap + +instance Monad (SM tenv tt env t) where +  SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx + +runSM :: SM env t env t a -> (Bool, a) +runSM (SM f) = first getAny (f id) + +smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) +smReconstruct core = SM (\ctx -> (Any False, ctx core)) + +tellActed :: SM tenv tt env t () +tellActed = SM (\_ -> (Any True, ())) + +-- more convenient in practice +acted :: SM tenv tt env t a -> SM tenv tt env t a +acted m = tellActed >> m + +within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a +within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) + +simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify' expr +  | scLogging ?config = do +      res <- simplify'Rec expr +      full <- smReconstruct res +      let printed = ppExpr knownEnv full +          replace a bs = concatMap (\x -> if x == a then bs else [x]) +          str | '\n' `elem` printed = "--- simplify step:\n  " ++ replace '\n' "\n  " printed +              | otherwise = "--- simplify step: " ++ printed +      traceM str +      return res +  | otherwise = simplify'Rec expr + +simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify'Rec = \case    -- inlining    ELet _ rhs body      | cheapExpr rhs @@ -83,11 +132,12 @@ simplify' = \case      acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body    -- let rotation -  ELet _ (ELet _ rhs a) b -> +  ELet _ (ELet _ rhs a) b -> do +    b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b      acted $ simplify' $        ELet ext rhs $        ELet ext a $ -        weakenExpr (WCopy WSink) (snd (simplify' b)) +        weakenExpr (WCopy WSink) b'    -- beta rules for products    EFst _ (EPair _ e e') @@ -133,8 +183,8 @@ simplify' = \case    EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1    -- TODO: more constant folding -  EOp _ OIf (EConst _ STBool True) -> (Any True, EInl ext STNil (ENil ext)) -  EOp _ OIf (EConst _ STBool False) -> (Any True, EInr ext STNil (ENil ext)) +  EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) +  EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))    -- inline cheap array constructors    ELet _ (EReplicate1Inner _ e1 e2) e3 -> @@ -153,29 +203,29 @@ simplify' = \case    -- eta rule for unit    e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->      case e of -      ENil _ -> (Any False, e) -      _      -> (Any True, ENil ext) +      ENil _ -> return e +      _      -> acted $ return (ENil ext)    EBuild _ SZ _ e ->      acted $ simplify' $ EUnit ext (substInline (ENil ext) e)    -- monoid rules    EAccum _ t p e1 e2 acc -> do -    e1' <- simplify' e1 -    e2' <- simplify' e2 -    acc' <- simplify' acc +    e1' <- within  (\e1'  -> EAccum ext t p e1' e2  acc ) $ simplify' e1 +    e2' <- within  (\e2'  -> EAccum ext t p e1' e2' acc ) $ simplify' e2 +    acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc      simplifyOneHotTerm (OneHotTerm t p e1' e2') -      (Any True, ENil ext) -      (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) +      (acted $ return (ENil ext)) +      (\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))        (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))    EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e    EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e    EOneHot _ t p e1 e2 -> do -    e1' <- simplify' e1 -    e2' <- simplify' e2 +    e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 +    e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2      simplifyOneHotTerm (OneHotTerm t p e1' e2') -      (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2)) -      (\e -> (Any True, e)) +      (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) +      (\e -> acted $ return e)        (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))    -- type-specific equations for plus @@ -198,49 +248,50 @@ simplify' = \case    -- fallback recursion    EVar _ t i -> pure $ EVar ext t i -  ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b -  EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b -  EFst _ e -> EFst ext <$> simplify' e -  ESnd _ e -> ESnd ext <$> simplify' e +  ELet _ a b -> [simprec| ELet ext *a *b |] +  EPair _ a b -> [simprec| EPair ext *a *b |] +  EFst _ e -> [simprec| EFst ext *e |] +  ESnd _ e -> [simprec| ESnd ext *e |]    ENil _ -> pure $ ENil ext -  EInl _ t e -> EInl ext t <$> simplify' e -  EInr _ t e -> EInr ext t <$> simplify' e -  ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b +  EInl _ t e -> [simprec| EInl ext t *e |] +  EInr _ t e -> [simprec| EInr ext t *e |] +  ECase _ e a b -> [simprec| ECase ext *e *a *b |]    ENothing _ t -> pure $ ENothing ext t -  EJust _ e -> EJust ext <$> simplify' e -  EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e +  EJust _ e -> [simprec| EJust ext *e |] +  EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]    ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 -  ELInl _ t e -> ELInl ext t <$> simplify' e -  ELInr _ t e -> ELInr ext t <$> simplify' e -  ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c +  ELInl _ t e -> [simprec| ELInl ext t *e |] +  ELInr _ t e -> [simprec| ELInr ext t *e |] +  ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]    EConstArr _ n t v -> pure $ EConstArr ext n t v -  EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b -  EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c -  ESum1Inner _ e -> ESum1Inner ext <$> simplify' e -  EUnit _ e -> EUnit ext <$> simplify' e -  EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b -  EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e -  EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e +  EBuild _ n a b -> [simprec| EBuild ext n *a *b |] +  EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] +  ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] +  EUnit _ e -> [simprec| EUnit ext *e |] +  EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] +  EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] +  EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]    EConst _ t v -> pure $ EConst ext t v -  EIdx0 _ e -> EIdx0 ext <$> simplify' e -  EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b -  EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b -  EShape _ e -> EShape ext <$> simplify' e -  EOp _ op e -> EOp ext op <$> simplify' e -  ECustom _ s t p a b c e1 e2 -> -    ECustom ext s t p -      <$> (let ?accumInScope = False in simplify' a) -      <*> (let ?accumInScope = False in simplify' b) -      <*> (let ?accumInScope = False in simplify' c) -      <*> simplify' e1 <*> simplify' e2 -  EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) -  EZero _ t e -> EZero ext t <$> simplify' e -  EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b +  EIdx0 _ e -> [simprec| EIdx0 ext *e |] +  EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] +  EIdx _ a b -> [simprec| EIdx ext *a *b |] +  EShape _ e -> [simprec| EShape ext *e |] +  EOp _ op e -> [simprec| EOp ext op *e |] +  ECustom _ s t p a b c e1 e2 -> do +    a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) +    b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) +    c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) +    e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) +    e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) +    pure (ECustom ext s t p a' b' c' e1' e2') +  EWith _ t e1 e2 -> do +    e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) +    e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) +    pure (EWith ext t e1' e2') +  EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e +  EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b    EError _ t s -> pure $ EError ext t s -acted :: (Any, a) -> (Any, a) -acted (_, x) = (Any True, x) -  cheapExpr :: Expr x env t -> Bool  cheapExpr = \case    EVar{} -> True @@ -312,18 +363,21 @@ data OneHotTerm env p a b where  deriving instance Show (OneHotTerm env p a b)  simplifyOneHotTerm :: OneHotTerm env p a b -                   -> (Any, r)  -- ^ Zero case (onehot is actually zero) -                   -> (Ex env a -> (Any, r))  -- ^ Trivial case (no zeros in onehot) -                   -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) -                   -> (Any, r) +                   -> SM tenv tt env t r  -- ^ Zero case (onehot is actually zero) +                   -> (Ex env a -> SM tenv tt env t r)  -- ^ Trivial case (no zeros in onehot) +                   -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) +                   -> SM tenv tt env t r  simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero  simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k    | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -  = do (Any True, ())  -- record, whatever happens later, that we've modified something +  = do tellActed  -- record, whatever happens later, that we've modified something         concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->           simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k +-- TODO: This does not actually recurse unless it just so happens to contain +-- another EZero or EOnehot in the final position. Should match on something +-- more general than SAPHere here.  simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of    (SMTNil, _) -> kzero diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs new file mode 100644 index 0000000..2e0076a --- /dev/null +++ b/src/Simplify/TH.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Simplify.TH (simprec) where + +import Data.Bifunctor (first) +import Data.Char +import Data.List (foldl1') +import Language.Haskell.TH +import Language.Haskell.TH.Quote +import Text.ParserCombinators.ReadP + + +-- [simprec| EPair ext *a *b |] +-- ~> +-- do a' <- within (\a' -> EPair ext a' b) (simplify' a) +--    b' <- within (\b' -> EPair ext a' b') (simplify' b) +--    pure (EPair ext a' b') + +simprec :: QuasiQuoter +simprec = QuasiQuoter +  { quoteDec = \_ -> fail "simprec used outside of expression context" +  , quoteType = \_ -> fail "simprec used outside of expression context" +  , quoteExp = handler +  , quotePat = \_ -> fail "simprec used outside of expression context" +  } + +handler :: String -> Q Exp +handler str = +  case readP_to_S pTemplate str of +    [(template, "")] -> generate template +    _:_:_ -> fail "simprec: template grammar ambiguous" +    _ -> fail "simprec: could not parse template" + +generate :: Template -> Q Exp +generate (Template topitems) = +  let takePrefix (Plain x : xs) = first (x:) (takePrefix xs) +      takePrefix xs = ([], xs) + +      itemVar "" = error "simprec: empty item name?" +      itemVar name@(c:_) | isLower c = VarE (mkName name) +                         | isUpper c = ConE (mkName name) +                         | otherwise = error "simprec: non-letter item name?" + +      loop :: Exp -> [Item] -> Q [Stmt] +      loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)] +      loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs +      loop yet (Recurse x : xs) = do +        primeName <- newName (x ++ "'") +        let appPrePrime e (Plain y) = e `AppE` itemVar y +            appPrePrime e (Recurse y) = e `AppE` itemVar y +        let stmt = BindS (VarP primeName) $ +                     VarE (mkName "within") +                     `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs) +                     `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x)) +        stmts <- loop (yet `AppE` VarE primeName) xs +        return (stmt : stmts) + +      (prefix, items') = takePrefix topitems +  in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items' + +data Template = Template [Item] +  deriving (Show) + +data Item = Plain String | Recurse String +  deriving (Show) + +pTemplate :: ReadP Template +pTemplate = do +  items <- many (skipSpaces >> pItem) +  skipSpaces +  eof +  return (Template items) + +pItem :: ReadP Item +pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName) + +pName :: ReadP String +pName = do +  c1 <- satisfy (\c -> isAlpha c || c == '_') +  cs <- munch (\c -> isAlphaNum c || c `elem` "_'") +  return (c1:cs) | 
