From 72eddb67bb6f048fc2076184be3a32169026a832 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 7 Oct 2024 14:34:27 +0200 Subject: Towards a test suite --- src/ForwardAD.hs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'src/ForwardAD.hs') diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 63244a8..6d53b48 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -52,6 +52,21 @@ tanty (STScal t) = case t of STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +zeroTan :: STy t -> Rep t -> Rep (Tan t) +zeroTan STNil () = () +zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) +zeroTan (STEither a _) (Left x) = Left (zeroTan a x) +zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STMaybe _) Nothing = Nothing +zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) +zeroTan (STArr _ t) x = fmap (zeroTan t) x +zeroTan (STScal STI32) _ = () +zeroTan (STScal STI64) _ = () +zeroTan (STScal STF32) _ = 0.0 +zeroTan (STScal STF64) _ = 0.0 +zeroTan (STScal STBool) _ = () +zeroTan STAccum{} _ = error "Accumulators not allowed in input program" + unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) unzipDN STNil _ = ((), ()) unzipDN (STPair a b) (d1, d2) = -- cgit v1.2.3-70-g09d2