diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 /src/CHAD | |
parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/CHAD')
-rw-r--r-- | src/CHAD/Types.hs | 4 | ||||
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 16 |
2 files changed, 10 insertions, 10 deletions
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index e8ec0c9..7f49cef 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -20,7 +20,7 @@ type family D2 t where D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TArr n (D2 t) + D2 (TArr n t) = TMaybe (TArr n (D2 t)) D2 (TScal t) = D2s t type family D2s t where @@ -60,7 +60,7 @@ d2 STNil = STNil d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STArr n (d2 t) +d2 (STArr n t) = STMaybe (STArr n (d2 t)) d2 (STScal t) = case t of STI32 -> STNil STI64 -> STNil diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index a75fdb8..f843206 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -29,14 +29,14 @@ toTan typ primal der = case typ of (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t - | shapeSize (arrayShape der) == 0 -> - arrayMap (zeroTan t) primal - | arrayShape primal == arrayShape der -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t -> case der of + Nothing -> arrayMap (zeroTan t) primal + Just d + | arrayShape primal == arrayShape d -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" |