diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-18 12:54:27 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-18 12:54:27 +0200 |
commit | 7823027b3ff7508c303c0e6e68192a783b65a5c4 (patch) | |
tree | 49a96fa88071ecf281ff5f58b84fb5e0b3d97be8 | |
parent | bd5d0458017862b984b9caf0975c135d154e8515 (diff) |
Better simplification of onehots
-rw-r--r-- | src/Simplify.hs | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index 0bf5482..f5b7d15 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -123,18 +123,22 @@ simplify' = \case -- monoid rules EAccum _ t p e1 e2 acc -> do + e1' <- simplify' e1 + e2' <- simplify' e2 acc' <- simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1 e2) + simplifyOneHotTerm (OneHotTerm t p e1' e2') (Any True, ENil ext) (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc')) + (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) EPlus _ _ (EZero _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _) -> acted $ simplify' e - EOneHot _ t p e1 e2 -> - simplifyOneHotTerm (OneHotTerm t p e1 e2) + EOneHot _ t p e1 e2 -> do + e1' <- simplify' e1 + e2' <- simplify' e2 + simplifyOneHotTerm (OneHotTerm t p e1' e2') (Any True, EZero ext t) (\e -> (Any True, e)) - (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2')) + (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus EPlus _ STNil _ _ -> (Any True, ENil ext) @@ -267,12 +271,42 @@ simplifyOneHotTerm :: OneHotTerm env p a b -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) -> (Any, r) simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero + simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k | Just Refl <- testEquality (acPrjTy prj1 t1) t2 = do (Any True, ()) -- record, whatever happens later, that we've modified something concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k -simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e + +simplifyOneHotTerm (OneHotTerm t SAPHere idx e) kzero ktriv k = case (t, e) of + (STNil, _) -> kzero + + (STPair{}, ENothing _ _) -> kzero + (STPair{}, EJust _ (EPair _ e1 EZero{})) -> + simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) idx e1) kzero ktriv k + (STPair{}, EJust _ (EPair _ EZero{} e2)) -> + simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) idx e2) kzero ktriv k + + (STEither{}, ENothing _ _) -> kzero + (STEither{}, EJust _ (EInl _ _ e1)) -> + simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) idx e1) kzero ktriv k + (STEither{}, EJust _ (EInr _ _ e2)) -> + simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) idx e2) kzero ktriv k + + (STMaybe{}, ENothing _ _) -> kzero + (STMaybe{}, EJust _ e1) -> + simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) idx e1) kzero ktriv k + + (STArr{}, ENothing _ _) -> kzero + + (STScal STI32, _) -> kzero + (STScal STI64, _) -> kzero + (STScal STF32, EConst _ _ 0.0) -> kzero + (STScal STF64, EConst _ _ 0.0) -> kzero + (STScal STBool, _) -> kzero + + _ -> ktriv e + simplifyOneHotTerm term _ _ k = k term concatOneHots :: STy a |