summaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 6d53b48..86d2fb0 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -67,6 +67,21 @@ zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars :: STy t -> Rep (Tan t) -> [Double]
+tanScalars STNil () = []
+tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
+tanScalars (STEither a _) (Left x) = tanScalars a x
+tanScalars (STEither _ b) (Right y) = tanScalars b y
+tanScalars (STMaybe _) Nothing = []
+tanScalars (STMaybe t) (Just x) = tanScalars t x
+tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
+tanScalars (STScal STI32) _ = []
+tanScalars (STScal STI64) _ = []
+tanScalars (STScal STF32) x = [realToFrac x]
+tanScalars (STScal STF64) x = [x]
+tanScalars (STScal STBool) _ = []
+tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+
unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =