aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-02-20 22:41:53 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-02-20 22:41:53 +0100
commit20754e0ae9226f658590f46105399aee5c0dcee2 (patch)
tree261be51f07c530d09665708fb149dd241c296b70 /src
parentfe3132304b6c25e5bebc9fb327e3ea5d6018be7a (diff)
Try debugging crash
Diffstat (limited to 'src')
-rw-r--r--src/Numeric/ADDual/Internal.hs13
1 files changed, 8 insertions, 5 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
index 8228694..d8e6e0a 100644
--- a/src/Numeric/ADDual/Internal.hs
+++ b/src/Numeric/ADDual/Internal.hs
@@ -15,18 +15,18 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Ptr
import Foreign.Storable
+import GHC.Stack
import GHC.Exts (withDict)
import System.IO.Unsafe
--- import System.IO (hPutStrLn, stderr)
-
--- import Numeric.ADDual.IDGen
+import System.IO (hPutStrLn, stderr)
-- TODO: full vjp (just some more Traversable mess)
{-# NOINLINE gradient' #-}
gradient' :: forall a f. (Traversable f, Num a, Storable a)
- -- => Show a -- TODO: remove
+ => HasCallStack
+ => Show a -- TODO: remove
=> (forall s. Taping s a => f (Dual s a) -> Dual s a)
-> f a -> a -> (a, f a)
gradient' f inp topctg = unsafePerformIO $ do
@@ -50,6 +50,7 @@ gradient' f inp topctg = unsafePerformIO $ do
let backpropagate i chunk@(Chunk ci0 vec) tape
| i >= ci0 = do
+ -- hPutStrLn stderr $ "read at ci0=" ++ show ci0 ++ " i=" ++ show i
ctg <- VSM.read accums i
Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)
when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
@@ -174,7 +175,9 @@ data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
-writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int
+writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a)
+ => HasCallStack
+ => proxy s -> Int -> a -> Int -> a -> IO Int
writeTapeIO _ i1 dx i2 dy = do
MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s)
let n = VSM.length vec