From 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 27 Nov 2025 21:30:17 +0100 Subject: WIP user-specified custom types The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus. --- src/CHAD/Simplify.hs | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'src/CHAD/Simplify.hs') diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs index ea253d6..bbc2db8 100644 --- a/src/CHAD/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -12,6 +12,8 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=50 #-} module CHAD.Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, @@ -307,6 +309,9 @@ simplify'Rec = \case EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e + -- user types + EUnUser _ (EUser _ _ e) -> acted $ simplify' e + -- fallback recursion EVar _ t i -> pure $ EVar ext t i ELet _ a b -> [simprec| ELet ext *a *b |] @@ -361,6 +366,8 @@ simplify'Rec = \case EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s + EUser _ t e -> [simprec| EUser ext t *e |] + EUnUser _ e -> [simprec| EUnUser ext *e |] -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. @@ -410,6 +417,8 @@ hasAdds = \case EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False + EUser _ _ e -> hasAdds e + EUnUser _ e -> hasAdds e checkAccumInScope :: SList STy env -> Bool checkAccumInScope = \case SNil -> False @@ -424,6 +433,7 @@ checkAccumInScope = \case SNil -> False check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True + check (STUser t) = check (userRepTy t) data OneHotTerm dense env a where OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a -- cgit v1.2.3-70-g09d2