aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/AST.hs6
-rw-r--r--src/CHAD/AST/Accum.hs4
-rw-r--r--src/CHAD/AST/Sparse/Types.hs18
-rw-r--r--src/CHAD/AST/Types.hs5
-rw-r--r--src/CHAD/AST/UnMonoid.hs190
-rw-r--r--src/CHAD/Drev/Types.hs3
6 files changed, 159 insertions, 67 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index 3f6dfc4..bb14218 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -129,6 +129,9 @@ data Expr x env t where
EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
+ EIdxPair :: x (TIdxPair n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t -> Expr x env (TIdxPair n t)
+ EUnIdxPair :: x (TPair (Tup (Replicate n TIx)) t) -> Expr x env (TIdxPair n t) -> Expr x env (TPair (Tup (Replicate n TIx)) t)
+
-- interface of abstract monoidal types
ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
@@ -601,6 +604,9 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t
go SMTMaybe{} _ = ENil ext
go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
go SMTScal{} _ = ENil ext
+ go (SMTIdxPair _ t) e =
+ eunPair (EUnIdxPair ext e) $ \_ e1 e2 ->
+ EPair ext e1 (go t e2)
splitSparsePair
:: -- given a sparsity
diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs
index ea74a95..b09e1bb 100644
--- a/src/CHAD/AST/Accum.hs
+++ b/src/CHAD/AST/Accum.hs
@@ -83,6 +83,7 @@ type family ZeroInfo t where
ZeroInfo (TMaybe a) = TNil
ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
ZeroInfo (TScal t) = TNil
+ ZeroInfo (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (ZeroInfo t)
tZeroInfo :: SMTy t -> STy (ZeroInfo t)
tZeroInfo SMTNil = STNil
@@ -91,6 +92,7 @@ tZeroInfo (SMTLEither _ _) = STNil
tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
+tZeroInfo (SMTIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tZeroInfo t)
-- | Info needed to create a zero-valued deep accumulator for a monoid type.
-- Should be constructable from a D1.
@@ -101,6 +103,7 @@ type family DeepZeroInfo t where
DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
DeepZeroInfo (TScal t) = TNil
+ DeepZeroInfo (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (DeepZeroInfo t)
tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
tDeepZeroInfo SMTNil = STNil
@@ -109,6 +112,7 @@ tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
tDeepZeroInfo (SMTScal _) = STNil
+tDeepZeroInfo (SMTIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tDeepZeroInfo t)
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 930475a..9a4cf99 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -25,19 +25,20 @@ data Sparse t t' where
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's)
SpScal :: Sparse (TScal t) (TScal t)
+ SpIdxPair :: Sparse t t' -> Sparse (TIdxPair n t) (TIdxPair n t')
deriving instance Show (Sparse t t')
type family MultiHot n t's where
MultiHot n '[] = TNil
- MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's)
+ MultiHot n (t' : t's) = TPair (TIdxPair n t') (MultiHot n t's)
tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts)
tMultiHot _ SNil = STNil
-tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+tMultiHot n (t `SCons` ts) = STPair (STIdxPair n t) (tMultiHot n ts)
mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts)
mtMultiHot _ SNil = SMTNil
-mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+mtMultiHot n (t `SCons` ts) = SMTPair (SMTIdxPair n t) (mtMultiHot n ts)
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
@@ -51,6 +52,7 @@ instance ApplySparse STy where
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss)
applySparse SpScal t = t
+ applySparse (SpIdxPair s) (STIdxPair n t) = STIdxPair n (applySparse s t)
instance ApplySparse SMTy where
applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
@@ -59,8 +61,9 @@ instance ApplySparse SMTy where
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
- applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t)
+ applySparse (SpArrIdx l) (SMTArr n t) = mtMultiHot n (slistMap (`applySparse` t) l)
applySparse SpScal t = t
+ applySparse (SpIdxPair s) (SMTIdxPair n t) = SMTIdxPair n (applySparse s t)
class IsSubType s where
@@ -85,6 +88,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
spDense (SMTMaybe t) = SpMaybe (spDense t)
spDense (SMTArr _ t) = SpArr (spDense t)
spDense (SMTScal _) = SpScal
+spDense (SMTIdxPair _ t) = SpIdxPair (spDense t)
isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
isDense SMTNil SpAbsent = Just Refl
@@ -104,6 +108,9 @@ isDense (SMTArr _ t) (SpArr s)
| otherwise = Nothing
isDense SMTArr{} SpArrIdx{} = Nothing
isDense (SMTScal _) SpScal = Just Refl
+isDense (SMTIdxPair _ t) (SpIdxPair s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
isAbsent :: Sparse t t' -> Bool
isAbsent (SpSparse s) = isAbsent s
@@ -112,5 +119,6 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
-isAbsent (SpArrIdx s) = isAbsent s
+isAbsent (SpArrIdx l) = and (unSList isAbsent l)
isAbsent SpScal = False
+isAbsent (SpIdxPair s) = isAbsent s
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index bec2201..1e3d36c 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -28,6 +28,7 @@ type data Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
| TAccum Ty -- ^ contained type must be a monoid type
+ | TIdxPair Nat Ty -- ^ an array one-hot (eliminated by UnMonoid)
type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -43,6 +44,7 @@ data STy t where
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
+ STIdxPair :: SNat n -> STy t -> STy (TIdxPair n t)
deriving instance Show (STy t)
instance GCompare STy where
@@ -77,7 +79,7 @@ data SMTy t where
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
- SMTData :: STy a -> SMTy a -- ^ inclusion of non-monoidal information
+ SMTIdxPair :: SNat n -> SMTy t -> SMTy (TIdxPair n t)
deriving instance Show (SMTy t)
instance GCompare SMTy where
@@ -107,6 +109,7 @@ fromSMTy = \case
SMTMaybe t -> STMaybe (fromSMTy t)
SMTArr n t -> STArr n (fromSMTy t)
SMTScal sty -> STScal sty
+ SMTIdxPair n t -> STIdxPair n (fromSMTy t)
data SScalTy t where
STI32 :: SScalTy TI32
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index 06de00c..6578438 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -2,7 +2,9 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import CHAD.AST
@@ -10,40 +12,106 @@ import CHAD.AST.Sparse.Types
import CHAD.Data
--- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
--- expanding them into their concrete implementations. Also ensure that
--- 'EAccum' has a dense sparsity.
-unMonoid :: Ex env t -> Ex env t
+type family UnMonoid t where
+ UnMonoid (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (UnMonoid t)
+
+ UnMonoid TNil = TNil
+ UnMonoid (TPair a b) = TPair (UnMonoid a) (UnMonoid b)
+ UnMonoid (TEither a b) = TEither (UnMonoid a) (UnMonoid b)
+ UnMonoid (TLEither a b) = TLEither (UnMonoid a) (UnMonoid b)
+ UnMonoid (TMaybe a) = TMaybe (UnMonoid a)
+ UnMonoid (TArr n a) = TArr n (UnMonoid a)
+ UnMonoid (TScal t) = TScal t
+ UnMonoid (TAccum t) = TAccum (UnMonoid t)
+
+type family UnMonoidE env where
+ UnMonoidE '[] = '[]
+ UnMonoidE (t : ts) = UnMonoid t : UnMonoidE ts
+
+tUnMonoid :: STy t -> STy (UnMonoid t)
+tUnMonoid (STIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tUnMonoid t)
+tUnMonoid STNil = STNil
+tUnMonoid (STPair a b) = STPair (tUnMonoid a) (tUnMonoid b)
+tUnMonoid (STEither a b) = STEither (tUnMonoid a) (tUnMonoid b)
+tUnMonoid (STLEither a b) = STLEither (tUnMonoid a) (tUnMonoid b)
+tUnMonoid (STMaybe a) = STMaybe (tUnMonoid a)
+tUnMonoid (STArr n t) = STArr n (tUnMonoid t)
+tUnMonoid (STAccum t) = STAccum (mtUnMonoid t)
+tUnMonoid (STScal t) = STScal t
+
+mtUnMonoid :: SMTy t -> SMTy (UnMonoid t)
+mtUnMonoid (SMTIdxPair n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (SMTScal STI64))) (mtUnMonoid t)
+mtUnMonoid SMTNil = SMTNil
+mtUnMonoid (SMTPair a b) = SMTPair (mtUnMonoid a) (mtUnMonoid b)
+mtUnMonoid (SMTLEither a b) = SMTLEither (mtUnMonoid a) (mtUnMonoid b)
+mtUnMonoid (SMTMaybe a) = SMTMaybe (mtUnMonoid a)
+mtUnMonoid (SMTArr n t) = SMTArr n (mtUnMonoid t)
+mtUnMonoid (SMTScal t) = SMTScal t
+
+unMonoidIndexId :: SNat n -> UnMonoid (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
+unMonoidIndexId SZ = Refl
+unMonoidIndexId (SS n) | Refl <- unMonoidIndexId n = Refl
+
+unMonoidZeroInfoId :: SMTy t -> UnMonoid (ZeroInfo t) :~: ZeroInfo t
+unMonoidZeroInfoId SMTNil = Refl
+unMonoidZeroInfoId (SMTPair a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl
+unMonoidZeroInfoId (SMTLEither a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl
+unMonoidZeroInfoId (SMTMaybe a) | Refl <- unMonoidZeroInfoId a = Refl
+unMonoidZeroInfoId (SMTArr _ a) | Refl <- unMonoidZeroInfoId a = Refl
+unMonoidZeroInfoId (SMTScal _) = Refl
+unMonoidZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId a = Refl
+
+unMonoidDeepZeroInfoId :: SMTy t -> UnMonoid (DeepZeroInfo t) :~: DeepZeroInfo t
+unMonoidDeepZeroInfoId SMTNil = Refl
+unMonoidDeepZeroInfoId (SMTPair a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl
+unMonoidDeepZeroInfoId (SMTLEither a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl
+unMonoidDeepZeroInfoId (SMTMaybe a) | Refl <- unMonoidDeepZeroInfoId a = Refl
+unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl
+unMonoidDeepZeroInfoId (SMTScal _) = Refl
+unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl
+
+
+-- | Removes monoidal stuff from the program. In particular:
+--
+-- * Removes 'EZero', 'EDeepZero', 'EPlus', 'EOneHot' from the program by
+-- expanding them into their concrete implementations.
+-- * Removes 'EIdxPair' and 'EUnIdxPair', as well as all occurrences of the
+-- 'TIdxPair' type in general.
+-- * Ensures that 'EAccum' has a dense sparsity.
+unMonoid :: Ex env t -> Ex (UnMonoidE env) (UnMonoid t)
unMonoid = \case
- EZero _ t e -> zero t e
- EDeepZero _ t e -> deepZero t e
- EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
+ EZero _ t e | Refl <- unMonoidZeroInfoId t -> zero t (unMonoid e)
+ EDeepZero _ t e | Refl <- unMonoidDeepZeroInfoId t -> deepZero t (unMonoid e)
+ EPlus _ t a b -> plus (mtUnMonoid t) (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
+ EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b)
+ EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e
EAccum _ t p eidx sp eval eacc ->
- elet (unMonoid eacc) $
- elet (weakenExpr WSink (unMonoid eidx)) $
- accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 ->
- acPrjCompose SAID p (evar (w @> IZ)) prj2 idx2 $ \prj' idx' ->
- EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (evar (w @> IS IZ))
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
- EVar _ t i -> EVar ext t i
+ EVar _ t i -> EVar ext (tUnMonoid t) (go i)
+ where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t)
+ go IZ = IZ
+ go (IS j) = IS (go j)
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
EFst _ e -> EFst ext (unMonoid e)
ESnd _ e -> ESnd ext (unMonoid e)
ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (unMonoid e)
- EInr _ t e -> EInr ext t (unMonoid e)
+ EInl _ t e -> EInl ext (tUnMonoid t) (unMonoid e)
+ EInr _ t e -> EInr ext (tUnMonoid t) (unMonoid e)
ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
- ENothing _ t -> ENothing ext t
+ ENothing _ t -> ENothing ext (tUnMonoid t)
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
- ELNil _ t1 t2 -> ELNil ext t1 t2
- ELInl _ t e -> ELInl ext t (unMonoid e)
- ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext (tUnMonoid t1) (tUnMonoid t2)
+ ELInl _ t e -> ELInl ext (tUnMonoid t) (unMonoid e)
+ ELInr _ t e -> ELInr ext (tUnMonoid t) (unMonoid e)
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
- EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EBuild _ n a b | Refl <- unMonoidIndexId n -> EBuild ext n (unMonoid a) (unMonoid b)
EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
@@ -51,49 +119,52 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
- EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EReshape _ n a b | Refl <- unMonoidIndexId n -> EReshape ext n (unMonoid a) (unMonoid b)
EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
- EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
- EShape _ e -> EShape ext (unMonoid e)
+ EIdx _ a b | STArr n _ <- typeOf a, Refl <- unMonoidIndexId n -> EIdx ext (unMonoid a) (unMonoid b)
+ EShape _ e | STArr n _ <- typeOf e, Refl <- unMonoidIndexId n -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
- ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext (tUnMonoid t1) (tUnMonoid t2) (tUnMonoid t3) (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EError _ t s -> EError ext t s
-zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env (UnMonoid t)
-- don't destroy the effects!
-zero SMTNil e = ELet ext e $ ENil ext
+zero SMTNil e = use e $ ENil ext
zero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
-zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
-zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTLEither t1 t2) _ = ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2))
+zero (SMTMaybe t) _ = ENothing ext (tUnMonoid (fromSMTy t))
zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
zero (SMTScal t) _ = case t of
STI32 -> EConst ext STI32 0
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+zero (SMTIdxPair _ t) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext e1 (zero t e2)
-deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env (UnMonoid t)
deepZero SMTNil e = elet e $ ENil ext
deepZero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
deepZero (SMTLEither t1 t2) e =
elcase e
- (ELNil ext (fromSMTy t1) (fromSMTy t2))
- (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
- (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+ (ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2)))
+ (ELInl ext (tUnMonoid (fromSMTy t2)) (deepZero t1 (evar IZ)))
+ (ELInr ext (tUnMonoid (fromSMTy t1)) (deepZero t2 (evar IZ)))
deepZero (SMTMaybe t) e =
emaybe e
- (ENothing ext (fromSMTy t))
+ (ENothing ext (tUnMonoid (fromSMTy t)))
(EJust ext (deepZero t (evar IZ)))
deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
deepZero (SMTScal t) _ = case t of
@@ -101,6 +172,9 @@ deepZero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero (SMTIdxPair _ t) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext e1 (deepZero t e2)
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-- don't destroy the effects!
@@ -134,44 +208,38 @@ plus (SMTArr _ t) a b =
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (UnMonoid (AcIdxS p t)) -> Ex env (UnMonoid a) -> Ex env (UnMonoid t)
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
- ELet ext arg $
- EVar ext (fromSMTy typ) IZ
+ arg
+
+ (SMTPair t1 t2, SAPFst prj) | Refl <- unMonoidZeroInfoId t2 ->
+ elet idx $
+ elet (onehot t1 prj (EFst ext (evar IZ)) (weakenExpr WSink arg)) $
+ EPair ext (evar IZ)
+ (zero t2 (ESnd ext (evar (IS IZ))))
- (SMTPair t1 t2, SAPFst prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t1 in
- EPair ext (EVar ext toh IZ)
- (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
-
- (SMTPair t1 t2, SAPSnd prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t2 in
- EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
- (EVar ext toh IZ)
+ (SMTPair t1 t2, SAPSnd prj) | Refl <- unMonoidZeroInfoId t1 ->
+ elet idx $
+ elet (onehot t2 prj (ESnd ext (evar IZ)) (weakenExpr WSink arg)) $
+ EPair ext (zero t1 (EFst ext (evar (IS IZ))))
+ (evar IZ)
(SMTLEither t1 t2, SAPLeft prj) ->
- ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ ELInl ext (tUnMonoid (fromSMTy t2)) (onehot t1 prj idx arg)
(SMTLEither t1 t2, SAPRight prj) ->
- ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
+ ELInr ext (tUnMonoid (fromSMTy t1)) (onehot t2 prj idx arg)
(SMTMaybe t1, SAPJust prj) ->
EJust ext (onehot t1 prj idx arg)
- (SMTArr n t1, SAPArrIdx prj) ->
- let tidx = tTup (sreplicate n tIx)
- in ELet ext idx $
- EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
- zero t1 (EVar ext (tZeroInfo t1) IZ))
+ (SMTArr n t1, SAPArrIdx prj) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId t1 ->
+ elet idx $
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (evar IZ)))) $
+ eif (eidxEq n (evar IZ) (EFst ext (EFst ext (evar (IS IZ)))))
+ (onehot t1 prj (ESnd ext (evar (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (elet (EIdx ext (ESnd ext (EFst ext (evar (IS IZ)))) (evar IZ)) $
+ zero t1 (evar IZ))
accumulateSparse
:: SMTy t -> Sparse t t' -> Ex env t'
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs
index 367a974..f23aeba 100644
--- a/src/CHAD/Drev/Types.hs
+++ b/src/CHAD/Drev/Types.hs
@@ -55,6 +55,7 @@ d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1 STIdxPair{} = error "Index pairs not allowed in input program"
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
@@ -74,6 +75,7 @@ d2M (STScal t) = case t of
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
+d2M STIdxPair{} = error "Index pairs not allowed in input program"
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
@@ -147,6 +149,7 @@ d1Identity = \case
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
+ STIdxPair{} -> error "Index pairs not allowed in input program"
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl