diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-21 09:57:45 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-21 09:57:45 +0200 |
commit | b5ed3d2fcc249cb410b9e86d25e9ef808c6dba97 (patch) | |
tree | 66383b16d5d95f939aaa165a783dbbfd99a57fe3 /src/CHAD/Types/ToTan.hs | |
parent | 8bbc2d2867e3d0a4a1f2810b40e92175779822e1 (diff) | |
parent | a4b3eb76acbec30ffeae119a4dc6e4c9f64396fe (diff) |
Diffstat (limited to 'src/CHAD/Types/ToTan.hs')
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 8476712..888fed4 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of @@ -34,14 +32,12 @@ toTan typ primal der = case typ of (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" |