aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Numeric/ADDual/Internal.hs29
1 files changed, 18 insertions, 11 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
index d8e6e0a..55c47c2 100644
--- a/src/Numeric/ADDual/Internal.hs
+++ b/src/Numeric/ADDual/Internal.hs
@@ -22,6 +22,10 @@ import System.IO.Unsafe
import System.IO (hPutStrLn, stderr)
+debug :: Bool
+debug = toEnum 0
+
+
-- TODO: full vjp (just some more Traversable mess)
{-# NOINLINE gradient' #-}
gradient' :: forall a f. (Traversable f, Num a, Storable a)
@@ -30,27 +34,27 @@ gradient' :: forall a f. (Traversable f, Num a, Storable a)
=> (forall s. Taping s a => f (Dual s a) -> Dual s a)
-> f a -> a -> (a, f a)
gradient' f inp topctg = unsafePerformIO $ do
- -- hPutStrLn stderr "Preparing input"
+ when debug $ hPutStrLn stderr "Preparing input"
let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
idref <- newIORef starti
vec1 <- VSM.new (max 128 (2 * starti))
taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
- -- hPutStrLn stderr "Running function"
+ when debug $ hPutStrLn stderr "Running function"
let Dual result outi = withDict @(Taping () a) taperef $ f @() inp'
- -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
+ when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
MLog _ lastChunk tapeTail <- readIORef taperef
- -- do tapestr <- showTape (tapeTail `Snoc` lastChunk)
- -- hPutStrLn stderr $ "tape = " ++ tapestr ""
+ when debug $ do
+ tapestr <- showTape (tapeTail `Snoc` lastChunk)
+ hPutStrLn stderr $ "tape = " ++ tapestr ""
- -- hPutStrLn stderr "Backpropagating"
+ when debug $ hPutStrLn stderr "Backpropagating"
accums <- VSM.new (outi+1)
VSM.write accums outi topctg
let backpropagate i chunk@(Chunk ci0 vec) tape
| i >= ci0 = do
- -- hPutStrLn stderr $ "read at ci0=" ++ show ci0 ++ " i=" ++ show i
ctg <- VSM.read accums i
Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)
when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
@@ -72,10 +76,11 @@ gradient' f inp topctg = unsafePerformIO $ do
SLNil
Snoc{} -> backpropagate outi lastChunk tapeTail
- -- do accums' <- VS.freeze accums
- -- hPutStrLn stderr $ "accums = " ++ show accums'
+ when debug $ do
+ accums' <- VS.freeze accums
+ hPutStrLn stderr $ "accums = " ++ show accums'
- -- hPutStrLn stderr "Reconstructing gradient"
+ when debug $ hPutStrLn stderr "Reconstructing gradient"
let readDeriv = do i <- get
d <- lift $ VSM.read accums i
put (i+1)
@@ -210,7 +215,9 @@ writeTapeIO _ i1 dx i2 dy = do
-- chunk in the first place!
error "writeTape: impossible"
case action of
- WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)
+ WTANewvec vec' -> do
+ when debug $ hPutStrLn stderr $ "writeTapeIO: new vec of size " ++ show (VSM.length vec')
+ VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)
WTAOldTape tape ->
let go SLNil = error "writeTape: no appropriate tape chunk?"
go (tape' `Snoc` Chunk ci0' vec')