diff options
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r-- | src/Simplify.hs | 59 |
1 files changed, 45 insertions, 14 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index 2177789..0aa7a66 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -14,6 +14,7 @@ module Simplify ( import Data.Function (fix) import Data.Monoid (Any(..)) +import Data.Type.Equality (testEquality) import AST import AST.Count @@ -105,10 +106,10 @@ simplify' = \case EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) -- projection down-commuting - EFst _ (ECase _ e1 e2@EPair{} e3@EPair{}) -> + EFst _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (EFst ext e2) (EFst ext e3) - ESnd _ (ECase _ e1 e2@EPair{} e3@EPair{}) -> + ESnd _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (ESnd ext e2) (ESnd ext e3) @@ -118,16 +119,22 @@ simplify' = \case -- TODO: constant folding for operations - -- TODO: properly concatenate accum/onehot - EAccum _ SZ _ (EOneHot _ _ i idx val) acc -> - acted $ simplify' $ - EAccum ext i idx val acc - EAccum _ _ _ (EZero _ _) _ -> (Any True, ENil ext) + -- monoid rules + EAccum _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) acc + | Just Refl <- testEquality (acPrjTy prj1 t1) t2 + -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + acted $ simplify' (EAccum ext t1 prj12 idx12 val acc) + EAccum _ _ _ _ (EZero _ _) _ -> (Any True, ENil ext) EPlus _ _ (EZero _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _) -> acted $ simplify' e - EOneHot _ _ SZ _ e -> acted $ simplify' e - - -- equations for plus + EOneHot _ t _ _ (EZero _ _) -> (Any True, EZero ext t) + EOneHot _ _ SAPHere _ e -> acted $ simplify' e + EOneHot _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) + | Just Refl <- testEquality (acPrjTy prj1 t1) t2 + -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + acted $ simplify' (EOneHot ext t1 prj12 idx12 val) + + -- type-specific equations for plus EPlus _ STNil _ _ -> (Any True, ENil ext) EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> @@ -180,8 +187,8 @@ simplify' = \case <*> (let ?accumInScope = False in simplify' b) <*> (let ?accumInScope = False in simplify' c) <*> simplify' e1 <*> simplify' e2 - EWith _ e1 e2 -> EWith ext <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EAccum _ i e1 e2 e3 -> EAccum ext i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 + EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) + EAccum _ t i e1 e2 e3 -> EAccum ext t i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 EZero _ t -> pure $ EZero ext t EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b @@ -230,8 +237,8 @@ hasAdds = \case EIdx _ a b -> hasAdds a || hasAdds b EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e - EWith _ a b -> hasAdds a || hasAdds b - EAccum _ _ _ _ _ -> True + EWith _ _ a b -> hasAdds a || hasAdds b + EAccum _ _ _ _ _ _ -> True EZero _ _ -> False EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b @@ -249,3 +256,27 @@ checkAccumInScope = \case SNil -> False check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True + +concatOneHots :: STy a + -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) + -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r +concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of + (_, SAPHere) -> k prj2 idx2 + + (STPair a _, SAPFst prj1') -> + concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 + (STPair _ b, SAPSnd prj1') -> + concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 + + (STEither a _, SAPLeft prj1') -> + concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 + (STEither _ b, SAPRight prj1') -> + concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 + + (STMaybe a, SAPJust prj1') -> + concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 + + (STArr n a, SAPArrIdx prj1' _) -> + concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) |