diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2026-02-08 15:43:02 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2026-02-12 20:44:47 +0100 |
| commit | 62796be35e6e768147aab70ba0beeb94c058c714 (patch) | |
| tree | dd43c8c2f37c59308b6b7d503fd25420621b0ab9 | |
| parent | c2831ef0f8be71f2a72ee4eee446e2ac473fb638 (diff) | |
WIP (continue in UnMonoid)
| -rw-r--r-- | src/CHAD/AST.hs | 6 | ||||
| -rw-r--r-- | src/CHAD/AST/Accum.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 18 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs | 5 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 190 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 3 |
6 files changed, 159 insertions, 67 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index 3f6dfc4..bb14218 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -129,6 +129,9 @@ data Expr x env t where EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + EIdxPair :: x (TIdxPair n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t -> Expr x env (TIdxPair n t) + EUnIdxPair :: x (TPair (Tup (Replicate n TIx)) t) -> Expr x env (TIdxPair n t) -> Expr x env (TPair (Tup (Replicate n TIx)) t) + -- interface of abstract monoidal types ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) @@ -601,6 +604,9 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext + go (SMTIdxPair _ t) e = + eunPair (EUnIdxPair ext e) $ \_ e1 e2 -> + EPair ext e1 (go t e2) splitSparsePair :: -- given a sparsity diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs index ea74a95..b09e1bb 100644 --- a/src/CHAD/AST/Accum.hs +++ b/src/CHAD/AST/Accum.hs @@ -83,6 +83,7 @@ type family ZeroInfo t where ZeroInfo (TMaybe a) = TNil ZeroInfo (TArr n t) = TArr n (ZeroInfo t) ZeroInfo (TScal t) = TNil + ZeroInfo (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (ZeroInfo t) tZeroInfo :: SMTy t -> STy (ZeroInfo t) tZeroInfo SMTNil = STNil @@ -91,6 +92,7 @@ tZeroInfo (SMTLEither _ _) = STNil tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil +tZeroInfo (SMTIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tZeroInfo t) -- | Info needed to create a zero-valued deep accumulator for a monoid type. -- Should be constructable from a D1. @@ -101,6 +103,7 @@ type family DeepZeroInfo t where DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) DeepZeroInfo (TScal t) = TNil + DeepZeroInfo (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (DeepZeroInfo t) tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) tDeepZeroInfo SMTNil = STNil @@ -109,6 +112,7 @@ tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) tDeepZeroInfo (SMTScal _) = STNil +tDeepZeroInfo (SMTIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tDeepZeroInfo t) -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 930475a..9a4cf99 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -25,19 +25,20 @@ data Sparse t t' where SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's) SpScal :: Sparse (TScal t) (TScal t) + SpIdxPair :: Sparse t t' -> Sparse (TIdxPair n t) (TIdxPair n t') deriving instance Show (Sparse t t') type family MultiHot n t's where MultiHot n '[] = TNil - MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's) + MultiHot n (t' : t's) = TPair (TIdxPair n t') (MultiHot n t's) tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts) tMultiHot _ SNil = STNil -tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) +tMultiHot n (t `SCons` ts) = STPair (STIdxPair n t) (tMultiHot n ts) mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts) mtMultiHot _ SNil = SMTNil -mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) +mtMultiHot n (t `SCons` ts) = SMTPair (SMTIdxPair n t) (mtMultiHot n ts) class ApplySparse f where applySparse :: Sparse t t' -> f t -> f t' @@ -51,6 +52,7 @@ instance ApplySparse STy where applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss) applySparse SpScal t = t + applySparse (SpIdxPair s) (STIdxPair n t) = STIdxPair n (applySparse s t) instance ApplySparse SMTy where applySparse (SpSparse s) t = SMTMaybe (applySparse s t) @@ -59,8 +61,9 @@ instance ApplySparse SMTy where applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) - applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t) + applySparse (SpArrIdx l) (SMTArr n t) = mtMultiHot n (slistMap (`applySparse` t) l) applySparse SpScal t = t + applySparse (SpIdxPair s) (SMTIdxPair n t) = SMTIdxPair n (applySparse s t) class IsSubType s where @@ -85,6 +88,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) spDense (SMTMaybe t) = SpMaybe (spDense t) spDense (SMTArr _ t) = SpArr (spDense t) spDense (SMTScal _) = SpScal +spDense (SMTIdxPair _ t) = SpIdxPair (spDense t) isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') isDense SMTNil SpAbsent = Just Refl @@ -104,6 +108,9 @@ isDense (SMTArr _ t) (SpArr s) | otherwise = Nothing isDense SMTArr{} SpArrIdx{} = Nothing isDense (SMTScal _) SpScal = Just Refl +isDense (SMTIdxPair _ t) (SpIdxPair s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing isAbsent :: Sparse t t' -> Bool isAbsent (SpSparse s) = isAbsent s @@ -112,5 +119,6 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s -isAbsent (SpArrIdx s) = isAbsent s +isAbsent (SpArrIdx l) = and (unSList isAbsent l) isAbsent SpScal = False +isAbsent (SpIdxPair s) = isAbsent s diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs index bec2201..1e3d36c 100644 --- a/src/CHAD/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -28,6 +28,7 @@ type data Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Ty -- ^ contained type must be a monoid type + | TIdxPair Nat Ty -- ^ an array one-hot (eliminated by UnMonoid) type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -43,6 +44,7 @@ data STy t where STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) + STIdxPair :: SNat n -> STy t -> STy (TIdxPair n t) deriving instance Show (STy t) instance GCompare STy where @@ -77,7 +79,7 @@ data SMTy t where SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) - SMTData :: STy a -> SMTy a -- ^ inclusion of non-monoidal information + SMTIdxPair :: SNat n -> SMTy t -> SMTy (TIdxPair n t) deriving instance Show (SMTy t) instance GCompare SMTy where @@ -107,6 +109,7 @@ fromSMTy = \case SMTMaybe t -> STMaybe (fromSMTy t) SMTArr n t -> STArr n (fromSMTy t) SMTScal sty -> STScal sty + SMTIdxPair n t -> STIdxPair n (fromSMTy t) data SScalTy t where STI32 :: SScalTy TI32 diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index 06de00c..6578438 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -2,7 +2,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import CHAD.AST @@ -10,40 +12,106 @@ import CHAD.AST.Sparse.Types import CHAD.Data --- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by --- expanding them into their concrete implementations. Also ensure that --- 'EAccum' has a dense sparsity. -unMonoid :: Ex env t -> Ex env t +type family UnMonoid t where + UnMonoid (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (UnMonoid t) + + UnMonoid TNil = TNil + UnMonoid (TPair a b) = TPair (UnMonoid a) (UnMonoid b) + UnMonoid (TEither a b) = TEither (UnMonoid a) (UnMonoid b) + UnMonoid (TLEither a b) = TLEither (UnMonoid a) (UnMonoid b) + UnMonoid (TMaybe a) = TMaybe (UnMonoid a) + UnMonoid (TArr n a) = TArr n (UnMonoid a) + UnMonoid (TScal t) = TScal t + UnMonoid (TAccum t) = TAccum (UnMonoid t) + +type family UnMonoidE env where + UnMonoidE '[] = '[] + UnMonoidE (t : ts) = UnMonoid t : UnMonoidE ts + +tUnMonoid :: STy t -> STy (UnMonoid t) +tUnMonoid (STIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tUnMonoid t) +tUnMonoid STNil = STNil +tUnMonoid (STPair a b) = STPair (tUnMonoid a) (tUnMonoid b) +tUnMonoid (STEither a b) = STEither (tUnMonoid a) (tUnMonoid b) +tUnMonoid (STLEither a b) = STLEither (tUnMonoid a) (tUnMonoid b) +tUnMonoid (STMaybe a) = STMaybe (tUnMonoid a) +tUnMonoid (STArr n t) = STArr n (tUnMonoid t) +tUnMonoid (STAccum t) = STAccum (mtUnMonoid t) +tUnMonoid (STScal t) = STScal t + +mtUnMonoid :: SMTy t -> SMTy (UnMonoid t) +mtUnMonoid (SMTIdxPair n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (SMTScal STI64))) (mtUnMonoid t) +mtUnMonoid SMTNil = SMTNil +mtUnMonoid (SMTPair a b) = SMTPair (mtUnMonoid a) (mtUnMonoid b) +mtUnMonoid (SMTLEither a b) = SMTLEither (mtUnMonoid a) (mtUnMonoid b) +mtUnMonoid (SMTMaybe a) = SMTMaybe (mtUnMonoid a) +mtUnMonoid (SMTArr n t) = SMTArr n (mtUnMonoid t) +mtUnMonoid (SMTScal t) = SMTScal t + +unMonoidIndexId :: SNat n -> UnMonoid (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +unMonoidIndexId SZ = Refl +unMonoidIndexId (SS n) | Refl <- unMonoidIndexId n = Refl + +unMonoidZeroInfoId :: SMTy t -> UnMonoid (ZeroInfo t) :~: ZeroInfo t +unMonoidZeroInfoId SMTNil = Refl +unMonoidZeroInfoId (SMTPair a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl +unMonoidZeroInfoId (SMTLEither a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl +unMonoidZeroInfoId (SMTMaybe a) | Refl <- unMonoidZeroInfoId a = Refl +unMonoidZeroInfoId (SMTArr _ a) | Refl <- unMonoidZeroInfoId a = Refl +unMonoidZeroInfoId (SMTScal _) = Refl +unMonoidZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId a = Refl + +unMonoidDeepZeroInfoId :: SMTy t -> UnMonoid (DeepZeroInfo t) :~: DeepZeroInfo t +unMonoidDeepZeroInfoId SMTNil = Refl +unMonoidDeepZeroInfoId (SMTPair a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl +unMonoidDeepZeroInfoId (SMTLEither a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl +unMonoidDeepZeroInfoId (SMTMaybe a) | Refl <- unMonoidDeepZeroInfoId a = Refl +unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl +unMonoidDeepZeroInfoId (SMTScal _) = Refl +unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl + + +-- | Removes monoidal stuff from the program. In particular: +-- +-- * Removes 'EZero', 'EDeepZero', 'EPlus', 'EOneHot' from the program by +-- expanding them into their concrete implementations. +-- * Removes 'EIdxPair' and 'EUnIdxPair', as well as all occurrences of the +-- 'TIdxPair' type in general. +-- * Ensures that 'EAccum' has a dense sparsity. +unMonoid :: Ex env t -> Ex (UnMonoidE env) (UnMonoid t) unMonoid = \case - EZero _ t e -> zero t e - EDeepZero _ t e -> deepZero t e - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) + EZero _ t e | Refl <- unMonoidZeroInfoId t -> zero t (unMonoid e) + EDeepZero _ t e | Refl <- unMonoidDeepZeroInfoId t -> deepZero t (unMonoid e) + EPlus _ t a b -> plus (mtUnMonoid t) (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) + EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b) + EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e EAccum _ t p eidx sp eval eacc -> - elet (unMonoid eacc) $ - elet (weakenExpr WSink (unMonoid eidx)) $ - accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 -> - acPrjCompose SAID p (evar (w @> IZ)) prj2 idx2 $ \prj' idx' -> - EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (evar (w @> IS IZ)) + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) - EVar _ t i -> EVar ext t i + EVar _ t i -> EVar ext (tUnMonoid t) (go i) + where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t) + go IZ = IZ + go (IS j) = IS (go j) ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) + EInl _ t e -> EInl ext (tUnMonoid t) (unMonoid e) + EInr _ t e -> EInr ext (tUnMonoid t) (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t + ENothing _ t -> ENothing ext (tUnMonoid t) EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - ELNil _ t1 t2 -> ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t (unMonoid e) - ELInr _ t e -> ELInr ext t (unMonoid e) + ELNil _ t1 t2 -> ELNil ext (tUnMonoid t1) (tUnMonoid t2) + ELInl _ t e -> ELInl ext (tUnMonoid t) (unMonoid e) + ELInr _ t e -> ELInr ext (tUnMonoid t) (unMonoid e) ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EBuild _ n a b | Refl <- unMonoidIndexId n -> EBuild ext n (unMonoid a) (unMonoid b) EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) @@ -51,49 +119,52 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EReshape _ n a b | Refl <- unMonoidIndexId n -> EReshape ext n (unMonoid a) (unMonoid b) EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) + EIdx _ a b | STArr n _ <- typeOf a, Refl <- unMonoidIndexId n -> EIdx ext (unMonoid a) (unMonoid b) + EShape _ e | STArr n _ <- typeOf e, Refl <- unMonoidIndexId n -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext (tUnMonoid t1) (tUnMonoid t2) (tUnMonoid t3) (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EError _ t s -> EError ext t s -zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env (UnMonoid t) -- don't destroy the effects! -zero SMTNil e = ELet ext e $ ENil ext +zero SMTNil e = use e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) -zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTLEither t1 t2) _ = ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2)) +zero (SMTMaybe t) _ = ENothing ext (tUnMonoid (fromSMTy t)) zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e zero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +zero (SMTIdxPair _ t) e = + eunPair e $ \_ e1 e2 -> + EPair ext e1 (zero t e2) -deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env (UnMonoid t) deepZero SMTNil e = elet e $ ENil ext deepZero (SMTPair t1 t2) e = ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) deepZero (SMTLEither t1 t2) e = elcase e - (ELNil ext (fromSMTy t1) (fromSMTy t2)) - (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) - (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) + (ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2))) + (ELInl ext (tUnMonoid (fromSMTy t2)) (deepZero t1 (evar IZ))) + (ELInr ext (tUnMonoid (fromSMTy t1)) (deepZero t2 (evar IZ))) deepZero (SMTMaybe t) e = emaybe e - (ENothing ext (fromSMTy t)) + (ENothing ext (tUnMonoid (fromSMTy t))) (EJust ext (deepZero t (evar IZ))) deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e deepZero (SMTScal t) _ = case t of @@ -101,6 +172,9 @@ deepZero (SMTScal t) _ = case t of STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero (SMTIdxPair _ t) e = + eunPair e $ \_ e1 e2 -> + EPair ext e1 (deepZero t e2) plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! @@ -134,44 +208,38 @@ plus (SMTArr _ t) a b = a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (UnMonoid (AcIdxS p t)) -> Ex env (UnMonoid a) -> Ex env (UnMonoid t) onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> - ELet ext arg $ - EVar ext (fromSMTy typ) IZ + arg + + (SMTPair t1 t2, SAPFst prj) | Refl <- unMonoidZeroInfoId t2 -> + elet idx $ + elet (onehot t1 prj (EFst ext (evar IZ)) (weakenExpr WSink arg)) $ + EPair ext (evar IZ) + (zero t2 (ESnd ext (evar (IS IZ)))) - (SMTPair t1 t2, SAPFst prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t1 in - EPair ext (EVar ext toh IZ) - (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) - - (SMTPair t1 t2, SAPSnd prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t2 in - EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) - (EVar ext toh IZ) + (SMTPair t1 t2, SAPSnd prj) | Refl <- unMonoidZeroInfoId t1 -> + elet idx $ + elet (onehot t2 prj (ESnd ext (evar IZ)) (weakenExpr WSink arg)) $ + EPair ext (zero t1 (EFst ext (evar (IS IZ)))) + (evar IZ) (SMTLEither t1 t2, SAPLeft prj) -> - ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + ELInl ext (tUnMonoid (fromSMTy t2)) (onehot t1 prj idx arg) (SMTLEither t1 t2, SAPRight prj) -> - ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) + ELInr ext (tUnMonoid (fromSMTy t1)) (onehot t2 prj idx arg) (SMTMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) - (SMTArr n t1, SAPArrIdx prj) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ - zero t1 (EVar ext (tZeroInfo t1) IZ)) + (SMTArr n t1, SAPArrIdx prj) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId t1 -> + elet idx $ + EBuild ext n (EShape ext (ESnd ext (EFst ext (evar IZ)))) $ + eif (eidxEq n (evar IZ) (EFst ext (EFst ext (evar (IS IZ))))) + (onehot t1 prj (ESnd ext (evar (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (elet (EIdx ext (ESnd ext (EFst ext (evar (IS IZ)))) (evar IZ)) $ + zero t1 (evar IZ)) accumulateSparse :: SMTy t -> Sparse t t' -> Ex env t' diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs index 367a974..f23aeba 100644 --- a/src/CHAD/Drev/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -55,6 +55,7 @@ d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1 STIdxPair{} = error "Index pairs not allowed in input program" d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil @@ -74,6 +75,7 @@ d2M (STScal t) = case t of STF64 -> SMTScal STF64 STBool -> SMTNil d2M STAccum{} = error "Accumulators not allowed in input program" +d2M STIdxPair{} = error "Index pairs not allowed in input program" d2 :: STy t -> STy (D2 t) d2 = fromSMTy . d2M @@ -147,6 +149,7 @@ d1Identity = \case STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" + STIdxPair{} -> error "Index pairs not allowed in input program" d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl |
