diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-20 22:41:53 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-20 22:41:53 +0100 |
commit | 20754e0ae9226f658590f46105399aee5c0dcee2 (patch) | |
tree | 261be51f07c530d09665708fb149dd241c296b70 /src/Numeric/ADDual | |
parent | fe3132304b6c25e5bebc9fb327e3ea5d6018be7a (diff) |
Try debugging crash
Diffstat (limited to 'src/Numeric/ADDual')
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 8228694..d8e6e0a 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -15,18 +15,18 @@ import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Ptr import Foreign.Storable +import GHC.Stack import GHC.Exts (withDict) import System.IO.Unsafe --- import System.IO (hPutStrLn, stderr) - --- import Numeric.ADDual.IDGen +import System.IO (hPutStrLn, stderr) -- TODO: full vjp (just some more Traversable mess) {-# NOINLINE gradient' #-} gradient' :: forall a f. (Traversable f, Num a, Storable a) - -- => Show a -- TODO: remove + => HasCallStack + => Show a -- TODO: remove => (forall s. Taping s a => f (Dual s a) -> Dual s a) -> f a -> a -> (a, f a) gradient' f inp topctg = unsafePerformIO $ do @@ -50,6 +50,7 @@ gradient' f inp topctg = unsafePerformIO $ do let backpropagate i chunk@(Chunk ci0 vec) tape | i >= ci0 = do + -- hPutStrLn stderr $ "read at ci0=" ++ show ci0 ++ " i=" ++ show i ctg <- VSM.read accums i Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 @@ -174,7 +175,9 @@ data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy -writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int +writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) + => HasCallStack + => proxy s -> Int -> a -> Int -> a -> IO Int writeTapeIO _ i1 dx i2 dy = do MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) let n = VSM.length vec |