{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where import Control.Monad (ap) import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) import Data.Type.Equality (testEquality) import Debug.Trace import AST import AST.Count import AST.Pretty import Data import Simplify.TH data SimplifyConfig = SimplifyConfig { scLogging :: Bool } defaultSimplifyConfig :: SimplifyConfig defaultSimplifyConfig = SimplifyConfig False simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t simplifyN 0 = id simplifyN n = simplifyN (n - 1) . simplify simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t simplify = let ?accumInScope = checkAccumInScope @env knownEnv ?config = defaultSimplifyConfig in snd . runSM . simplify' simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t simplifyWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config in snd . runSM . simplify' simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t simplifyFix = simplifyFixWith defaultSimplifyConfig simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t simplifyFixWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config in fix $ \loop e -> let (act, e') = runSM (simplify' e) in if act then loop e' else e' -- | simplify monad newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) deriving (Functor) instance Applicative (SM tenv tt env t) where pure x = SM (\_ -> (Any False, x)) (<*>) = ap instance Monad (SM tenv tt env t) where SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx runSM :: SM env t env t a -> (Bool, a) runSM (SM f) = first getAny (f id) smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) smReconstruct core = SM (\ctx -> (Any False, ctx core)) tellActed :: SM tenv tt env t () tellActed = SM (\_ -> (Any True, ())) -- more convenient in practice acted :: SM tenv tt env t a -> SM tenv tt env t a acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do res <- simplify'Rec expr full <- smReconstruct res let printed = ppExpr knownEnv full replace a bs = concatMap (\x -> if x == a then bs else [x]) str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed | otherwise = "--- simplify step: " ++ printed traceM str return res | otherwise = simplify'Rec expr simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify'Rec = \case -- inlining ELet _ rhs body | cheapExpr rhs -> acted $ simplify' (substInline rhs body) | Occ lexOcc runOcc <- occCount IZ body , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not -> acted $ simplify' (substInline rhs body) -- let splitting / let peeling ELet _ (EPair _ a b) body -> acted $ simplify' $ ELet ext a $ ELet ext (weakenExpr WSink b) $ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) IS i -> EVar ext t (IS (IS i))) body ELet _ (EJust _ a) body -> acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body ELet _ (EInl _ t2 a) body -> acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body ELet _ (EInr _ t1 a) body -> acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body -- let rotation ELet _ (ELet _ rhs a) b -> do b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b acted $ simplify' $ ELet ext rhs $ ELet ext a $ weakenExpr (WCopy WSink) b' -- beta rules for products EFst _ (EPair _ e e') | not (hasAdds e') -> acted $ simplify' e | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) ESnd _ (EPair _ e' e) | not (hasAdds e') -> acted $ simplify' e | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) -- beta rules for coproducts ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) -- beta rules for maybe EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 -- let floating EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) EAccum _ t p e1 (ELet _ rhs body) acc -> acted $ simplify' $ ELet ext rhs $ EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) -- let () = e in () ~> e ELet _ e1 (ENil _) | STNil <- typeOf e1 -> acted $ simplify' e1 -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (EFst ext e2) (EFst ext e3) ESnd _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (ESnd ext e2) (ESnd ext e3) -- TODO: more array indexing EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) -- inline cheap array constructors ELet _ (EReplicate1Inner _ e1 e2) e3 -> acted $ simplify' $ ELet ext (EPair ext e1 e2) $ let v = EVar ext (STPair tIx (typeOf e2)) IZ in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3 -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is -- -- cheap, which it can't be because (!) is not cheap if you do AD after. -- -- Should do proper SoA representation. -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 -> -- acted $ simplify' $ -- ELet ext e1 $ -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3 -- eta rule for unit e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> case e of ENil _ -> return e _ -> acted $ return (ENil ext) EBuild _ SZ _ e -> acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules EAccum _ t p e1 e2 acc -> do e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc simplifyOneHotTerm (OneHotTerm t p e1' e2') (acted $ return (ENil ext)) (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 simplifyOneHotTerm (OneHotTerm t p e1' e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) (\e -> acted $ return e) (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> acted $ return (ENil ext) EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> acted $ simplify' $ EJust ext (EPlus ext t e1 e2) EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e -- fallback recursion EVar _ t i -> pure $ EVar ext t i ELet _ a b -> [simprec| ELet ext *a *b |] EPair _ a b -> [simprec| EPair ext *a *b |] EFst _ e -> [simprec| EFst ext *e |] ESnd _ e -> [simprec| ESnd ext *e |] ENil _ -> pure $ ENil ext EInl _ t e -> [simprec| EInl ext t *e |] EInr _ t e -> [simprec| EInr ext t *e |] ECase _ e a b -> [simprec| ECase ext *e *a *b |] ENothing _ t -> pure $ ENothing ext t EJust _ e -> [simprec| EJust ext *e |] EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 ELInl _ t e -> [simprec| ELInl ext t *e |] ELInr _ t e -> [simprec| ELInr ext t *e |] ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> [simprec| EBuild ext n *a *b |] EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] EUnit _ e -> [simprec| EUnit ext *e |] EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] EIdx _ a b -> [simprec| EIdx ext *a *b |] EShape _ e -> [simprec| EShape ext *e |] EOp _ op e -> [simprec| EOp ext op *e |] ECustom _ s t p a b c e1 e2 -> do a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) pure (ECustom ext s t p a' b' c' e1' e2') EWith _ t e1 e2 -> do e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b EError _ t s -> pure $ EError ext t s cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True ENil{} -> True EConst{} -> True EFst _ e -> cheapExpr e ESnd _ e -> cheapExpr e EUnit _ e -> cheapExpr e _ -> False -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. hasAdds :: Expr x env t -> Bool hasAdds = \case EVar _ _ _ -> False ELet _ rhs body -> hasAdds rhs || hasAdds body EPair _ a b -> hasAdds a || hasAdds b EFst _ e -> hasAdds e ESnd _ e -> hasAdds e ENil _ -> False EInl _ _ e -> hasAdds e EInr _ _ e -> hasAdds e ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b ENothing _ _ -> False EJust _ e -> hasAdds e EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e ELNil _ _ _ -> False ELInl _ _ e -> hasAdds e ELInr _ _ e -> hasAdds e ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b EIdx _ a b -> hasAdds a || hasAdds b EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b EAccum _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False checkAccumInScope :: SList STy env -> Bool checkAccumInScope = \case SNil -> False SCons t env -> check t || checkAccumInScope env where check :: STy t -> Bool check STNil = False check (STPair s t) = check s || check t check (STEither s t) = check s || check t check (STMaybe t) = check t check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True check (STLEither s t) = check s || check t data OneHotTerm env p a b where OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b deriving instance Show (OneHotTerm env p a b) simplifyOneHotTerm :: OneHotTerm env p a b -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) -> SM tenv tt env t r simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k | Just Refl <- testEquality (acPrjTy prj1 t1) t2 = do tellActed -- record, whatever happens later, that we've modified something concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k -- TODO: This does not actually recurse unless it just so happens to contain -- another EZero or EOnehot in the final position. Should match on something -- more general than SAPHere here. simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of (SMTNil, _) -> kzero (SMTPair{}, EPair _ e1 (EZero _ _ ezi)) -> simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) (EPair ext (ENil ext) ezi) e1) kzero ktriv k (SMTPair{}, EPair _ (EZero _ _ ezi) e2) -> simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) (EPair ext ezi (ENil ext)) e2) kzero ktriv k (SMTLEither{}, ELNil _ _ _) -> kzero (SMTLEither{}, ELInl _ _ e1) -> simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) (ENil ext) e1) kzero ktriv k (SMTLEither{}, ELInr _ _ e2) -> simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) (ENil ext) e2) kzero ktriv k (SMTMaybe{}, ENothing _ _) -> kzero (SMTMaybe{}, EJust _ e1) -> simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) (ENil ext) e1) kzero ktriv k (SMTScal STI32, _) -> kzero (SMTScal STI64, _) -> kzero (SMTScal STF32, EConst _ _ 0.0) -> kzero (SMTScal STF64, EConst _ _ 0.0) -> kzero _ -> ktriv e simplifyOneHotTerm term _ _ k = k term concatOneHots :: SMTy a -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of (_, SAPHere) -> k prj2 idx2 (SMTPair a _, SAPFst prj1') -> concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) (SMTPair _ b, SAPSnd prj1') -> concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) (SMTLEither a _, SAPLeft prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 (SMTLEither _ b, SAPRight prj1') -> concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 (SMTMaybe a, SAPJust prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 (SMTArr _ a, SAPArrIdx prj1') -> concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) where -- invariant: expression argument is duplicable go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) go SMTNil _ = ENil ext go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) go SMTLEither{} _ = ENil ext go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext