summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal3
-rw-r--r--src/Simplify.hs180
-rw-r--r--src/Simplify/TH.hs80
3 files changed, 199 insertions, 64 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 1aadc6b..b0ed639 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -46,6 +46,7 @@ library
Language.AST
Lemmas
Simplify
+ Simplify.TH
Util.IdGen
other-modules:
build-depends:
@@ -53,10 +54,10 @@ library
containers,
deepseq,
directory,
- -- template-haskell,
prettyprinter,
process,
some,
+ template-haskell,
transformers,
unix,
vector,
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 2a1d3b6..469c7a1 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,8 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -13,20 +15,27 @@ module Simplify (
SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
) where
+import Control.Monad (ap)
+import Data.Bifunctor (first)
import Data.Function (fix)
import Data.Monoid (Any(..))
import Data.Type.Equality (testEquality)
+import Debug.Trace
+
import AST
import AST.Count
+import AST.Pretty
import Data
+import Simplify.TH
--- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.
data SimplifyConfig = SimplifyConfig
+ { scLogging :: Bool
+ }
defaultSimplifyConfig :: SimplifyConfig
-defaultSimplifyConfig = SimplifyConfig
+defaultSimplifyConfig = SimplifyConfig False
simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
@@ -36,13 +45,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplify =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = defaultSimplifyConfig
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
simplifyWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplifyFix = simplifyFixWith defaultSimplifyConfig
@@ -52,11 +61,51 @@ simplifyFixWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
in fix $ \loop e ->
- let (Any act, e') = simplify' e
+ let (act, e') = runSM (simplify' e)
in if act then loop e' else e'
-simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t)
-simplify' = \case
+-- | simplify monad
+newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a))
+ deriving (Functor)
+
+instance Applicative (SM tenv tt env t) where
+ pure x = SM (\_ -> (Any False, x))
+ (<*>) = ap
+
+instance Monad (SM tenv tt env t) where
+ SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx
+
+runSM :: SM env t env t a -> (Bool, a)
+runSM (SM f) = first getAny (f id)
+
+smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
+smReconstruct core = SM (\ctx -> (Any False, ctx core))
+
+tellActed :: SM tenv tt env t ()
+tellActed = SM (\_ -> (Any True, ()))
+
+-- more convenient in practice
+acted :: SM tenv tt env t a -> SM tenv tt env t a
+acted m = tellActed >> m
+
+within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
+within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
+
+simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify' expr
+ | scLogging ?config = do
+ res <- simplify'Rec expr
+ full <- smReconstruct res
+ let printed = ppExpr knownEnv full
+ replace a bs = concatMap (\x -> if x == a then bs else [x])
+ str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed
+ | otherwise = "--- simplify step: " ++ printed
+ traceM str
+ return res
+ | otherwise = simplify'Rec expr
+
+simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify'Rec = \case
-- inlining
ELet _ rhs body
| cheapExpr rhs
@@ -83,11 +132,12 @@ simplify' = \case
acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body
-- let rotation
- ELet _ (ELet _ rhs a) b ->
+ ELet _ (ELet _ rhs a) b -> do
+ b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b
acted $ simplify' $
ELet ext rhs $
ELet ext a $
- weakenExpr (WCopy WSink) (snd (simplify' b))
+ weakenExpr (WCopy WSink) b'
-- beta rules for products
EFst _ (EPair _ e e')
@@ -133,8 +183,8 @@ simplify' = \case
EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
-- TODO: more constant folding
- EOp _ OIf (EConst _ STBool True) -> (Any True, EInl ext STNil (ENil ext))
- EOp _ OIf (EConst _ STBool False) -> (Any True, EInr ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
-- inline cheap array constructors
ELet _ (EReplicate1Inner _ e1 e2) e3 ->
@@ -153,29 +203,29 @@ simplify' = \case
-- eta rule for unit
e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
case e of
- ENil _ -> (Any False, e)
- _ -> (Any True, ENil ext)
+ ENil _ -> return e
+ _ -> acted $ return (ENil ext)
EBuild _ SZ _ e ->
acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
EAccum _ t p e1 e2 acc -> do
- e1' <- simplify' e1
- e2' <- simplify' e2
- acc' <- simplify' acc
+ e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc
simplifyOneHotTerm (OneHotTerm t p e1' e2')
- (Any True, ENil ext)
- (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
+ (acted $ return (ENil ext))
+ (\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))
(\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
- e1' <- simplify' e1
- e2' <- simplify' e2
+ e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
+ e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
simplifyOneHotTerm (OneHotTerm t p e1' e2')
- (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2))
- (\e -> (Any True, e))
+ (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
+ (\e -> acted $ return e)
(\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
-- type-specific equations for plus
@@ -198,49 +248,50 @@ simplify' = \case
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
- ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
- EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
- EFst _ e -> EFst ext <$> simplify' e
- ESnd _ e -> ESnd ext <$> simplify' e
+ ELet _ a b -> [simprec| ELet ext *a *b |]
+ EPair _ a b -> [simprec| EPair ext *a *b |]
+ EFst _ e -> [simprec| EFst ext *e |]
+ ESnd _ e -> [simprec| ESnd ext *e |]
ENil _ -> pure $ ENil ext
- EInl _ t e -> EInl ext t <$> simplify' e
- EInr _ t e -> EInr ext t <$> simplify' e
- ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
+ EInl _ t e -> [simprec| EInl ext t *e |]
+ EInr _ t e -> [simprec| EInr ext t *e |]
+ ECase _ e a b -> [simprec| ECase ext *e *a *b |]
ENothing _ t -> pure $ ENothing ext t
- EJust _ e -> EJust ext <$> simplify' e
- EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
+ EJust _ e -> [simprec| EJust ext *e |]
+ EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]
ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
- ELInl _ t e -> ELInl ext t <$> simplify' e
- ELInr _ t e -> ELInr ext t <$> simplify' e
- ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c
+ ELInl _ t e -> [simprec| ELInl ext t *e |]
+ ELInr _ t e -> [simprec| ELInr ext t *e |]
+ ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
- EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c
- ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
- EUnit _ e -> EUnit ext <$> simplify' e
- EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
- EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
- EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
+ EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
+ EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
+ ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
+ EUnit _ e -> [simprec| EUnit ext *e |]
+ EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
+ EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
+ EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EConst _ t v -> pure $ EConst ext t v
- EIdx0 _ e -> EIdx0 ext <$> simplify' e
- EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
- EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
- EShape _ e -> EShape ext <$> simplify' e
- EOp _ op e -> EOp ext op <$> simplify' e
- ECustom _ s t p a b c e1 e2 ->
- ECustom ext s t p
- <$> (let ?accumInScope = False in simplify' a)
- <*> (let ?accumInScope = False in simplify' b)
- <*> (let ?accumInScope = False in simplify' c)
- <*> simplify' e1 <*> simplify' e2
- EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EZero _ t e -> EZero ext t <$> simplify' e
- EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
+ EIdx0 _ e -> [simprec| EIdx0 ext *e |]
+ EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
+ EIdx _ a b -> [simprec| EIdx ext *a *b |]
+ EShape _ e -> [simprec| EShape ext *e |]
+ EOp _ op e -> [simprec| EOp ext op *e |]
+ ECustom _ s t p a b c e1 e2 -> do
+ a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a)
+ b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b)
+ c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c)
+ e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
+ pure (ECustom ext s t p a' b' c' e1' e2')
+ EWith _ t e1 e2 -> do
+ e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
+ pure (EWith ext t e1' e2')
+ EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b
EError _ t s -> pure $ EError ext t s
-acted :: (Any, a) -> (Any, a)
-acted (_, x) = (Any True, x)
-
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
EVar{} -> True
@@ -312,18 +363,21 @@ data OneHotTerm env p a b where
deriving instance Show (OneHotTerm env p a b)
simplifyOneHotTerm :: OneHotTerm env p a b
- -> (Any, r) -- ^ Zero case (onehot is actually zero)
- -> (Ex env a -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
- -> (Any, r)
+ -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero)
+ -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
+ -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r)
+ -> SM tenv tt env t r
simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero
simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
| Just Refl <- testEquality (acPrjTy prj1 t1) t2
- = do (Any True, ()) -- record, whatever happens later, that we've modified something
+ = do tellActed -- record, whatever happens later, that we've modified something
concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
+-- TODO: This does not actually recurse unless it just so happens to contain
+-- another EZero or EOnehot in the final position. Should match on something
+-- more general than SAPHere here.
simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of
(SMTNil, _) -> kzero
diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs
new file mode 100644
index 0000000..2e0076a
--- /dev/null
+++ b/src/Simplify/TH.hs
@@ -0,0 +1,80 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Simplify.TH (simprec) where
+
+import Data.Bifunctor (first)
+import Data.Char
+import Data.List (foldl1')
+import Language.Haskell.TH
+import Language.Haskell.TH.Quote
+import Text.ParserCombinators.ReadP
+
+
+-- [simprec| EPair ext *a *b |]
+-- ~>
+-- do a' <- within (\a' -> EPair ext a' b) (simplify' a)
+-- b' <- within (\b' -> EPair ext a' b') (simplify' b)
+-- pure (EPair ext a' b')
+
+simprec :: QuasiQuoter
+simprec = QuasiQuoter
+ { quoteDec = \_ -> fail "simprec used outside of expression context"
+ , quoteType = \_ -> fail "simprec used outside of expression context"
+ , quoteExp = handler
+ , quotePat = \_ -> fail "simprec used outside of expression context"
+ }
+
+handler :: String -> Q Exp
+handler str =
+ case readP_to_S pTemplate str of
+ [(template, "")] -> generate template
+ _:_:_ -> fail "simprec: template grammar ambiguous"
+ _ -> fail "simprec: could not parse template"
+
+generate :: Template -> Q Exp
+generate (Template topitems) =
+ let takePrefix (Plain x : xs) = first (x:) (takePrefix xs)
+ takePrefix xs = ([], xs)
+
+ itemVar "" = error "simprec: empty item name?"
+ itemVar name@(c:_) | isLower c = VarE (mkName name)
+ | isUpper c = ConE (mkName name)
+ | otherwise = error "simprec: non-letter item name?"
+
+ loop :: Exp -> [Item] -> Q [Stmt]
+ loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)]
+ loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs
+ loop yet (Recurse x : xs) = do
+ primeName <- newName (x ++ "'")
+ let appPrePrime e (Plain y) = e `AppE` itemVar y
+ appPrePrime e (Recurse y) = e `AppE` itemVar y
+ let stmt = BindS (VarP primeName) $
+ VarE (mkName "within")
+ `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs)
+ `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x))
+ stmts <- loop (yet `AppE` VarE primeName) xs
+ return (stmt : stmts)
+
+ (prefix, items') = takePrefix topitems
+ in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items'
+
+data Template = Template [Item]
+ deriving (Show)
+
+data Item = Plain String | Recurse String
+ deriving (Show)
+
+pTemplate :: ReadP Template
+pTemplate = do
+ items <- many (skipSpaces >> pItem)
+ skipSpaces
+ eof
+ return (Template items)
+
+pItem :: ReadP Item
+pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName)
+
+pName :: ReadP String
+pName = do
+ c1 <- satisfy (\c -> isAlpha c || c == '_')
+ cs <- munch (\c -> isAlphaNum c || c `elem` "_'")
+ return (c1:cs)