diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:03:42 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:03:47 +0100 |
commit | 4bd1890dccb45a90f10183a916f93f025a3f57d2 (patch) | |
tree | 6f706aad594558cfffd9044bedb77671e01a0dd0 /src/Numeric/ADDual | |
parent | 20754e0ae9226f658590f46105399aee5c0dcee2 (diff) |
Add (toggleable) debug code
Diffstat (limited to 'src/Numeric/ADDual')
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index d8e6e0a..55c47c2 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -22,6 +22,10 @@ import System.IO.Unsafe import System.IO (hPutStrLn, stderr) +debug :: Bool +debug = toEnum 0 + + -- TODO: full vjp (just some more Traversable mess) {-# NOINLINE gradient' #-} gradient' :: forall a f. (Traversable f, Num a, Storable a) @@ -30,27 +34,27 @@ gradient' :: forall a f. (Traversable f, Num a, Storable a) => (forall s. Taping s a => f (Dual s a) -> Dual s a) -> f a -> a -> (a, f a) gradient' f inp topctg = unsafePerformIO $ do - -- hPutStrLn stderr "Preparing input" + when debug $ hPutStrLn stderr "Preparing input" let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 idref <- newIORef starti vec1 <- VSM.new (max 128 (2 * starti)) taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) - -- hPutStrLn stderr "Running function" + when debug $ hPutStrLn stderr "Running function" let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' - -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi + when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi MLog _ lastChunk tapeTail <- readIORef taperef - -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) - -- hPutStrLn stderr $ "tape = " ++ tapestr "" + when debug $ do + tapestr <- showTape (tapeTail `Snoc` lastChunk) + hPutStrLn stderr $ "tape = " ++ tapestr "" - -- hPutStrLn stderr "Backpropagating" + when debug $ hPutStrLn stderr "Backpropagating" accums <- VSM.new (outi+1) VSM.write accums outi topctg let backpropagate i chunk@(Chunk ci0 vec) tape | i >= ci0 = do - -- hPutStrLn stderr $ "read at ci0=" ++ show ci0 ++ " i=" ++ show i ctg <- VSM.read accums i Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 @@ -72,10 +76,11 @@ gradient' f inp topctg = unsafePerformIO $ do SLNil Snoc{} -> backpropagate outi lastChunk tapeTail - -- do accums' <- VS.freeze accums - -- hPutStrLn stderr $ "accums = " ++ show accums' + when debug $ do + accums' <- VS.freeze accums + hPutStrLn stderr $ "accums = " ++ show accums' - -- hPutStrLn stderr "Reconstructing gradient" + when debug $ hPutStrLn stderr "Reconstructing gradient" let readDeriv = do i <- get d <- lift $ VSM.read accums i put (i+1) @@ -210,7 +215,9 @@ writeTapeIO _ i1 dx i2 dy = do -- chunk in the first place! error "writeTape: impossible" case action of - WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) + WTANewvec vec' -> do + when debug $ hPutStrLn stderr $ "writeTapeIO: new vec of size " ++ show (VSM.length vec') + VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) WTAOldTape tape -> let go SLNil = error "writeTape: no appropriate tape chunk?" go (tape' `Snoc` Chunk ci0' vec') |