diff options
Diffstat (limited to 'src/Simplify.hs')
| -rw-r--r-- | src/Simplify.hs | 300 |
1 files changed, 0 insertions, 300 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs deleted file mode 100644 index 0bf5482..0000000 --- a/src/Simplify.hs +++ /dev/null @@ -1,300 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Simplify ( - simplifyN, simplifyFix, - SimplifyConfig(..), simplifyWith, simplifyFixWith, -) where - -import Data.Function (fix) -import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) - -import AST -import AST.Count -import CHAD.Types -import Data - - --- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some. -data SimplifyConfig = SimplifyConfig - -defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig - -simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t -simplifyN 0 = id -simplifyN n = simplifyN (n - 1) . simplify - -simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplify = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = defaultSimplifyConfig - in snd . simplify' - -simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in snd . simplify' - -simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplifyFix = simplifyFixWith defaultSimplifyConfig - -simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyFixWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in fix $ \loop e -> - let (Any act, e') = simplify' e - in if act then loop e' else e' - -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) -simplify' = \case - -- inlining - ELet _ rhs body - | cheapExpr rhs - -> acted $ simplify' (subst1 rhs body) - - | Occ lexOcc runOcc <- occCount IZ body - , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply - || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not - -> acted $ simplify' (subst1 rhs body) - - -- let splitting - ELet _ (EPair _ a b) body -> - acted $ simplify' $ - ELet ext a $ - ELet ext (weakenExpr WSink b) $ - subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) - IS i -> EVar ext t (IS (IS i))) - body - - -- let rotation - ELet _ (ELet _ rhs a) b -> - acted $ simplify' $ - ELet ext rhs $ - ELet ext a $ - weakenExpr (WCopy WSink) (snd (simplify' b)) - - -- beta rules for products - EFst _ (EPair _ e e') - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - ESnd _ (EPair _ e' e) - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - - -- beta rules for coproducts - ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) - ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) - - -- beta rules for maybe - EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 - EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 - - -- let floating to facilitate beta reduction - EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) - ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) - ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) - EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) - EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - - -- projection down-commuting - EFst _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (EFst ext e2) (EFst ext e3) - ESnd _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (ESnd ext e2) (ESnd ext e3) - - -- TODO: array indexing (index of build, index of fold) - - -- TODO: beta rules for maybe - - -- TODO: constant folding for operations - - -- monoid rules - EAccum _ t p e1 e2 acc -> do - acc' <- simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1 e2) - (Any True, ENil ext) - (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc')) - EPlus _ _ (EZero _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _) -> acted $ simplify' e - EOneHot _ t p e1 e2 -> - simplifyOneHotTerm (OneHotTerm t p e1 e2) - (Any True, EZero ext t) - (\e -> (Any True, e)) - (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2')) - - -- type-specific equations for plus - EPlus _ STNil _ _ -> (Any True, ENil ext) - - EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> - acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) - EPlus _ STPair{} ENothing{} e -> acted $ simplify' e - EPlus _ STPair{} e ENothing{} -> acted $ simplify' e - - EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> - acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) - EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> - acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) - EPlus _ STEither{} ENothing{} e -> acted $ simplify' e - EPlus _ STEither{} e ENothing{} -> acted $ simplify' e - - EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> - acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e - - -- fallback recursion - EVar _ t i -> pure $ EVar ext t i - ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b - EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b - EFst _ e -> EFst ext <$> simplify' e - ESnd _ e -> ESnd ext <$> simplify' e - ENil _ -> pure $ ENil ext - EInl _ t e -> EInl ext t <$> simplify' e - EInr _ t e -> EInr ext t <$> simplify' e - ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b - ENothing _ t -> pure $ ENothing ext t - EJust _ e -> EJust ext <$> simplify' e - EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e - EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b - EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c - ESum1Inner _ e -> ESum1Inner ext <$> simplify' e - EUnit _ e -> EUnit ext <$> simplify' e - EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b - EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e - EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e - EConst _ t v -> pure $ EConst ext t v - EIdx0 _ e -> EIdx0 ext <$> simplify' e - EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b - EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b - EShape _ e -> EShape ext <$> simplify' e - EOp _ op e -> EOp ext op <$> simplify' e - ECustom _ s t p a b c e1 e2 -> - ECustom ext s t p - <$> (let ?accumInScope = False in simplify' a) - <*> (let ?accumInScope = False in simplify' b) - <*> (let ?accumInScope = False in simplify' c) - <*> simplify' e1 <*> simplify' e2 - EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EZero _ t -> pure $ EZero ext t - EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b - EError _ t s -> pure $ EError ext t s - -acted :: (Any, a) -> (Any, a) -acted (_, x) = (Any True, x) - -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - _ -> False - --- | This can be made more precise by tracking (and not counting) adds on --- locally eliminated accumulators. -hasAdds :: Expr x env t -> Bool -hasAdds = \case - EVar _ _ _ -> False - ELet _ rhs body -> hasAdds rhs || hasAdds body - EPair _ a b -> hasAdds a || hasAdds b - EFst _ e -> hasAdds e - ESnd _ e -> hasAdds e - ENil _ -> False - EInl _ _ e -> hasAdds e - EInr _ _ e -> hasAdds e - ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b - ENothing _ _ -> False - EJust _ e -> hasAdds e - EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e - EConstArr _ _ _ _ -> False - EBuild _ _ a b -> hasAdds a || hasAdds b - EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - ESum1Inner _ e -> hasAdds e - EUnit _ e -> hasAdds e - EReplicate1Inner _ a b -> hasAdds a || hasAdds b - EMaximum1Inner _ e -> hasAdds e - EMinimum1Inner _ e -> hasAdds e - ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e - EConst _ _ _ -> False - EIdx0 _ e -> hasAdds e - EIdx1 _ a b -> hasAdds a || hasAdds b - EIdx _ a b -> hasAdds a || hasAdds b - EShape _ e -> hasAdds e - EOp _ _ e -> hasAdds e - EWith _ _ a b -> hasAdds a || hasAdds b - EAccum _ _ _ _ _ _ -> True - EZero _ _ -> False - EPlus _ _ a b -> hasAdds a || hasAdds b - EOneHot _ _ _ a b -> hasAdds a || hasAdds b - EError _ _ _ -> False - -checkAccumInScope :: SList STy env -> Bool -checkAccumInScope = \case SNil -> False - SCons t env -> check t || checkAccumInScope env - where - check :: STy t -> Bool - check STNil = False - check (STPair s t) = check s || check t - check (STEither s t) = check s || check t - check (STMaybe t) = check t - check (STArr _ t) = check t - check (STScal _) = False - check STAccum{} = True - -data OneHotTerm env p a b where - OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) - -simplifyOneHotTerm :: OneHotTerm env p a b - -> (Any, r) -- ^ Zero case (onehot is actually zero) - -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) - -> (Any, r) -simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - = do (Any True, ()) -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k -simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e -simplifyOneHotTerm term _ _ k = k term - -concatOneHots :: STy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (STPair a _, SAPFst prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 - (STPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 - - (STEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (STEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (STMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - (STArr n a, SAPArrIdx prj1' _) -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) |
