diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Numeric/ADDual/Internal.hs | 13 | 
1 files changed, 8 insertions, 5 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 8228694..d8e6e0a 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -15,18 +15,18 @@ import qualified Data.Vector.Storable as VS  import qualified Data.Vector.Storable.Mutable as VSM  import Foreign.Ptr  import Foreign.Storable +import GHC.Stack  import GHC.Exts (withDict)  import System.IO.Unsafe --- import System.IO (hPutStrLn, stderr) - --- import Numeric.ADDual.IDGen +import System.IO (hPutStrLn, stderr)  -- TODO: full vjp (just some more Traversable mess)  {-# NOINLINE gradient' #-}  gradient' :: forall a f. (Traversable f, Num a, Storable a) -          -- => Show a  -- TODO: remove +          => HasCallStack +          => Show a  -- TODO: remove            => (forall s. Taping s a => f (Dual s a) -> Dual s a)            -> f a -> a -> (a, f a)  gradient' f inp topctg = unsafePerformIO $ do @@ -50,6 +50,7 @@ gradient' f inp topctg = unsafePerformIO $ do    let backpropagate i chunk@(Chunk ci0 vec) tape          | i >= ci0 = do +            -- hPutStrLn stderr $ "read at ci0=" ++ show ci0 ++ " i=" ++ show i              ctg <- VSM.read accums i              Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)              when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 @@ -174,7 +175,9 @@ data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))  writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int  writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy -writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int +writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) +            => HasCallStack +            => proxy s -> Int -> a -> Int -> a -> IO Int  writeTapeIO _ i1 dx i2 dy = do    MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s)    let n = VSM.length vec  | 
