{-# LANGUAGE BangPatterns #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module Numeric.ADDual.Array.Internal where import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Data.IORef import Data.Proxy import Data.Typeable import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable import GHC.Stack import GHC.Exts (withDict) import System.IO.Unsafe import System.IO (hPutStrLn, stderr) import Numeric.ADDual.VectorOps -- TODO: type roles on 's' debug :: Bool debug = toEnum 0 -- TODO: full vjp (just some more Traversable mess) -- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate {-# NOINLINE gradient' #-} gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) => HasCallStack => Show a -- TODO: remove => (forall s. Taping s a => f (Dual s a) -> Dual s a) -> f a -> a -> (a, f a) gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Preparing input" let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 -- The tape starts after the input IDs. taperef <- newIORef (Log starti Start) when debug $ hPutStrLn stderr "Running function" let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi Log _ tape <- readIORef taperef -- when debug $ do -- tapestr <- showTape (tapeTail `Snoc` lastChunk) -- hPutStrLn stderr $ "tape = " ++ tapestr "" when debug $ hPutStrLn stderr "Backpropagating" accums <- VSM.new (outi+1) VSM.write accums outi topctg let backpropagate i (Cscalar i1 dx i2 dy tape') = do ctg <- VSM.read accums i when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 backpropagate (i-1) tape' backpropagate _ Start = return () backpropagate outi tape when debug $ do accums' <- VS.freeze accums hPutStrLn stderr $ "accums = " ++ show accums' when debug $ hPutStrLn stderr "Reconstructing gradient" let readDeriv = do i <- get d <- lift $ VSM.read accums i put (i+1) return d grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 return (result, grad) -- | Contribution to a vector-typed value data VCon a = VCon {-# UNPACK #-} !Int -- ^ the ID of the vector value {-# UNPACK #-} !(VS.Vector a) -- ^ the cotangent | VConNothing deriving (Show) data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution {-# UNPACK #-} !Int !a -- ^ idem !(Chain a) | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list !(Chain a) | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector {-# UNPACK #-} !Int -- ^ start of the reserved output ID range {-# UNPACK #-} !Int -- ^ number of reserved output IDs (length of the vector) !(Chain a) | VCsum {-# UNPACK #-} !Int -- ^ ID of the input vector !(Chain a) | VCreplicate {-# UNPACK #-} !Int -- ^ length of the replicated vector {-# UNPACK #-} !Int -- ^ ID of the input scalar !(Chain a) | Start deriving (Show) data Log s a = Log !Int -- ^ next ID to generate !(Chain a) -- ^ tape -- | This class does not have any instances defined, on purpose. You'll get one -- magically when you differentiate. class Taping s a where getTape :: IORef (Log s a) data Dual s a = Dual !a {-# UNPACK #-} !Int -- ^ -1 if this is a constant instance Eq a => Eq (Dual s a) where Dual x _ == Dual y _ = x == y instance Ord a => Ord (Dual s a) where compare (Dual x _) (Dual y _) = compare x y instance (Num a, Taping s a) => Num (Dual s a) where Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 signum (Dual x _) = Dual (signum x) (-1) fromInteger n = Dual (fromInteger n) (-1) instance (Fractional a, Taping s a) => Fractional (Dual s a) where Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y)) recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0 fromRational r = Dual (fromRational r) (-1) instance (Floating a, Taping s a) => Floating (Dual s a) where pi = Dual pi (-1) exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0 log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0 sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0 -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x Dual x i1 ** Dual y i2 = let z = x ** y in mkDual z i1 (z * y/x) i2 (z * log x) logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined atanh = undefined constant :: a -> Dual s a constant x = Dual x (-1) mkDual :: forall a s. Taping s a => a -> Int -> a -> Int -> a -> Dual s a mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 dy)) data VDual s a = VDual !(VS.Vector a) {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector instance (Storable a, Taping s a) => VectorOps (VDual s a) where type VectorOpsScalar (VDual s a) = Dual s a vfromListN n l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromListN n xs) (VCfromList (VS.fromListN n is)) vfromList l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) vtoList (VDual v i) = _ vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate n i) instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i)) vconstant :: VS.Vector a -> VDual s a vconstant v = VDual v (-1) mkVDual :: forall a s. Taping s a => VS.Vector a -> (Chain a -> Chain a) -> VDual s a mkVDual res f = VDual res (writeTapeUnsafe (Proxy @s) f) {-# NOINLINE writeTapeUnsafe #-} writeTapeUnsafe :: forall a s proxy. Taping s a => proxy s -> (Chain a -> Chain a) -> Int writeTapeUnsafe _ f = unsafePerformIO $ atomicModifyIORef' (getTape @s) $ \(Log i tape) -> (Log (i + 1) (f tape), i)