diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 17:52:43 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 17:52:43 +0200 |
commit | 4ad7eaba73d5fda8ff5028d1e53966f728d704d3 (patch) | |
tree | bb8126c9eee53d565f2c3a86fe70f914548c3782 /src | |
parent | 0a32a06b9c3206484c34860148dfbe23935b8e3b (diff) |
simplify: Better simplify nested monoid ops
Diffstat (limited to 'src')
-rw-r--r-- | src/AST/Types.hs | 1 | ||||
-rw-r--r-- | src/Simplify.hs | 93 |
2 files changed, 56 insertions, 38 deletions
diff --git a/src/AST/Types.hs b/src/AST/Types.hs index c8515fc..efb1e04 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -73,7 +73,6 @@ type SMTy :: Ty -> Type data SMTy t where SMTNil :: SMTy TNil SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) - -- TODO: call this SMTLEither SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) diff --git a/src/Simplify.hs b/src/Simplify.hs index 140e673..f5eb0a1 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -91,6 +91,12 @@ acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) +acted' :: (Any, a) -> (Any, a) +acted' (_, x) = (Any True, x) + +liftActed :: (Any, a) -> SM tenv tt env t a +liftActed pair = SM $ \_ -> pair + simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do @@ -368,43 +374,56 @@ simplifyOneHotTerm :: OneHotTerm env p a b -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero - -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - = do tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k - --- TODO: This does not actually recurse unless it just so happens to contain --- another EZero or EOnehot in the final position. Should match on something --- more general than SAPHere here. -simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of - (SMTNil, _) -> kzero - - (SMTPair{}, EPair _ e1 (EZero _ _ ezi)) -> - simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) (EPair ext (ENil ext) ezi) e1) kzero ktriv k - (SMTPair{}, EPair _ (EZero _ _ ezi) e2) -> - simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) (EPair ext ezi (ENil ext)) e2) kzero ktriv k - - (SMTLEither{}, ELNil _ _ _) -> kzero - (SMTLEither{}, ELInl _ _ e1) -> - simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) (ENil ext) e1) kzero ktriv k - (SMTLEither{}, ELInr _ _ e2) -> - simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) (ENil ext) e2) kzero ktriv k - - (SMTMaybe{}, ENothing _ _) -> kzero - (SMTMaybe{}, EJust _ e1) -> - simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) (ENil ext) e1) kzero ktriv k - - (SMTScal STI32, _) -> kzero - (SMTScal STI64, _) -> kzero - (SMTScal STF32, EConst _ _ 0.0) -> kzero - (SMTScal STF64, EConst _ _ 0.0) -> kzero - - _ -> ktriv e - -simplifyOneHotTerm term _ _ k = k term +simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do + val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 + case val1' of + EZero{} -> kzero + EOneHot _ t2 prj2 idx2 val2 + | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do + tellActed -- record, whatever happens later, that we've modified something + concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k + _ -> case prj1 of + SAPHere -> ktriv val1 + _ -> k (OneHotTerm t1 prj1 idx1 val1) + +-- | Recognises 'EZero' and 'EOneHot'. +recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) +recogniseMonoid _ e@EOneHot{} = return e +recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = + ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (a', b') -> return $ EPair ext a' b' +recogniseMonoid typ@(SMTLEither t1 t2) expr = + case expr of + ELNil{} -> acted' $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + _ -> return expr +recogniseMonoid typ@(SMTMaybe t1) expr = + case expr of + ENothing{} -> acted' $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + _ -> return expr +recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = + acted' $ do + e' <- recogniseMonoid t e + return $ + ELet ext e' $ + EOneHot ext typ (SAPArrIdx SAPHere) + (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ)))) + (ENil ext)) + (EVar ext (fromSMTy t) IZ) +recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of + (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + _ -> return e +recogniseMonoid _ e = return e concatOneHots :: SMTy a -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) |