{-# LANGUAGE BangPatterns #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module Numeric.ADDual.Array.Internal where import Control.Monad (when) import Control.Monad.Trans.State.Strict import Data.Foldable (toList) import Data.IORef import Data.List (foldl') import qualified Data.IntMap.Strict as IM import Data.Proxy import qualified Data.Vector.Storable as VS import Foreign.Ptr (castPtr) import Foreign.Storable import GHC.Stack import GHC.Exts (withDict) import System.IO.Unsafe import System.IO (hPutStrLn, stderr) import Numeric.ADDual.VectorOps -- TODO: type roles on 's' debug :: Bool debug = toEnum 0 -- TODO: full vjp (just some more Traversable mess) -- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate {-# NOINLINE gradient' #-} gradient' :: forall a f. (Traversable f, Num a, Storable a) => HasCallStack => Show a -- TODO: remove => (forall s. Taping s a => f (VDual s a) -> Dual s a) -> f (VS.Vector a) -> a -> (a, f (VS.Vector a)) gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Preparing input" let (inp', starti) = runState (traverse (\x -> state (\i -> (VDual x i, i + 1))) inp) 0 inpSizes = VS.fromListN starti (map VS.length (toList inp)) -- The tape starts after the input IDs. taperef <- newIORef (Log starti Start) when debug $ hPutStrLn stderr "Running function" let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi Log _ tape <- readIORef taperef -- when debug $ do -- tapestr <- showTape (tapeTail `Snoc` lastChunk) -- hPutStrLn stderr $ "tape = " ++ tapestr "" when debug $ hPutStrLn stderr "Backpropagating" let (_outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape -- when debug $ do -- accums' <- VS.freeze accums -- hPutStrLn stderr $ "accums = " ++ show accums' when debug $ hPutStrLn stderr "Reconstructing gradient" let readDeriv = do i <- get let d = IM.findWithDefault (VS.replicate (inpSizes VS.! i) 0) i outaccV put (i+1) return d let grad = evalState (traverse (\_ -> readDeriv) inp) 0 return (result, grad) backpropagate :: (Num a, Storable a) => IM.IntMap a -> IM.IntMap (VS.Vector a) -> Int -> Chain a -> (IM.IntMap a, IM.IntMap (VS.Vector a)) backpropagate accS accV i (Cscalar i1 dx i2 dy tape) = case IM.lookup i accS of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accS1 | i1 /= -1 = IM.insertWith (+) i1 (ctg*dx) accS | otherwise = accS accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1 | otherwise = accS1 in backpropagate accS2 accV (i-1) tape backpropagate accS accV i (VCarith i1 dx i2 dy tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accV1 | i1 /= -1 = if VS.length ctg == VS.length dx then IM.insertWith vadd i1 (vmul ctg dx) accV else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" | otherwise = accV accV2 | i2 /= -1 = if VS.length ctg == VS.length dy then IM.insertWith vadd i2 (vmul ctg dy) accV1 else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" | otherwise = accV in backpropagate accS accV2 (i-1) tape backpropagate accS accV i (VCfromList is tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accS1 | VS.length ctg == VS.length is = foldl' (\accS' idx -> IM.insertWith (+) (is VS.! idx) (ctg VS.! idx) accS') accS [0 .. VS.length ctg - 1] | otherwise = error "Numeric.ADDual.Array: wrong cotangent length to vfromList" in backpropagate accS1 accV (i-1) tape backpropagate accS accV i (VCtoList j len tape) = case IM.lookupGE (i - len) accS of Just (smallid, _) | smallid < i -> let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] accV1 = IM.insertWith vadd j ctg accV in backpropagate accS accV1 (i - 1 - len) tape _ -> backpropagate accS accV (i - 1 - len) tape backpropagate accS accV i (VCsum j len tape) = case IM.lookup i accS of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accV1 = IM.alter (\case Nothing -> Just (VS.replicate len ctg) Just d -> Just (VS.map (+ ctg) d)) j accV in backpropagate accS accV1 (i - 1 - len) tape backpropagate accS accV i (VCreplicate j len tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accS1 = IM.insertWith (+) j (fromIntegral len * VS.sum ctg) accS in backpropagate accS1 accV (i - 1 - len) tape backpropagate accS accV _ Start = (accS, accV) data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution {-# UNPACK #-} !Int !a -- ^ idem !(Chain a) | VCarith {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a) -- ^ first argument with scale factors {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a) -- ^ second argument with scale factors !(Chain a) | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list !(Chain a) | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector {-# UNPACK #-} !Int -- ^ number of reserved output IDs (length of the vector) !(Chain a) | VCsum {-# UNPACK #-} !Int -- ^ ID of the input vector {-# UNPACK #-} !Int -- ^ length of the vector !(Chain a) | VCreplicate {-# UNPACK #-} !Int -- ^ ID of the input scalar {-# UNPACK #-} !Int -- ^ length of the replicated vector !(Chain a) | Start deriving (Show) data Log s a = Log !Int -- ^ next ID to generate !(Chain a) -- ^ tape -- | This class does not have any instances defined, on purpose. You'll get one -- magically when you differentiate. class Taping s a where getTape :: IORef (Log s a) data Dual s a = Dual !a {-# UNPACK #-} !Int -- ^ -1 if this is a constant instance Eq a => Eq (Dual s a) where Dual x _ == Dual y _ = x == y instance Ord a => Ord (Dual s a) where compare (Dual x _) (Dual y _) = compare x y instance (Num a, Taping s a) => Num (Dual s a) where Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 signum (Dual x _) = Dual (signum x) (-1) fromInteger n = Dual (fromInteger n) (-1) instance (Fractional a, Taping s a) => Fractional (Dual s a) where Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y)) recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0 fromRational r = Dual (fromRational r) (-1) instance (Floating a, Taping s a) => Floating (Dual s a) where pi = Dual pi (-1) exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0 log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0 sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0 -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x Dual x i1 ** Dual y i2 = let z = x ** y in mkDual z i1 (z * y/x) i2 (z * log x) logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined atanh = undefined -- | This instance allows breaking the abstraction of 'Dual'. Don't inspect or modify the serialised representation, and DO NOT use serialised 'Dual' values from one 'gradient'' computation in another! instance Storable a => Storable (Dual s a) where sizeOf _ = sizeOf (undefined :: a) + sizeOf (undefined :: Int) alignment _ = alignment (undefined :: a) peek ptr = Dual <$> peek (castPtr ptr) <*> peekByteOff ptr (sizeOf (undefined :: a)) poke ptr (Dual x i) = poke (castPtr ptr) x >> pokeByteOff ptr (sizeOf (undefined :: a)) i constant :: a -> Dual s a constant x = Dual x (-1) mkDual :: forall a s. Taping s a => a -> Int -> a -> Int -> a -> Dual s a mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 dy)) data VDual s a = VDual !(VS.Vector a) {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector instance (Storable a, Num a, Taping s a) => VectorOps (VDual s a) where type VectorOpsScalar (VDual s a) = Dual s a vfromListN n l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromListN n xs) (VCfromList (VS.fromListN n is)) vfromList l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) vtoList (VDual v i) = let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v) in zipWith Dual (VS.toList v) [starti..] vlength (VDual v _) = VS.length v vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n) vselect bs (VDual a i) (VDual b j) = mkVDual (vselect bs a b) (VCarith i (VS.map (fromIntegral . fromEnum) bs) j (VS.map (fromIntegral . fromEnum . not) bs)) instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where vadd (VDual v i) (VDual w j) = let len = VS.length v in mkVDual (vadd v w) (VCarith i (VS.replicate len 1) j (VS.replicate len 1)) vsub (VDual v i) (VDual w j) = let len = VS.length v in mkVDual (vsub v w) (VCarith i (VS.replicate len 1) j (VS.replicate len (-1))) vmul (VDual v i) (VDual w j) = mkVDual (vmul v w) (VCarith i w j v) vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v))) instance (Storable a, Floating a, Taping s a) => VectorOpsFloating (VDual s a) where vexp (VDual v i) = mkVDual (vexp v) (VCarith i v (-1) VS.empty) instance (Storable a, Num a, Ord a, Taping s a) => VectorOpsOrd (VDual s a) where vcmpLE (VDual v _) (VDual w _) = vcmpLE v w vmaximum (VDual v i) = let w = vmaximum v in Dual w (writeTapeUnsafe @a (Proxy @s) (VCarith i (VS.map (\x -> if x == w then 1 else 0) v) (-1) VS.empty)) vconstant :: VS.Vector a -> VDual s a vconstant v = VDual v (-1) mkVDual :: forall a s. Taping s a => VS.Vector a -> (Chain a -> Chain a) -> VDual s a mkVDual res f = VDual res (writeTapeUnsafe (Proxy @s) f) {-# NOINLINE writeTapeUnsafe #-} writeTapeUnsafe :: forall a s proxy. Taping s a => proxy s -> (Chain a -> Chain a) -> Int writeTapeUnsafe _ f = unsafePerformIO $ atomicModifyIORef' (getTape @s) $ \(Log i tape) -> (Log (i + 1) (f tape), i) {-# NOINLINE allocTapeToListUnsafe #-} allocTapeToListUnsafe :: forall a s proxy. Taping s a => proxy a -> proxy s -> Int -> Int -> Int allocTapeToListUnsafe _ _ vecid len = unsafePerformIO $ atomicModifyIORef' (getTape @s @a) $ \(Log i tape) -> (Log (i + len + 1) (VCtoList vecid len tape), i)