diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:04:16 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:04:16 +0100 |
commit | 14be880ce787a35f70a749c1176c4607bfb535ed (patch) | |
tree | 7397efa6a119fc324bf418ab06271979a81e4e9d | |
parent | 94a59b0d78ff16903f250989a6121d13dae23e2f (diff) |
Evaluate result Dual before backpropagating
This ensures that the tape is actually created/written before we start
backpropagating over it, lol.
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 858e0db..5955fae 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiWayIf #-} @@ -41,7 +42,7 @@ gradient' f inp topctg = unsafePerformIO $ do taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) when debug $ hPutStrLn stderr "Running function" - let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' + let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi MLog _ lastChunk tapeTail <- readIORef taperef |