aboutsummaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs300
1 files changed, 0 insertions, 300 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
deleted file mode 100644
index 0bf5482..0000000
--- a/src/Simplify.hs
+++ /dev/null
@@ -1,300 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImplicitParams #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Simplify (
- simplifyN, simplifyFix,
- SimplifyConfig(..), simplifyWith, simplifyFixWith,
-) where
-
-import Data.Function (fix)
-import Data.Monoid (Any(..))
-import Data.Type.Equality (testEquality)
-
-import AST
-import AST.Count
-import CHAD.Types
-import Data
-
-
--- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.
-data SimplifyConfig = SimplifyConfig
-
-defaultSimplifyConfig :: SimplifyConfig
-defaultSimplifyConfig = SimplifyConfig
-
-simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
-simplifyN 0 = id
-simplifyN n = simplifyN (n - 1) . simplify
-
-simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplify =
- let ?accumInScope = checkAccumInScope @env knownEnv
- ?config = defaultSimplifyConfig
- in snd . simplify'
-
-simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
-simplifyWith config =
- let ?accumInScope = checkAccumInScope @env knownEnv
- ?config = config
- in snd . simplify'
-
-simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplifyFix = simplifyFixWith defaultSimplifyConfig
-
-simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
-simplifyFixWith config =
- let ?accumInScope = checkAccumInScope @env knownEnv
- ?config = config
- in fix $ \loop e ->
- let (Any act, e') = simplify' e
- in if act then loop e' else e'
-
-simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t)
-simplify' = \case
- -- inlining
- ELet _ rhs body
- | cheapExpr rhs
- -> acted $ simplify' (subst1 rhs body)
-
- | Occ lexOcc runOcc <- occCount IZ body
- , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply
- || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not
- -> acted $ simplify' (subst1 rhs body)
-
- -- let splitting
- ELet _ (EPair _ a b) body ->
- acted $ simplify' $
- ELet ext a $
- ELet ext (weakenExpr WSink b) $
- subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
- IS i -> EVar ext t (IS (IS i)))
- body
-
- -- let rotation
- ELet _ (ELet _ rhs a) b ->
- acted $ simplify' $
- ELet ext rhs $
- ELet ext a $
- weakenExpr (WCopy WSink) (snd (simplify' b))
-
- -- beta rules for products
- EFst _ (EPair _ e e')
- | not (hasAdds e') -> acted $ simplify' e
- | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e)
- ESnd _ (EPair _ e' e)
- | not (hasAdds e') -> acted $ simplify' e
- | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e)
-
- -- beta rules for coproducts
- ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs)
- ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs)
-
- -- beta rules for maybe
- EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
- EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1
-
- -- let floating to facilitate beta reduction
- EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
- ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
- ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
- EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
- EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
-
- -- projection down-commuting
- EFst _ (ECase _ e1 e2 e3) ->
- acted $ simplify' $
- ECase ext e1 (EFst ext e2) (EFst ext e3)
- ESnd _ (ECase _ e1 e2 e3) ->
- acted $ simplify' $
- ECase ext e1 (ESnd ext e2) (ESnd ext e3)
-
- -- TODO: array indexing (index of build, index of fold)
-
- -- TODO: beta rules for maybe
-
- -- TODO: constant folding for operations
-
- -- monoid rules
- EAccum _ t p e1 e2 acc -> do
- acc' <- simplify' acc
- simplifyOneHotTerm (OneHotTerm t p e1 e2)
- (Any True, ENil ext)
- (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc'))
- EPlus _ _ (EZero _ _) e -> acted $ simplify' e
- EPlus _ _ e (EZero _ _) -> acted $ simplify' e
- EOneHot _ t p e1 e2 ->
- simplifyOneHotTerm (OneHotTerm t p e1 e2)
- (Any True, EZero ext t)
- (\e -> (Any True, e))
- (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2'))
-
- -- type-specific equations for plus
- EPlus _ STNil _ _ -> (Any True, ENil ext)
-
- EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
- acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2))
- EPlus _ STPair{} ENothing{} e -> acted $ simplify' e
- EPlus _ STPair{} e ENothing{} -> acted $ simplify' e
-
- EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) ->
- acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2))
- EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) ->
- acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2))
- EPlus _ STEither{} ENothing{} e -> acted $ simplify' e
- EPlus _ STEither{} e ENothing{} -> acted $ simplify' e
-
- EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) ->
- acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
- EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e
- EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e
-
- -- fallback recursion
- EVar _ t i -> pure $ EVar ext t i
- ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
- EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
- EFst _ e -> EFst ext <$> simplify' e
- ESnd _ e -> ESnd ext <$> simplify' e
- ENil _ -> pure $ ENil ext
- EInl _ t e -> EInl ext t <$> simplify' e
- EInr _ t e -> EInr ext t <$> simplify' e
- ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
- ENothing _ t -> pure $ ENothing ext t
- EJust _ e -> EJust ext <$> simplify' e
- EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
- EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
- EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c
- ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
- EUnit _ e -> EUnit ext <$> simplify' e
- EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
- EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
- EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
- EConst _ t v -> pure $ EConst ext t v
- EIdx0 _ e -> EIdx0 ext <$> simplify' e
- EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
- EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
- EShape _ e -> EShape ext <$> simplify' e
- EOp _ op e -> EOp ext op <$> simplify' e
- ECustom _ s t p a b c e1 e2 ->
- ECustom ext s t p
- <$> (let ?accumInScope = False in simplify' a)
- <*> (let ?accumInScope = False in simplify' b)
- <*> (let ?accumInScope = False in simplify' c)
- <*> simplify' e1 <*> simplify' e2
- EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EZero _ t -> pure $ EZero ext t
- EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
- EError _ t s -> pure $ EError ext t s
-
-acted :: (Any, a) -> (Any, a)
-acted (_, x) = (Any True, x)
-
-cheapExpr :: Expr x env t -> Bool
-cheapExpr = \case
- EVar{} -> True
- ENil{} -> True
- EConst{} -> True
- EFst _ e -> cheapExpr e
- ESnd _ e -> cheapExpr e
- _ -> False
-
--- | This can be made more precise by tracking (and not counting) adds on
--- locally eliminated accumulators.
-hasAdds :: Expr x env t -> Bool
-hasAdds = \case
- EVar _ _ _ -> False
- ELet _ rhs body -> hasAdds rhs || hasAdds body
- EPair _ a b -> hasAdds a || hasAdds b
- EFst _ e -> hasAdds e
- ESnd _ e -> hasAdds e
- ENil _ -> False
- EInl _ _ e -> hasAdds e
- EInr _ _ e -> hasAdds e
- ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
- ENothing _ _ -> False
- EJust _ e -> hasAdds e
- EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
- EConstArr _ _ _ _ -> False
- EBuild _ _ a b -> hasAdds a || hasAdds b
- EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
- ESum1Inner _ e -> hasAdds e
- EUnit _ e -> hasAdds e
- EReplicate1Inner _ a b -> hasAdds a || hasAdds b
- EMaximum1Inner _ e -> hasAdds e
- EMinimum1Inner _ e -> hasAdds e
- ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
- EConst _ _ _ -> False
- EIdx0 _ e -> hasAdds e
- EIdx1 _ a b -> hasAdds a || hasAdds b
- EIdx _ a b -> hasAdds a || hasAdds b
- EShape _ e -> hasAdds e
- EOp _ _ e -> hasAdds e
- EWith _ _ a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ _ _ -> True
- EZero _ _ -> False
- EPlus _ _ a b -> hasAdds a || hasAdds b
- EOneHot _ _ _ a b -> hasAdds a || hasAdds b
- EError _ _ _ -> False
-
-checkAccumInScope :: SList STy env -> Bool
-checkAccumInScope = \case SNil -> False
- SCons t env -> check t || checkAccumInScope env
- where
- check :: STy t -> Bool
- check STNil = False
- check (STPair s t) = check s || check t
- check (STEither s t) = check s || check t
- check (STMaybe t) = check t
- check (STArr _ t) = check t
- check (STScal _) = False
- check STAccum{} = True
-
-data OneHotTerm env p a b where
- OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b
-deriving instance Show (OneHotTerm env p a b)
-
-simplifyOneHotTerm :: OneHotTerm env p a b
- -> (Any, r) -- ^ Zero case (onehot is actually zero)
- -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
- -> (Any, r)
-simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2
- = do (Any True, ()) -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
-simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e
-simplifyOneHotTerm term _ _ k = k term
-
-concatOneHots :: STy a
- -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
-concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
- (_, SAPHere) -> k prj2 idx2
-
- (STPair a _, SAPFst prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12
- (STPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12
-
- (STEither a _, SAPLeft prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (STEither _ b, SAPRight prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
-
- (STMaybe a, SAPJust prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
-
- (STArr n a, SAPArrIdx prj1' _) ->
- concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)