{-# LANGUAGE BangPatterns #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module Numeric.ADDual.Array.Internal where import Control.Monad (when) import Control.Monad.Trans.State.Strict import Data.Foldable (toList) import Data.IORef import Data.List (foldl') import qualified Data.IntMap.Strict as IM import Data.Proxy import qualified Data.Vector.Storable as VS import Foreign.Storable import GHC.Stack import GHC.Exts (withDict) import System.IO.Unsafe import System.IO (hPutStrLn, stderr) import Numeric.ADDual.VectorOps -- TODO: type roles on 's' debug :: Bool debug = toEnum 0 -- TODO: full vjp (just some more Traversable mess) -- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate {-# NOINLINE gradient' #-} gradient' :: forall a f. (Traversable f, Num a, Storable a) => HasCallStack => Show a -- TODO: remove => (forall s. Taping s a => f (VDual s a) -> Dual s a) -> f (VS.Vector a) -> a -> (a, f (VS.Vector a)) gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Preparing input" let (inp', starti) = runState (traverse (\x -> state (\i -> (VDual x i, i + 1))) inp) 0 inpSizes = VS.fromListN starti (map VS.length (toList inp)) -- The tape starts after the input IDs. taperef <- newIORef (Log starti Start) when debug $ hPutStrLn stderr "Running function" let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi Log _ tape <- readIORef taperef -- when debug $ do -- tapestr <- showTape (tapeTail `Snoc` lastChunk) -- hPutStrLn stderr $ "tape = " ++ tapestr "" when debug $ hPutStrLn stderr "Backpropagating" let (outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape -- when debug $ do -- accums' <- VS.freeze accums -- hPutStrLn stderr $ "accums = " ++ show accums' when debug $ hPutStrLn stderr "Reconstructing gradient" let readDeriv = do i <- get let d = IM.findWithDefault (VS.replicate (inpSizes VS.! i) 0) i outaccV put (i+1) return d let grad = evalState (traverse (\_ -> readDeriv) inp) 0 return (result, grad) backpropagate :: (Num a, Storable a) => IM.IntMap a -> IM.IntMap (VS.Vector a) -> Int -> Chain a -> (IM.IntMap a, IM.IntMap (VS.Vector a)) backpropagate accS accV i (Cscalar i1 dx i2 dy tape) = case IM.lookup i accS of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accS1 | i1 /= -1 = IM.insertWith (+) i1 (ctg*dx) accS | otherwise = accS accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1 | otherwise = accS1 in backpropagate accS2 accV (i-1) tape backpropagate accS accV i (VCfromList is tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accS1 | VS.length ctg == VS.length is = foldl' (\accS' idx -> IM.insertWith (+) (is VS.! idx) (ctg VS.! idx) accS') accS [0 .. VS.length ctg - 1] | otherwise = error "Numeric.ADDual.Array: wrong cotangent length to vfromList" in backpropagate accS1 accV (i-1) tape backpropagate accS accV i (VCtoList j len tape) = case IM.lookupGE (i - len) accS of Just (smallid, _) | smallid < i -> let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] accV1 = IM.insertWith (VS.zipWith (+)) j ctg accV in backpropagate accS accV1 (i - 1 - len) tape _ -> backpropagate accS accV (i - 1 - len) tape backpropagate accS accV i (VCsum j len tape) = case IM.lookup i accS of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accV1 = IM.alter (\case Nothing -> Just (VS.replicate len ctg) Just d -> Just (VS.map (+ ctg) d)) j accV in backpropagate accS accV1 (i - 1 - len) tape backpropagate accS accV i (VCreplicate j len tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape Just ctg -> let accS1 = IM.insertWith (+) j (fromIntegral len * VS.sum ctg) accS in backpropagate accS1 accV (i - 1 - len) tape backpropagate accS accV _ Start = (accS, accV) data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution {-# UNPACK #-} !Int !a -- ^ idem !(Chain a) | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list !(Chain a) | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector {-# UNPACK #-} !Int -- ^ number of reserved output IDs (length of the vector) !(Chain a) | VCsum {-# UNPACK #-} !Int -- ^ ID of the input vector {-# UNPACK #-} !Int -- ^ length of the vector !(Chain a) | VCreplicate {-# UNPACK #-} !Int -- ^ ID of the input scalar {-# UNPACK #-} !Int -- ^ length of the replicated vector !(Chain a) | Start deriving (Show) data Log s a = Log !Int -- ^ next ID to generate !(Chain a) -- ^ tape -- | This class does not have any instances defined, on purpose. You'll get one -- magically when you differentiate. class Taping s a where getTape :: IORef (Log s a) data Dual s a = Dual !a {-# UNPACK #-} !Int -- ^ -1 if this is a constant instance Eq a => Eq (Dual s a) where Dual x _ == Dual y _ = x == y instance Ord a => Ord (Dual s a) where compare (Dual x _) (Dual y _) = compare x y instance (Num a, Taping s a) => Num (Dual s a) where Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 signum (Dual x _) = Dual (signum x) (-1) fromInteger n = Dual (fromInteger n) (-1) instance (Fractional a, Taping s a) => Fractional (Dual s a) where Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y)) recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0 fromRational r = Dual (fromRational r) (-1) instance (Floating a, Taping s a) => Floating (Dual s a) where pi = Dual pi (-1) exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0 log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0 sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0 -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x Dual x i1 ** Dual y i2 = let z = x ** y in mkDual z i1 (z * y/x) i2 (z * log x) logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined atanh = undefined constant :: a -> Dual s a constant x = Dual x (-1) mkDual :: forall a s. Taping s a => a -> Int -> a -> Int -> a -> Dual s a mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 dy)) data VDual s a = VDual !(VS.Vector a) {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector instance (Storable a, Taping s a) => VectorOps (VDual s a) where type VectorOpsScalar (VDual s a) = Dual s a vfromListN n l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromListN n xs) (VCfromList (VS.fromListN n is)) vfromList l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) vtoList (VDual v i) = let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v) in zipWith Dual (VS.toList v) [starti..] vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n) instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v))) vconstant :: VS.Vector a -> VDual s a vconstant v = VDual v (-1) mkVDual :: forall a s. Taping s a => VS.Vector a -> (Chain a -> Chain a) -> VDual s a mkVDual res f = VDual res (writeTapeUnsafe (Proxy @s) f) {-# NOINLINE writeTapeUnsafe #-} writeTapeUnsafe :: forall a s proxy. Taping s a => proxy s -> (Chain a -> Chain a) -> Int writeTapeUnsafe _ f = unsafePerformIO $ atomicModifyIORef' (getTape @s) $ \(Log i tape) -> (Log (i + 1) (f tape), i) {-# NOINLINE allocTapeToListUnsafe #-} allocTapeToListUnsafe :: forall a s proxy. Taping s a => proxy a -> proxy s -> Int -> Int -> Int allocTapeToListUnsafe _ _ vecid len = unsafePerformIO $ atomicModifyIORef' (getTape @s @a) $ \(Log i tape) -> (Log (i + len + 1) (VCtoList vecid len tape), i)