diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Simplify.hs | 180 | ||||
-rw-r--r-- | src/Simplify/TH.hs | 80 |
2 files changed, 197 insertions, 63 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index 2a1d3b6..469c7a1 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,8 +1,10 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -13,20 +15,27 @@ module Simplify ( SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where +import Control.Monad (ap) +import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) import Data.Type.Equality (testEquality) +import Debug.Trace + import AST import AST.Count +import AST.Pretty import Data +import Simplify.TH --- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some. data SimplifyConfig = SimplifyConfig + { scLogging :: Bool + } defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig +defaultSimplifyConfig = SimplifyConfig False simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t simplifyN 0 = id @@ -36,13 +45,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t simplify = let ?accumInScope = checkAccumInScope @env knownEnv ?config = defaultSimplifyConfig - in snd . simplify' + in snd . runSM . simplify' simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t simplifyWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config - in snd . simplify' + in snd . runSM . simplify' simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t simplifyFix = simplifyFixWith defaultSimplifyConfig @@ -52,11 +61,51 @@ simplifyFixWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config in fix $ \loop e -> - let (Any act, e') = simplify' e + let (act, e') = runSM (simplify' e) in if act then loop e' else e' -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) -simplify' = \case +-- | simplify monad +newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) + deriving (Functor) + +instance Applicative (SM tenv tt env t) where + pure x = SM (\_ -> (Any False, x)) + (<*>) = ap + +instance Monad (SM tenv tt env t) where + SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx + +runSM :: SM env t env t a -> (Bool, a) +runSM (SM f) = first getAny (f id) + +smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) +smReconstruct core = SM (\ctx -> (Any False, ctx core)) + +tellActed :: SM tenv tt env t () +tellActed = SM (\_ -> (Any True, ())) + +-- more convenient in practice +acted :: SM tenv tt env t a -> SM tenv tt env t a +acted m = tellActed >> m + +within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a +within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) + +simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify' expr + | scLogging ?config = do + res <- simplify'Rec expr + full <- smReconstruct res + let printed = ppExpr knownEnv full + replace a bs = concatMap (\x -> if x == a then bs else [x]) + str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed + | otherwise = "--- simplify step: " ++ printed + traceM str + return res + | otherwise = simplify'Rec expr + +simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify'Rec = \case -- inlining ELet _ rhs body | cheapExpr rhs @@ -83,11 +132,12 @@ simplify' = \case acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body -- let rotation - ELet _ (ELet _ rhs a) b -> + ELet _ (ELet _ rhs a) b -> do + b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b acted $ simplify' $ ELet ext rhs $ ELet ext a $ - weakenExpr (WCopy WSink) (snd (simplify' b)) + weakenExpr (WCopy WSink) b' -- beta rules for products EFst _ (EPair _ e e') @@ -133,8 +183,8 @@ simplify' = \case EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 -- TODO: more constant folding - EOp _ OIf (EConst _ STBool True) -> (Any True, EInl ext STNil (ENil ext)) - EOp _ OIf (EConst _ STBool False) -> (Any True, EInr ext STNil (ENil ext)) + EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) + EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) -- inline cheap array constructors ELet _ (EReplicate1Inner _ e1 e2) e3 -> @@ -153,29 +203,29 @@ simplify' = \case -- eta rule for unit e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> case e of - ENil _ -> (Any False, e) - _ -> (Any True, ENil ext) + ENil _ -> return e + _ -> acted $ return (ENil ext) EBuild _ SZ _ e -> acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules EAccum _ t p e1 e2 acc -> do - e1' <- simplify' e1 - e2' <- simplify' e2 - acc' <- simplify' acc + e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc simplifyOneHotTerm (OneHotTerm t p e1' e2') - (Any True, ENil ext) - (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) + (acted $ return (ENil ext)) + (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do - e1' <- simplify' e1 - e2' <- simplify' e2 + e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 + e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 simplifyOneHotTerm (OneHotTerm t p e1' e2') - (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2)) - (\e -> (Any True, e)) + (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) + (\e -> acted $ return e) (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus @@ -198,49 +248,50 @@ simplify' = \case -- fallback recursion EVar _ t i -> pure $ EVar ext t i - ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b - EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b - EFst _ e -> EFst ext <$> simplify' e - ESnd _ e -> ESnd ext <$> simplify' e + ELet _ a b -> [simprec| ELet ext *a *b |] + EPair _ a b -> [simprec| EPair ext *a *b |] + EFst _ e -> [simprec| EFst ext *e |] + ESnd _ e -> [simprec| ESnd ext *e |] ENil _ -> pure $ ENil ext - EInl _ t e -> EInl ext t <$> simplify' e - EInr _ t e -> EInr ext t <$> simplify' e - ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b + EInl _ t e -> [simprec| EInl ext t *e |] + EInr _ t e -> [simprec| EInr ext t *e |] + ECase _ e a b -> [simprec| ECase ext *e *a *b |] ENothing _ t -> pure $ ENothing ext t - EJust _ e -> EJust ext <$> simplify' e - EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e + EJust _ e -> [simprec| EJust ext *e |] + EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t <$> simplify' e - ELInr _ t e -> ELInr ext t <$> simplify' e - ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c + ELInl _ t e -> [simprec| ELInl ext t *e |] + ELInr _ t e -> [simprec| ELInr ext t *e |] + ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b - EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c - ESum1Inner _ e -> ESum1Inner ext <$> simplify' e - EUnit _ e -> EUnit ext <$> simplify' e - EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b - EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e - EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e + EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] + ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] + EUnit _ e -> [simprec| EUnit ext *e |] + EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] + EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] + EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] EConst _ t v -> pure $ EConst ext t v - EIdx0 _ e -> EIdx0 ext <$> simplify' e - EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b - EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b - EShape _ e -> EShape ext <$> simplify' e - EOp _ op e -> EOp ext op <$> simplify' e - ECustom _ s t p a b c e1 e2 -> - ECustom ext s t p - <$> (let ?accumInScope = False in simplify' a) - <*> (let ?accumInScope = False in simplify' b) - <*> (let ?accumInScope = False in simplify' c) - <*> simplify' e1 <*> simplify' e2 - EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EZero _ t e -> EZero ext t <$> simplify' e - EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b + EIdx0 _ e -> [simprec| EIdx0 ext *e |] + EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] + EIdx _ a b -> [simprec| EIdx ext *a *b |] + EShape _ e -> [simprec| EShape ext *e |] + EOp _ op e -> [simprec| EOp ext op *e |] + ECustom _ s t p a b c e1 e2 -> do + a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) + b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) + c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) + e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) + e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) + pure (ECustom ext s t p a' b' c' e1' e2') + EWith _ t e1 e2 -> do + e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) + e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) + pure (EWith ext t e1' e2') + EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b EError _ t s -> pure $ EError ext t s -acted :: (Any, a) -> (Any, a) -acted (_, x) = (Any True, x) - cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True @@ -312,18 +363,21 @@ data OneHotTerm env p a b where deriving instance Show (OneHotTerm env p a b) simplifyOneHotTerm :: OneHotTerm env p a b - -> (Any, r) -- ^ Zero case (onehot is actually zero) - -> (Ex env a -> (Any, r)) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) - -> (Any, r) + -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) + -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) + -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) + -> SM tenv tt env t r simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - = do (Any True, ()) -- record, whatever happens later, that we've modified something + = do tellActed -- record, whatever happens later, that we've modified something concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k +-- TODO: This does not actually recurse unless it just so happens to contain +-- another EZero or EOnehot in the final position. Should match on something +-- more general than SAPHere here. simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of (SMTNil, _) -> kzero diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs new file mode 100644 index 0000000..2e0076a --- /dev/null +++ b/src/Simplify/TH.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Simplify.TH (simprec) where + +import Data.Bifunctor (first) +import Data.Char +import Data.List (foldl1') +import Language.Haskell.TH +import Language.Haskell.TH.Quote +import Text.ParserCombinators.ReadP + + +-- [simprec| EPair ext *a *b |] +-- ~> +-- do a' <- within (\a' -> EPair ext a' b) (simplify' a) +-- b' <- within (\b' -> EPair ext a' b') (simplify' b) +-- pure (EPair ext a' b') + +simprec :: QuasiQuoter +simprec = QuasiQuoter + { quoteDec = \_ -> fail "simprec used outside of expression context" + , quoteType = \_ -> fail "simprec used outside of expression context" + , quoteExp = handler + , quotePat = \_ -> fail "simprec used outside of expression context" + } + +handler :: String -> Q Exp +handler str = + case readP_to_S pTemplate str of + [(template, "")] -> generate template + _:_:_ -> fail "simprec: template grammar ambiguous" + _ -> fail "simprec: could not parse template" + +generate :: Template -> Q Exp +generate (Template topitems) = + let takePrefix (Plain x : xs) = first (x:) (takePrefix xs) + takePrefix xs = ([], xs) + + itemVar "" = error "simprec: empty item name?" + itemVar name@(c:_) | isLower c = VarE (mkName name) + | isUpper c = ConE (mkName name) + | otherwise = error "simprec: non-letter item name?" + + loop :: Exp -> [Item] -> Q [Stmt] + loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)] + loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs + loop yet (Recurse x : xs) = do + primeName <- newName (x ++ "'") + let appPrePrime e (Plain y) = e `AppE` itemVar y + appPrePrime e (Recurse y) = e `AppE` itemVar y + let stmt = BindS (VarP primeName) $ + VarE (mkName "within") + `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs) + `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x)) + stmts <- loop (yet `AppE` VarE primeName) xs + return (stmt : stmts) + + (prefix, items') = takePrefix topitems + in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items' + +data Template = Template [Item] + deriving (Show) + +data Item = Plain String | Recurse String + deriving (Show) + +pTemplate :: ReadP Template +pTemplate = do + items <- many (skipSpaces >> pItem) + skipSpaces + eof + return (Template items) + +pItem :: ReadP Item +pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName) + +pName :: ReadP String +pName = do + c1 <- satisfy (\c -> isAlpha c || c == '_') + cs <- munch (\c -> isAlphaNum c || c `elem` "_'") + return (c1:cs) |