From b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 27 Apr 2025 23:34:59 +0200 Subject: WIP revamp accumulators again: explicit monoid types No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication. --- src/ForwardAD/DualNumbers.hs | 4 ++++ src/ForwardAD/DualNumbers/Types.hs | 2 ++ 2 files changed, 6 insertions(+) (limited to 'src/ForwardAD') diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 2f94076..ebc70d7 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -143,6 +143,10 @@ dfwdDN = \case ENothing _ t -> ENothing ext (dn t) EJust _ e -> EJust ext (dfwdDN e) EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) + ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) + ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) + ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) EConstArr _ n t x -> scalTyCase t (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs index fba92d0..3c76cbe 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -15,6 +15,7 @@ type family DN t where DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t + DN (TLEither a b) = TLEither (DN a) (DN b) type family DNS t where DNS TF32 = TPair (TScal TF32) (TScal TF32) @@ -40,6 +41,7 @@ dn (STScal t) = case t of STI64 -> STScal STI64 STBool -> STScal STBool dn STAccum{} = error "Accum in source program" +dn (STLEither a b) = STLEither (dn a) (dn b) dne :: SList STy env -> SList STy (DNE env) dne SNil = SNil -- cgit v1.2.3-70-g09d2