diff options
Diffstat (limited to 'src/Numeric/ADDual')
| -rw-r--r-- | src/Numeric/ADDual/Internal.hs | 29 | 
1 files changed, 18 insertions, 11 deletions
| diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index d8e6e0a..55c47c2 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -22,6 +22,10 @@ import System.IO.Unsafe  import System.IO (hPutStrLn, stderr) +debug :: Bool +debug = toEnum 0 + +  -- TODO: full vjp (just some more Traversable mess)  {-# NOINLINE gradient' #-}  gradient' :: forall a f. (Traversable f, Num a, Storable a) @@ -30,27 +34,27 @@ gradient' :: forall a f. (Traversable f, Num a, Storable a)            => (forall s. Taping s a => f (Dual s a) -> Dual s a)            -> f a -> a -> (a, f a)  gradient' f inp topctg = unsafePerformIO $ do -  -- hPutStrLn stderr "Preparing input" +  when debug $ hPutStrLn stderr "Preparing input"    let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0    idref <- newIORef starti    vec1 <- VSM.new (max 128 (2 * starti))    taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) -  -- hPutStrLn stderr "Running function" +  when debug $ hPutStrLn stderr "Running function"    let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' -  -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi +  when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi    MLog _ lastChunk tapeTail <- readIORef taperef -  -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) -  --    hPutStrLn stderr $ "tape = " ++ tapestr "" +  when debug $ do +    tapestr <- showTape (tapeTail `Snoc` lastChunk) +    hPutStrLn stderr $ "tape = " ++ tapestr "" -  -- hPutStrLn stderr "Backpropagating" +  when debug $ hPutStrLn stderr "Backpropagating"    accums <- VSM.new (outi+1)    VSM.write accums outi topctg    let backpropagate i chunk@(Chunk ci0 vec) tape          | i >= ci0 = do -            -- hPutStrLn stderr $ "read at ci0=" ++ show ci0 ++ " i=" ++ show i              ctg <- VSM.read accums i              Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)              when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 @@ -72,10 +76,11 @@ gradient' f inp topctg = unsafePerformIO $ do                                  SLNil      Snoc{} -> backpropagate outi lastChunk tapeTail -  -- do accums' <- VS.freeze accums -  --    hPutStrLn stderr $ "accums = " ++ show accums' +  when debug $ do +    accums' <- VS.freeze accums +    hPutStrLn stderr $ "accums = " ++ show accums' -  -- hPutStrLn stderr "Reconstructing gradient" +  when debug $ hPutStrLn stderr "Reconstructing gradient"    let readDeriv = do i <- get                       d <- lift $ VSM.read accums i                       put (i+1) @@ -210,7 +215,9 @@ writeTapeIO _ i1 dx i2 dy = do                   -- chunk in the first place!                   error "writeTape: impossible"           case action of -           WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) +           WTANewvec vec' -> do +             when debug $ hPutStrLn stderr $ "writeTapeIO: new vec of size " ++ show (VSM.length vec') +             VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)             WTAOldTape tape ->               let go SLNil = error "writeTape: no appropriate tape chunk?"                   go (tape' `Snoc` Chunk ci0' vec') | 
