summaryrefslogtreecommitdiff
path: root/src/CHAD/Types
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Types')
-rw-r--r--src/CHAD/Types/ToTan.hs18
1 files changed, 7 insertions, 11 deletions
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index 8476712..888fed4 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
- STPair t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal
STEither t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just d -> case (primal, d) of
@@ -34,14 +32,12 @@ toTan typ primal der = case typ of
(Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
_ -> error "Primal and cotangent disagree on LEither alternative"
STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t -> case der of
- Nothing -> arrayMap (zeroTan t) primal
- Just d
- | arrayShape primal == arrayShape d ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
+ STArr _ t
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"