diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/AST/UnMonoid.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/AST/UnMonoid.hs')
| -rw-r--r-- | src/AST/UnMonoid.hs | 255 |
1 files changed, 0 insertions, 255 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs deleted file mode 100644 index 1712ba5..0000000 --- a/src/AST/UnMonoid.hs +++ /dev/null @@ -1,255 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where - -import AST -import AST.Sparse.Types -import Data - - --- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by --- expanding them into their concrete implementations. Also ensure that --- 'EAccum' has a dense sparsity. -unMonoid :: Ex env t -> Ex env t -unMonoid = \case - EZero _ t e -> zero t e - EDeepZero _ t e -> deepZero t e - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) - - EVar _ t i -> EVar ext t i - ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) - EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) - EFst _ e -> EFst ext (unMonoid e) - ESnd _ e -> ESnd ext (unMonoid e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) - ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t - EJust _ e -> EJust ext (unMonoid e) - EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - ELNil _ t1 t2 -> ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t (unMonoid e) - ELInr _ t e -> ELInr ext t (unMonoid e) - ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) - EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) - EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) - ESum1Inner _ e -> ESum1Inner ext (unMonoid e) - EUnit _ e -> EUnit ext (unMonoid e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) - EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) - EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) - EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) - EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EConst _ t x -> EConst ext t x - EIdx0 _ e -> EIdx0 ext (unMonoid e) - EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) - EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - ERecompute _ e -> ERecompute ext (unMonoid e) - EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p eidx sp eval eacc -> - accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> - acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> - EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) - EError _ t s -> EError ext t s - -zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t --- don't destroy the effects! -zero SMTNil e = ELet ext e $ ENil ext -zero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) -zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) -zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e -zero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t -deepZero SMTNil e = elet e $ ENil ext -deepZero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -deepZero (SMTLEither t1 t2) e = - elcase e - (ELNil ext (fromSMTy t1) (fromSMTy t2)) - (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) - (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) -deepZero (SMTMaybe t) e = - emaybe e - (ENothing ext (fromSMTy t)) - (EJust ext (deepZero t (evar IZ))) -deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e -deepZero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t --- don't destroy the effects! -plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext -plus (SMTPair t1 t2) a b = - let t = STPair (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (SMTLEither t1 t2) a b = - let t = STLEither (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - ELCase ext (EVar ext t (IS IZ)) - (EVar ext t IZ) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) - (EError ext t "plus l+r")) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (EError ext t "plus r+l") - (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) -plus (SMTMaybe t) a b = - ELet ext b $ - EMaybe ext - (EVar ext (STMaybe (fromSMTy t)) IZ) - (EJust ext - (EMaybe ext - (EVar ext (fromSMTy t) IZ) - (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) - (weakenExpr WSink a) -plus (SMTArr _ t) a b = - ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - a b -plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) - -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t -onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> - ELet ext arg $ - EVar ext (fromSMTy typ) IZ - - (SMTPair t1 t2, SAPFst prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t1 in - EPair ext (EVar ext toh IZ) - (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) - - (SMTPair t1 t2, SAPSnd prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t2 in - EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) - (EVar ext toh IZ) - - (SMTLEither t1 t2, SAPLeft prj) -> - ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) - (SMTLEither t1 t2, SAPRight prj) -> - ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) - - (SMTMaybe t1, SAPJust prj) -> - EJust ext (onehot t1 prj idx arg) - - (SMTArr n t1, SAPArrIdx prj) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ - zero t1 (EVar ext (tZeroInfo t1) IZ)) - -accumulateSparse - :: SMTy t -> Sparse t t' -> Ex env t' - -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) - -> Ex env TNil -accumulateSparse topty topsp arg accum = case (topty, topsp) of - (_, s) | Just Refl <- isDense topty s -> - accum WId SAPHere (ENil ext) arg - (SMTScal _, SpScal) -> - accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh - (_, SpSparse s) -> - emaybe arg - (ENil ext) - (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) - (_, SpAbsent) -> - ENil ext - (SMTPair t1 t2, SpPair s1 s2) -> - eunPair arg $ \w1 e1 e2 -> - elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ - accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) - (SMTLEither t1 t2, SpLEither s1 s2) -> - elcase arg - (ENil ext) - (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) - (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) - (SMTMaybe t, SpMaybe s) -> - emaybe arg - (ENil ext) - (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) - (SMTArr n t, SpArr s) -> - let tn = tTup (sreplicate n tIx) in - elet arg $ - elet (EBuild ext n (EShape ext (evar IZ)) $ - accumulateSparse t s - (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) - (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ - ENil ext - -acPrjCompose - :: SAIDense dense - -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) - -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r -acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 -acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = - acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPFst p') idx' -acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = - acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPSnd p') idx' -acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) -acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') -acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPLeft p') idx' -acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPRight p') idx' -acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPJust p') idx' -acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') -acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') |
