From 095cc16c62fc6b1f039a10a43bf3bd3a79694f4d Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 12 Feb 2026 20:37:46 +0100 Subject: WIP --- src/CHAD/AST/UnMonoid.hs | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) (limited to 'src/CHAD') diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index 6578438..ceb10de 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -70,6 +70,15 @@ unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl unMonoidDeepZeroInfoId (SMTScal _) = Refl unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl +unMonoidAcIdxDId :: SMTy t -> SAcPrj p t t' -> UnMonoid (AcIdxD p t) :~: AcIdxD p (UnMonoid t) +unMonoidAcIdxDId _ SAPHere = Refl +unMonoidAcIdxDId (SMTPair t _) (SAPFst p) | Refl <- unMonoidAcIdxDId t p = Refl +unMonoidAcIdxDId (SMTPair _ t) (SAPSnd p) | Refl <- unMonoidAcIdxDId t p = Refl +unMonoidAcIdxDId (SMTLEither t _) (SAPLeft p) | Refl <- unMonoidAcIdxDId t p = Refl +unMonoidAcIdxDId (SMTLEither _ t) (SAPRight p) | Refl <- unMonoidAcIdxDId t p = Refl +unMonoidAcIdxDId (SMTMaybe t) (SAPJust p) | Refl <- unMonoidAcIdxDId t p = Refl +unMonoidAcIdxDId (SMTArr n t) (SAPArrIdx p) | Refl <- unMonoidAcIdxDId t p, Refl <- unMonoidIndexId n = Refl + -- | Removes monoidal stuff from the program. In particular: -- @@ -86,10 +95,13 @@ unMonoid = \case EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b) EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e - EAccum _ t p eidx sp eval eacc -> - accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> - acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> - EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) + EAccum _ t p eidx sp eval eacc + | Refl <- unMonoidAcIdxDId t p -> + elet (unMonoid eacc) $ + elet (weakenExpr WSink (unMonoid eidx)) $ + accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 -> + acPrjCompose SAID (unmAcPrj p) (evar (w @> IZ)) prj2 idx2 $ \prj' idx' -> + EAccum ext (mtUnMonoid t) prj' idx' (spDense (acPrjTy prj' (mtUnMonoid t))) val2 (evar (w @> IS IZ)) EVar _ t i -> EVar ext (tUnMonoid t) (go i) where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t) @@ -242,8 +254,8 @@ onehot typ topprj idx arg = case (typ, topprj) of zero t1 (evar IZ)) accumulateSparse - :: SMTy t -> Sparse t t' -> Ex env t' - -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + :: SMTy t -> Sparse t t' -> Ex env (UnMonoid t') + -> (forall p b env'. env :> env' -> SAcPrj p (UnMonoid t) b -> Ex env' (AcIdxD p (UnMonoid t)) -> Ex env' b -> Ex env' TNil) -> Ex env TNil accumulateSparse topty topsp arg accum = case (topty, topsp) of (_, s) | Just Refl <- isDense topty s -> @@ -320,3 +332,12 @@ acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') + +unmAcPrj :: SAcPrj p t t' -> SAcPrj p (UnMonoid t) (UnMonoid t') +unmAcPrj SAPHere = SAPHere +unmAcPrj (SAPFst p) = SAPFst (unmAcPrj p) +unmAcPrj (SAPSnd p) = SAPSnd (unmAcPrj p) +unmAcPrj (SAPLeft p) = SAPLeft (unmAcPrj p) +unmAcPrj (SAPRight p) = SAPRight (unmAcPrj p) +unmAcPrj (SAPJust p) = SAPJust (unmAcPrj p) +unmAcPrj (SAPArrIdx p) = SAPArrIdx (unmAcPrj p) -- cgit v1.2.3-70-g09d2