summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs59
1 files changed, 45 insertions, 14 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 2177789..0aa7a66 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -14,6 +14,7 @@ module Simplify (
import Data.Function (fix)
import Data.Monoid (Any(..))
+import Data.Type.Equality (testEquality)
import AST
import AST.Count
@@ -105,10 +106,10 @@ simplify' = \case
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
-- projection down-commuting
- EFst _ (ECase _ e1 e2@EPair{} e3@EPair{}) ->
+ EFst _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (EFst ext e2) (EFst ext e3)
- ESnd _ (ECase _ e1 e2@EPair{} e3@EPair{}) ->
+ ESnd _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
@@ -118,16 +119,22 @@ simplify' = \case
-- TODO: constant folding for operations
- -- TODO: properly concatenate accum/onehot
- EAccum _ SZ _ (EOneHot _ _ i idx val) acc ->
- acted $ simplify' $
- EAccum ext i idx val acc
- EAccum _ _ _ (EZero _ _) _ -> (Any True, ENil ext)
+ -- monoid rules
+ EAccum _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) acc
+ | Just Refl <- testEquality (acPrjTy prj1 t1) t2
+ -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ acted $ simplify' (EAccum ext t1 prj12 idx12 val acc)
+ EAccum _ _ _ _ (EZero _ _) _ -> (Any True, ENil ext)
EPlus _ _ (EZero _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _) -> acted $ simplify' e
- EOneHot _ _ SZ _ e -> acted $ simplify' e
-
- -- equations for plus
+ EOneHot _ t _ _ (EZero _ _) -> (Any True, EZero ext t)
+ EOneHot _ _ SAPHere _ e -> acted $ simplify' e
+ EOneHot _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)
+ | Just Refl <- testEquality (acPrjTy prj1 t1) t2
+ -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ acted $ simplify' (EOneHot ext t1 prj12 idx12 val)
+
+ -- type-specific equations for plus
EPlus _ STNil _ _ -> (Any True, ENil ext)
EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
@@ -180,8 +187,8 @@ simplify' = \case
<*> (let ?accumInScope = False in simplify' b)
<*> (let ?accumInScope = False in simplify' c)
<*> simplify' e1 <*> simplify' e2
- EWith _ e1 e2 -> EWith ext <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EAccum _ i e1 e2 e3 -> EAccum ext i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
+ EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
+ EAccum _ t i e1 e2 e3 -> EAccum ext t i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
EZero _ t -> pure $ EZero ext t
EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b
@@ -230,8 +237,8 @@ hasAdds = \case
EIdx _ a b -> hasAdds a || hasAdds b
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
- EWith _ a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ _ -> True
+ EWith _ _ a b -> hasAdds a || hasAdds b
+ EAccum _ _ _ _ _ _ -> True
EZero _ _ -> False
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
@@ -249,3 +256,27 @@ checkAccumInScope = \case SNil -> False
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
+
+concatOneHots :: STy a
+ -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
+ -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
+concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
+ (_, SAPHere) -> k prj2 idx2
+
+ (STPair a _, SAPFst prj1') ->
+ concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12
+ (STPair _ b, SAPSnd prj1') ->
+ concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12
+
+ (STEither a _, SAPLeft prj1') ->
+ concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
+ (STEither _ b, SAPRight prj1') ->
+ concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
+
+ (STMaybe a, SAPJust prj1') ->
+ concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
+
+ (STArr n a, SAPArrIdx prj1' _) ->
+ concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)