diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
| commit | 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch) | |
| tree | a21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/Simplify.hs | |
| parent | ae634c056b500a568b2d89b7f8e225404a2c0c62 (diff) | |
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of
monoids to be elementwise float addition; this fundamentally clashes
with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/Simplify.hs')
| -rw-r--r-- | src/CHAD/Simplify.hs | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs index ea253d6..bbc2db8 100644 --- a/src/CHAD/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -12,6 +12,8 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=50 #-} module CHAD.Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, @@ -307,6 +309,9 @@ simplify'Rec = \case EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e + -- user types + EUnUser _ (EUser _ _ e) -> acted $ simplify' e + -- fallback recursion EVar _ t i -> pure $ EVar ext t i ELet _ a b -> [simprec| ELet ext *a *b |] @@ -361,6 +366,8 @@ simplify'Rec = \case EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s + EUser _ t e -> [simprec| EUser ext t *e |] + EUnUser _ e -> [simprec| EUnUser ext *e |] -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. @@ -410,6 +417,8 @@ hasAdds = \case EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False + EUser _ _ e -> hasAdds e + EUnUser _ e -> hasAdds e checkAccumInScope :: SList STy env -> Bool checkAccumInScope = \case SNil -> False @@ -424,6 +433,7 @@ checkAccumInScope = \case SNil -> False check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True + check (STUser t) = check (userRepTy t) data OneHotTerm dense env a where OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a |
