{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), simplifyWith, simplifyFixWith, ) where import Data.Function (fix) import Data.Monoid (Any(..)) import Data.Type.Equality (testEquality) import AST import AST.Count import CHAD.Types import Data -- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some. data SimplifyConfig = SimplifyConfig defaultSimplifyConfig :: SimplifyConfig defaultSimplifyConfig = SimplifyConfig simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t simplifyN 0 = id simplifyN n = simplifyN (n - 1) . simplify simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t simplify = let ?accumInScope = checkAccumInScope @env knownEnv ?config = defaultSimplifyConfig in snd . simplify' simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t simplifyWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config in snd . simplify' simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t simplifyFix = simplifyFixWith defaultSimplifyConfig simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t simplifyFixWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config in fix $ \loop e -> let (Any act, e') = simplify' e in if act then loop e' else e' simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) simplify' = \case -- inlining ELet _ rhs body | cheapExpr rhs -> acted $ simplify' (subst1 rhs body) | Occ lexOcc runOcc <- occCount IZ body , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not -> acted $ simplify' (subst1 rhs body) -- let splitting ELet _ (EPair _ a b) body -> acted $ simplify' $ ELet ext a $ ELet ext (weakenExpr WSink b) $ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) IS i -> EVar ext t (IS (IS i))) body -- let rotation ELet _ (ELet _ rhs a) b -> acted $ simplify' $ ELet ext rhs $ ELet ext a $ weakenExpr (WCopy WSink) (snd (simplify' b)) -- beta rules for products EFst _ (EPair _ e e') | not (hasAdds e') -> acted $ simplify' e | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) ESnd _ (EPair _ e' e) | not (hasAdds e') -> acted $ simplify' e | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) -- beta rules for coproducts ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) -- beta rules for maybe EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 -- let floating to facilitate beta reduction EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (EFst ext e2) (EFst ext e3) ESnd _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (ESnd ext e2) (ESnd ext e3) -- TODO: array indexing (index of build, index of fold) -- TODO: beta rules for maybe -- TODO: constant folding for operations -- monoid rules EAccum _ t p e1 e2 acc -> do acc' <- simplify' acc simplifyOneHotTerm (OneHotTerm t p e1 e2) (Any True, ENil ext) (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc')) EPlus _ _ (EZero _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> simplifyOneHotTerm (OneHotTerm t p e1 e2) (Any True, EZero ext t) (\e -> (Any True, e)) (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2')) -- type-specific equations for plus EPlus _ STNil _ _ -> (Any True, ENil ext) EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) EPlus _ STPair{} ENothing{} e -> acted $ simplify' e EPlus _ STPair{} e ENothing{} -> acted $ simplify' e EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) EPlus _ STEither{} ENothing{} e -> acted $ simplify' e EPlus _ STEither{} e ENothing{} -> acted $ simplify' e EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> acted $ simplify' $ EJust ext (EPlus ext t e1 e2) EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e -- fallback recursion EVar _ t i -> pure $ EVar ext t i ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b EFst _ e -> EFst ext <$> simplify' e ESnd _ e -> ESnd ext <$> simplify' e ENil _ -> pure $ ENil ext EInl _ t e -> EInl ext t <$> simplify' e EInr _ t e -> EInr ext t <$> simplify' e ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b ENothing _ t -> pure $ ENothing ext t EJust _ e -> EJust ext <$> simplify' e EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c ESum1Inner _ e -> ESum1Inner ext <$> simplify' e EUnit _ e -> EUnit ext <$> simplify' e EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> EIdx0 ext <$> simplify' e EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b EShape _ e -> EShape ext <$> simplify' e EOp _ op e -> EOp ext op <$> simplify' e ECustom _ s t p a b c e1 e2 -> ECustom ext s t p <$> (let ?accumInScope = False in simplify' a) <*> (let ?accumInScope = False in simplify' b) <*> (let ?accumInScope = False in simplify' c) <*> simplify' e1 <*> simplify' e2 EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) EZero _ t -> pure $ EZero ext t EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b EError _ t s -> pure $ EError ext t s acted :: (Any, a) -> (Any, a) acted (_, x) = (Any True, x) cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True ENil{} -> True EConst{} -> True EFst _ e -> cheapExpr e ESnd _ e -> cheapExpr e _ -> False -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. hasAdds :: Expr x env t -> Bool hasAdds = \case EVar _ _ _ -> False ELet _ rhs body -> hasAdds rhs || hasAdds body EPair _ a b -> hasAdds a || hasAdds b EFst _ e -> hasAdds e ESnd _ e -> hasAdds e ENil _ -> False EInl _ _ e -> hasAdds e EInr _ _ e -> hasAdds e ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b ENothing _ _ -> False EJust _ e -> hasAdds e EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b EIdx _ a b -> hasAdds a || hasAdds b EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b EAccum _ _ _ _ _ _ -> True EZero _ _ -> False EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False checkAccumInScope :: SList STy env -> Bool checkAccumInScope = \case SNil -> False SCons t env -> check t || checkAccumInScope env where check :: STy t -> Bool check STNil = False check (STPair s t) = check s || check t check (STEither s t) = check s || check t check (STMaybe t) = check t check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True data OneHotTerm env p a b where OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b deriving instance Show (OneHotTerm env p a b) simplifyOneHotTerm :: OneHotTerm env p a b -> (Any, r) -- ^ Zero case (onehot is actually zero) -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) -> (Any, r) simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k | Just Refl <- testEquality (acPrjTy prj1 t1) t2 = do (Any True, ()) -- record, whatever happens later, that we've modified something concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e simplifyOneHotTerm term _ _ k = k term concatOneHots :: STy a -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of (_, SAPHere) -> k prj2 idx2 (STPair a _, SAPFst prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 (STPair _ b, SAPSnd prj1') -> concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 (STEither a _, SAPLeft prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 (STEither _ b, SAPRight prj1') -> concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 (STMaybe a, SAPJust prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 (STArr n a, SAPArrIdx prj1' _) -> concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)