summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
commita00234388d1b4e14481067d030bf90031258b756 (patch)
tree501b6778fc5779ce220aba1e22f56ae60f68d970 /src/CHAD
parent7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff)
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Types.hs4
-rw-r--r--src/CHAD/Types/ToTan.hs16
2 files changed, 10 insertions, 10 deletions
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index e8ec0c9..7f49cef 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -20,7 +20,7 @@ type family D2 t where
D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TArr n (D2 t)
+ D2 (TArr n t) = TMaybe (TArr n (D2 t))
D2 (TScal t) = D2s t
type family D2s t where
@@ -60,7 +60,7 @@ d2 STNil = STNil
d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
d2 (STMaybe t) = STMaybe (d2 t)
-d2 (STArr n t) = STArr n (d2 t)
+d2 (STArr n t) = STMaybe (STArr n (d2 t))
d2 (STScal t) = case t of
STI32 -> STNil
STI64 -> STNil
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index a75fdb8..f843206 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -29,14 +29,14 @@ toTan typ primal der = case typ of
(Right p, Right d') -> Right (toTan t2 p d')
_ -> error "Primal and cotangent disagree on Either alternative"
STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t
- | shapeSize (arrayShape der) == 0 ->
- arrayMap (zeroTan t) primal
- | arrayShape primal == arrayShape der ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
+ STArr _ t -> case der of
+ Nothing -> arrayMap (zeroTan t) primal
+ Just d
+ | arrayShape primal == arrayShape d ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"