aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2026-02-12 20:37:46 +0100
committerTom Smeding <tom@tomsmeding.com>2026-02-12 20:44:49 +0100
commit095cc16c62fc6b1f039a10a43bf3bd3a79694f4d (patch)
tree81febb46def1ed7758202d3c57abd3eb4622735c /src/CHAD
parent62796be35e6e768147aab70ba0beeb94c058c714 (diff)
WIP
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/AST/UnMonoid.hs33
1 files changed, 27 insertions, 6 deletions
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index 6578438..ceb10de 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -70,6 +70,15 @@ unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl
unMonoidDeepZeroInfoId (SMTScal _) = Refl
unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl
+unMonoidAcIdxDId :: SMTy t -> SAcPrj p t t' -> UnMonoid (AcIdxD p t) :~: AcIdxD p (UnMonoid t)
+unMonoidAcIdxDId _ SAPHere = Refl
+unMonoidAcIdxDId (SMTPair t _) (SAPFst p) | Refl <- unMonoidAcIdxDId t p = Refl
+unMonoidAcIdxDId (SMTPair _ t) (SAPSnd p) | Refl <- unMonoidAcIdxDId t p = Refl
+unMonoidAcIdxDId (SMTLEither t _) (SAPLeft p) | Refl <- unMonoidAcIdxDId t p = Refl
+unMonoidAcIdxDId (SMTLEither _ t) (SAPRight p) | Refl <- unMonoidAcIdxDId t p = Refl
+unMonoidAcIdxDId (SMTMaybe t) (SAPJust p) | Refl <- unMonoidAcIdxDId t p = Refl
+unMonoidAcIdxDId (SMTArr n t) (SAPArrIdx p) | Refl <- unMonoidAcIdxDId t p, Refl <- unMonoidIndexId n = Refl
+
-- | Removes monoidal stuff from the program. In particular:
--
@@ -86,10 +95,13 @@ unMonoid = \case
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b)
EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e
- EAccum _ t p eidx sp eval eacc ->
- accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
- acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
- EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
+ EAccum _ t p eidx sp eval eacc
+ | Refl <- unMonoidAcIdxDId t p ->
+ elet (unMonoid eacc) $
+ elet (weakenExpr WSink (unMonoid eidx)) $
+ accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID (unmAcPrj p) (evar (w @> IZ)) prj2 idx2 $ \prj' idx' ->
+ EAccum ext (mtUnMonoid t) prj' idx' (spDense (acPrjTy prj' (mtUnMonoid t))) val2 (evar (w @> IS IZ))
EVar _ t i -> EVar ext (tUnMonoid t) (go i)
where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t)
@@ -242,8 +254,8 @@ onehot typ topprj idx arg = case (typ, topprj) of
zero t1 (evar IZ))
accumulateSparse
- :: SMTy t -> Sparse t t' -> Ex env t'
- -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ :: SMTy t -> Sparse t t' -> Ex env (UnMonoid t')
+ -> (forall p b env'. env :> env' -> SAcPrj p (UnMonoid t) b -> Ex env' (AcIdxD p (UnMonoid t)) -> Ex env' b -> Ex env' TNil)
-> Ex env TNil
accumulateSparse topty topsp arg accum = case (topty, topsp) of
(_, s) | Just Refl <- isDense topty s ->
@@ -320,3 +332,12 @@ acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
| Dict <- styKnown (typeOf idx1) =
acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+
+unmAcPrj :: SAcPrj p t t' -> SAcPrj p (UnMonoid t) (UnMonoid t')
+unmAcPrj SAPHere = SAPHere
+unmAcPrj (SAPFst p) = SAPFst (unmAcPrj p)
+unmAcPrj (SAPSnd p) = SAPSnd (unmAcPrj p)
+unmAcPrj (SAPLeft p) = SAPLeft (unmAcPrj p)
+unmAcPrj (SAPRight p) = SAPRight (unmAcPrj p)
+unmAcPrj (SAPJust p) = SAPJust (unmAcPrj p)
+unmAcPrj (SAPArrIdx p) = SAPArrIdx (unmAcPrj p)