aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-02-21 13:04:16 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-02-21 13:04:16 +0100
commit14be880ce787a35f70a749c1176c4607bfb535ed (patch)
tree7397efa6a119fc324bf418ab06271979a81e4e9d /src/Numeric/ADDual/Internal.hs
parent94a59b0d78ff16903f250989a6121d13dae23e2f (diff)
Evaluate result Dual before backpropagating
This ensures that the tape is actually created/written before we start backpropagating over it, lol.
Diffstat (limited to 'src/Numeric/ADDual/Internal.hs')
-rw-r--r--src/Numeric/ADDual/Internal.hs3
1 files changed, 2 insertions, 1 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
index 858e0db..5955fae 100644
--- a/src/Numeric/ADDual/Internal.hs
+++ b/src/Numeric/ADDual/Internal.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf #-}
@@ -41,7 +42,7 @@ gradient' f inp topctg = unsafePerformIO $ do
taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
when debug $ hPutStrLn stderr "Running function"
- let Dual result outi = withDict @(Taping () a) taperef $ f @() inp'
+ let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp'
when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
MLog _ lastChunk tapeTail <- readIORef taperef