diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Numeric/ADDual/Array/Internal.hs | 187 | ||||
| -rw-r--r-- | src/Numeric/ADDual/VectorOps.hs | 60 | 
2 files changed, 247 insertions, 0 deletions
diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs new file mode 100644 index 0000000..1cc2796 --- /dev/null +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -0,0 +1,187 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +module Numeric.ADDual.Array.Internal where + +import Control.Monad (when) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Data.IORef +import Data.Proxy +import Data.Typeable +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Storable +import GHC.Stack +import GHC.Exts (withDict) +import System.IO.Unsafe + +import System.IO (hPutStrLn, stderr) + +import Numeric.ADDual.VectorOps + + +-- TODO: type roles on 's' + + +debug :: Bool +debug = toEnum 0 + + +-- TODO: full vjp (just some more Traversable mess) +-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate +{-# NOINLINE gradient' #-} +gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) +          => HasCallStack +          => Show a  -- TODO: remove +          => (forall s. Taping s a => f (Dual s a) -> Dual s a) +          -> f a -> a -> (a, f a) +gradient' f inp topctg = unsafePerformIO $ do +  when debug $ hPutStrLn stderr "Preparing input" +  let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 +  -- The tape starts after the input IDs. +  taperef <- newIORef (Log starti Start) + +  when debug $ hPutStrLn stderr "Running function" +  let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' +  when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi +  Log _ tape <- readIORef taperef + +  -- when debug $ do +  --   tapestr <- showTape (tapeTail `Snoc` lastChunk) +  --   hPutStrLn stderr $ "tape = " ++ tapestr "" + +  when debug $ hPutStrLn stderr "Backpropagating" +  accums <- VSM.new (outi+1) +  VSM.write accums outi topctg + +  let backpropagate i (Cscalar i1 dx i2 dy tape') = do +        ctg <- VSM.read accums i +        when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 +        when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 +        backpropagate (i-1) tape' +      backpropagate _ Start = return () + +  backpropagate outi tape + +  when debug $ do +    accums' <- VS.freeze accums +    hPutStrLn stderr $ "accums = " ++ show accums' + +  when debug $ hPutStrLn stderr "Reconstructing gradient" +  let readDeriv = do i <- get +                     d <- lift $ VSM.read accums i +                     put (i+1) +                     return d +  grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + +  return (result, grad) + +-- | Contribution to a vector-typed value +data VCon a = VCon {-# UNPACK #-} !Int  -- ^ the ID of the vector value +                   {-# UNPACK #-} !(VS.Vector a)  -- ^ the cotangent +            | VConNothing +  deriving (Show) + +data Chain a = Cscalar {-# UNPACK #-} !Int !a  -- ^ ID == -1 -> no contribution +                       {-# UNPACK #-} !Int !a  -- ^ idem +                       !(Chain a) +             | VCfromList {-# UNPACK #-} !(VS.Vector Int)  -- ^ IDs of scalars in the input list +                          !(Chain a) +             | VCtoList {-# UNPACK #-} !Int  -- ^ ID of the input vector +                        {-# UNPACK #-} !Int  -- ^ start of the reserved output ID range +                        {-# UNPACK #-} !Int  -- ^ number of reserved output IDs (length of the vector) +                        !(Chain a) +             | VCsum {-# UNPACK #-} !Int  -- ^ ID of the input vector +                     !(Chain a) +             | VCreplicate {-# UNPACK #-} !Int  -- ^ length of the replicated vector +                           {-# UNPACK #-} !Int  -- ^ ID of the input scalar +                           !(Chain a) +             | Start +  deriving (Show) + +data Log s a = Log !Int  -- ^ next ID to generate +                   !(Chain a)  -- ^ tape + +-- | This class does not have any instances defined, on purpose. You'll get one +-- magically when you differentiate. +class Taping s a where +  getTape :: IORef (Log s a) + +data Dual s a = Dual !a +                     {-# UNPACK #-} !Int  -- ^ -1 if this is a constant + +instance Eq a => Eq (Dual s a) where +  Dual x _ == Dual y _ = x == y + +instance Ord a => Ord (Dual s a) where +  compare (Dual x _) (Dual y _) = compare x y + +instance (Num a, Taping s a) => Num (Dual s a) where +  Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 +  Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) +  Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x +  negate (Dual x i1)    = mkDual (negate x) i1 (-1) (-1) 0 +  abs (Dual x i1)       = mkDual (abs x) i1 (x * signum x) (-1) 0 +  signum (Dual x _) = Dual (signum x) (-1) +  fromInteger n = Dual (fromInteger n) (-1) + +instance (Fractional a, Taping s a) => Fractional (Dual s a) where +  Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y)) +  recip (Dual x i1)     = mkDual (recip x) i1 (-1/(x*x)) (-1) 0 +  fromRational r = Dual (fromRational r) (-1) + +instance (Floating a, Taping s a) => Floating (Dual s a) where +  pi = Dual pi (-1) +  exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0 +  log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0 +  sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0 +  -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x +  -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x +  Dual x i1 ** Dual y i2 = +    let z = x ** y +    in mkDual z i1 (z * y/x) i2 (z * log x) +  logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined +  asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined +  cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined +  atanh = undefined + +constant :: a -> Dual s a +constant x = Dual x (-1) + +mkDual :: forall a s. Taping s a => a -> Int -> a -> Int -> a -> Dual s a +mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 dy)) + +data VDual s a = VDual !(VS.Vector a) +                       {-# UNPACK #-} !Int  -- ^ -1 if this is a constant vector + +instance (Storable a, Taping s a) => VectorOps (VDual s a) where +  type VectorOpsScalar (VDual s a) = Dual s a +  vfromListN n l = +    let (xs, is) = unzip [(x, i) | Dual x i <- l] +    in mkVDual (VS.fromListN n xs) (VCfromList (VS.fromListN n is)) +  vfromList l = +    let (xs, is) = unzip [(x, i) | Dual x i <- l] +    in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) +  vtoList (VDual v i) = _ +  vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate n i) + +instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where +  vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i)) + +vconstant :: VS.Vector a -> VDual s a +vconstant v = VDual v (-1) + +mkVDual :: forall a s. Taping s a => VS.Vector a -> (Chain a -> Chain a) -> VDual s a +mkVDual res f = VDual res (writeTapeUnsafe (Proxy @s) f) + +{-# NOINLINE writeTapeUnsafe #-} +writeTapeUnsafe :: forall a s proxy. Taping s a +                => proxy s -> (Chain a -> Chain a) -> Int +writeTapeUnsafe _ f = +  unsafePerformIO $ +    atomicModifyIORef' (getTape @s) $ \(Log i tape) -> +      (Log (i + 1) (f tape), i) diff --git a/src/Numeric/ADDual/VectorOps.hs b/src/Numeric/ADDual/VectorOps.hs new file mode 100644 index 0000000..38063e0 --- /dev/null +++ b/src/Numeric/ADDual/VectorOps.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE TypeFamilies #-} +module Numeric.ADDual.VectorOps where + +import Data.Kind (Type) +import qualified Data.Vector as V +import qualified Data.Vector.Strict as VSr +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Unboxed as VU +import Foreign.Storable (Storable) + + +class VectorOps v where +  type VectorOpsScalar v :: Type +  vfromListN :: Int -> [VectorOpsScalar v] -> v +  vfromList :: [VectorOpsScalar v] -> v +  vtoList :: v -> [VectorOpsScalar v] +  vreplicate :: Int -> VectorOpsScalar v -> v + +class VectorOpsNum v where +  vsum :: v -> VectorOpsScalar v + +instance VectorOps (V.Vector a) where +  type VectorOpsScalar (V.Vector a) = a +  vfromListN = V.fromListN +  vfromList = V.fromList +  vtoList = V.toList +  vreplicate = V.replicate + +instance Num a => VectorOpsNum (V.Vector a) where +  vsum = V.sum + +instance VectorOps (VSr.Vector a) where +  type VectorOpsScalar (VSr.Vector a) = a +  vfromListN = VSr.fromListN +  vfromList = VSr.fromList +  vtoList = VSr.toList +  vreplicate = VSr.replicate + +instance Num a => VectorOpsNum (VSr.Vector a) where +  vsum = VSr.sum + +instance Storable a => VectorOps (VS.Vector a) where +  type VectorOpsScalar (VS.Vector a) = a +  vfromListN = VS.fromListN +  vfromList = VS.fromList +  vtoList = VS.toList +  vreplicate = VS.replicate + +instance (Storable a, Num a) => VectorOpsNum (VS.Vector a) where +  vsum = VS.sum + +instance VU.Unbox a => VectorOps (VU.Vector a) where +  type VectorOpsScalar (VU.Vector a) = a +  vfromListN = VU.fromListN +  vfromList = VU.fromList +  vtoList = VU.toList +  vreplicate = VU.replicate + +instance (VU.Unbox a, Num a) => VectorOpsNum (VU.Vector a) where +  vsum = VU.sum  | 
