aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/UnMonoid.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/AST/UnMonoid.hs
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/AST/UnMonoid.hs')
-rw-r--r--src/CHAD/AST/UnMonoid.hs9
1 files changed, 7 insertions, 2 deletions
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index d3cad25..2166fc6 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -63,6 +63,8 @@ unMonoid = \case
acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
+ EUser _ t e -> EUser ext t (unMonoid e)
+ EUnUser _ e -> EUnUser ext (unMonoid e)
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
-- don't destroy the effects!
@@ -78,6 +80,7 @@ zero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+zero (SMTUser t) e = EUser ext (STUser t) (euserZero t e)
deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
deepZero SMTNil e = elet e $ ENil ext
@@ -99,6 +102,7 @@ deepZero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero (SMTUser t) e = EUser ext (STUser t) (euserDeepZero t e)
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-- don't destroy the effects!
@@ -136,6 +140,7 @@ plus (SMTArr _ t) a b =
ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
+plus (SMTUser t) a b = EUser ext (STUser t) (euserPlus t (EUnUser ext a) (EUnUser ext b))
onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
@@ -183,8 +188,8 @@ accumulateSparse
accumulateSparse topty topsp arg accum = case (topty, topsp) of
(_, s) | Just Refl <- isDense topty s ->
accum WId SAPHere (ENil ext) arg
- (SMTScal _, SpScal) ->
- accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (SMTScal _, SpScal) -> error "TScal is dense"
+ (SMTUser _, SpUser) -> error "TUser is dense"
(_, SpSparse s) ->
emaybe arg
(ENil ext)