From 62796be35e6e768147aab70ba0beeb94c058c714 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 8 Feb 2026 15:43:02 +0100 Subject: WIP (continue in UnMonoid) --- src/CHAD/AST/UnMonoid.hs | 192 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 130 insertions(+), 62 deletions(-) (limited to 'src/CHAD/AST/UnMonoid.hs') diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index 06de00c..6578438 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -2,7 +2,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import CHAD.AST @@ -10,40 +12,106 @@ import CHAD.AST.Sparse.Types import CHAD.Data --- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by --- expanding them into their concrete implementations. Also ensure that --- 'EAccum' has a dense sparsity. -unMonoid :: Ex env t -> Ex env t +type family UnMonoid t where + UnMonoid (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (UnMonoid t) + + UnMonoid TNil = TNil + UnMonoid (TPair a b) = TPair (UnMonoid a) (UnMonoid b) + UnMonoid (TEither a b) = TEither (UnMonoid a) (UnMonoid b) + UnMonoid (TLEither a b) = TLEither (UnMonoid a) (UnMonoid b) + UnMonoid (TMaybe a) = TMaybe (UnMonoid a) + UnMonoid (TArr n a) = TArr n (UnMonoid a) + UnMonoid (TScal t) = TScal t + UnMonoid (TAccum t) = TAccum (UnMonoid t) + +type family UnMonoidE env where + UnMonoidE '[] = '[] + UnMonoidE (t : ts) = UnMonoid t : UnMonoidE ts + +tUnMonoid :: STy t -> STy (UnMonoid t) +tUnMonoid (STIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tUnMonoid t) +tUnMonoid STNil = STNil +tUnMonoid (STPair a b) = STPair (tUnMonoid a) (tUnMonoid b) +tUnMonoid (STEither a b) = STEither (tUnMonoid a) (tUnMonoid b) +tUnMonoid (STLEither a b) = STLEither (tUnMonoid a) (tUnMonoid b) +tUnMonoid (STMaybe a) = STMaybe (tUnMonoid a) +tUnMonoid (STArr n t) = STArr n (tUnMonoid t) +tUnMonoid (STAccum t) = STAccum (mtUnMonoid t) +tUnMonoid (STScal t) = STScal t + +mtUnMonoid :: SMTy t -> SMTy (UnMonoid t) +mtUnMonoid (SMTIdxPair n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (SMTScal STI64))) (mtUnMonoid t) +mtUnMonoid SMTNil = SMTNil +mtUnMonoid (SMTPair a b) = SMTPair (mtUnMonoid a) (mtUnMonoid b) +mtUnMonoid (SMTLEither a b) = SMTLEither (mtUnMonoid a) (mtUnMonoid b) +mtUnMonoid (SMTMaybe a) = SMTMaybe (mtUnMonoid a) +mtUnMonoid (SMTArr n t) = SMTArr n (mtUnMonoid t) +mtUnMonoid (SMTScal t) = SMTScal t + +unMonoidIndexId :: SNat n -> UnMonoid (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +unMonoidIndexId SZ = Refl +unMonoidIndexId (SS n) | Refl <- unMonoidIndexId n = Refl + +unMonoidZeroInfoId :: SMTy t -> UnMonoid (ZeroInfo t) :~: ZeroInfo t +unMonoidZeroInfoId SMTNil = Refl +unMonoidZeroInfoId (SMTPair a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl +unMonoidZeroInfoId (SMTLEither a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl +unMonoidZeroInfoId (SMTMaybe a) | Refl <- unMonoidZeroInfoId a = Refl +unMonoidZeroInfoId (SMTArr _ a) | Refl <- unMonoidZeroInfoId a = Refl +unMonoidZeroInfoId (SMTScal _) = Refl +unMonoidZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId a = Refl + +unMonoidDeepZeroInfoId :: SMTy t -> UnMonoid (DeepZeroInfo t) :~: DeepZeroInfo t +unMonoidDeepZeroInfoId SMTNil = Refl +unMonoidDeepZeroInfoId (SMTPair a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl +unMonoidDeepZeroInfoId (SMTLEither a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl +unMonoidDeepZeroInfoId (SMTMaybe a) | Refl <- unMonoidDeepZeroInfoId a = Refl +unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl +unMonoidDeepZeroInfoId (SMTScal _) = Refl +unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl + + +-- | Removes monoidal stuff from the program. In particular: +-- +-- * Removes 'EZero', 'EDeepZero', 'EPlus', 'EOneHot' from the program by +-- expanding them into their concrete implementations. +-- * Removes 'EIdxPair' and 'EUnIdxPair', as well as all occurrences of the +-- 'TIdxPair' type in general. +-- * Ensures that 'EAccum' has a dense sparsity. +unMonoid :: Ex env t -> Ex (UnMonoidE env) (UnMonoid t) unMonoid = \case - EZero _ t e -> zero t e - EDeepZero _ t e -> deepZero t e - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) + EZero _ t e | Refl <- unMonoidZeroInfoId t -> zero t (unMonoid e) + EDeepZero _ t e | Refl <- unMonoidDeepZeroInfoId t -> deepZero t (unMonoid e) + EPlus _ t a b -> plus (mtUnMonoid t) (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) + EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b) + EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e EAccum _ t p eidx sp eval eacc -> - elet (unMonoid eacc) $ - elet (weakenExpr WSink (unMonoid eidx)) $ - accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 -> - acPrjCompose SAID p (evar (w @> IZ)) prj2 idx2 $ \prj' idx' -> - EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (evar (w @> IS IZ)) + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) - EVar _ t i -> EVar ext t i + EVar _ t i -> EVar ext (tUnMonoid t) (go i) + where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t) + go IZ = IZ + go (IS j) = IS (go j) ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) + EInl _ t e -> EInl ext (tUnMonoid t) (unMonoid e) + EInr _ t e -> EInr ext (tUnMonoid t) (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t + ENothing _ t -> ENothing ext (tUnMonoid t) EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - ELNil _ t1 t2 -> ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t (unMonoid e) - ELInr _ t e -> ELInr ext t (unMonoid e) + ELNil _ t1 t2 -> ELNil ext (tUnMonoid t1) (tUnMonoid t2) + ELInl _ t e -> ELInl ext (tUnMonoid t) (unMonoid e) + ELInr _ t e -> ELInr ext (tUnMonoid t) (unMonoid e) ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EBuild _ n a b | Refl <- unMonoidIndexId n -> EBuild ext n (unMonoid a) (unMonoid b) EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) @@ -51,49 +119,52 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EReshape _ n a b | Refl <- unMonoidIndexId n -> EReshape ext n (unMonoid a) (unMonoid b) EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) + EIdx _ a b | STArr n _ <- typeOf a, Refl <- unMonoidIndexId n -> EIdx ext (unMonoid a) (unMonoid b) + EShape _ e | STArr n _ <- typeOf e, Refl <- unMonoidIndexId n -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext (tUnMonoid t1) (tUnMonoid t2) (tUnMonoid t3) (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EError _ t s -> EError ext t s -zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env (UnMonoid t) -- don't destroy the effects! -zero SMTNil e = ELet ext e $ ENil ext +zero SMTNil e = use e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) -zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTLEither t1 t2) _ = ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2)) +zero (SMTMaybe t) _ = ENothing ext (tUnMonoid (fromSMTy t)) zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e zero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +zero (SMTIdxPair _ t) e = + eunPair e $ \_ e1 e2 -> + EPair ext e1 (zero t e2) -deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env (UnMonoid t) deepZero SMTNil e = elet e $ ENil ext deepZero (SMTPair t1 t2) e = ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) deepZero (SMTLEither t1 t2) e = elcase e - (ELNil ext (fromSMTy t1) (fromSMTy t2)) - (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) - (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) + (ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2))) + (ELInl ext (tUnMonoid (fromSMTy t2)) (deepZero t1 (evar IZ))) + (ELInr ext (tUnMonoid (fromSMTy t1)) (deepZero t2 (evar IZ))) deepZero (SMTMaybe t) e = emaybe e - (ENothing ext (fromSMTy t)) + (ENothing ext (tUnMonoid (fromSMTy t))) (EJust ext (deepZero t (evar IZ))) deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e deepZero (SMTScal t) _ = case t of @@ -101,6 +172,9 @@ deepZero (SMTScal t) _ = case t of STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero (SMTIdxPair _ t) e = + eunPair e $ \_ e1 e2 -> + EPair ext e1 (deepZero t e2) plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! @@ -134,44 +208,38 @@ plus (SMTArr _ t) a b = a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (UnMonoid (AcIdxS p t)) -> Ex env (UnMonoid a) -> Ex env (UnMonoid t) onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> - ELet ext arg $ - EVar ext (fromSMTy typ) IZ - - (SMTPair t1 t2, SAPFst prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t1 in - EPair ext (EVar ext toh IZ) - (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) - - (SMTPair t1 t2, SAPSnd prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t2 in - EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) - (EVar ext toh IZ) + arg + + (SMTPair t1 t2, SAPFst prj) | Refl <- unMonoidZeroInfoId t2 -> + elet idx $ + elet (onehot t1 prj (EFst ext (evar IZ)) (weakenExpr WSink arg)) $ + EPair ext (evar IZ) + (zero t2 (ESnd ext (evar (IS IZ)))) + + (SMTPair t1 t2, SAPSnd prj) | Refl <- unMonoidZeroInfoId t1 -> + elet idx $ + elet (onehot t2 prj (ESnd ext (evar IZ)) (weakenExpr WSink arg)) $ + EPair ext (zero t1 (EFst ext (evar (IS IZ)))) + (evar IZ) (SMTLEither t1 t2, SAPLeft prj) -> - ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + ELInl ext (tUnMonoid (fromSMTy t2)) (onehot t1 prj idx arg) (SMTLEither t1 t2, SAPRight prj) -> - ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) + ELInr ext (tUnMonoid (fromSMTy t1)) (onehot t2 prj idx arg) (SMTMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) - (SMTArr n t1, SAPArrIdx prj) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ - zero t1 (EVar ext (tZeroInfo t1) IZ)) + (SMTArr n t1, SAPArrIdx prj) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId t1 -> + elet idx $ + EBuild ext n (EShape ext (ESnd ext (EFst ext (evar IZ)))) $ + eif (eidxEq n (evar IZ) (EFst ext (EFst ext (evar (IS IZ))))) + (onehot t1 prj (ESnd ext (evar (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (elet (EIdx ext (ESnd ext (EFst ext (evar (IS IZ)))) (evar IZ)) $ + zero t1 (evar IZ)) accumulateSparse :: SMTy t -> Sparse t t' -> Ex env t' -- cgit v1.2.3-70-g09d2