aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-24 11:04:50 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-24 11:04:50 +0100
commitf84c5b0cab7e819cdaae8288a06641973cf83437 (patch)
tree4938a1777eb10901b2fbdb28b7e2049d0c3556ca /src/Numeric/ADDual/Array
parenta16185618aa6f483f587f8a0c65031fc479afac7 (diff)
WIP array AD
Diffstat (limited to 'src/Numeric/ADDual/Array')
-rw-r--r--src/Numeric/ADDual/Array/Internal.hs187
1 files changed, 187 insertions, 0 deletions
diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs
new file mode 100644
index 0000000..1cc2796
--- /dev/null
+++ b/src/Numeric/ADDual/Array/Internal.hs
@@ -0,0 +1,187 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+module Numeric.ADDual.Array.Internal where
+
+import Control.Monad (when)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict
+import Data.IORef
+import Data.Proxy
+import Data.Typeable
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.Storable
+import GHC.Stack
+import GHC.Exts (withDict)
+import System.IO.Unsafe
+
+import System.IO (hPutStrLn, stderr)
+
+import Numeric.ADDual.VectorOps
+
+
+-- TODO: type roles on 's'
+
+
+debug :: Bool
+debug = toEnum 0
+
+
+-- TODO: full vjp (just some more Traversable mess)
+-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate
+{-# NOINLINE gradient' #-}
+gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a)
+ => HasCallStack
+ => Show a -- TODO: remove
+ => (forall s. Taping s a => f (Dual s a) -> Dual s a)
+ -> f a -> a -> (a, f a)
+gradient' f inp topctg = unsafePerformIO $ do
+ when debug $ hPutStrLn stderr "Preparing input"
+ let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
+ -- The tape starts after the input IDs.
+ taperef <- newIORef (Log starti Start)
+
+ when debug $ hPutStrLn stderr "Running function"
+ let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp'
+ when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
+ Log _ tape <- readIORef taperef
+
+ -- when debug $ do
+ -- tapestr <- showTape (tapeTail `Snoc` lastChunk)
+ -- hPutStrLn stderr $ "tape = " ++ tapestr ""
+
+ when debug $ hPutStrLn stderr "Backpropagating"
+ accums <- VSM.new (outi+1)
+ VSM.write accums outi topctg
+
+ let backpropagate i (Cscalar i1 dx i2 dy tape') = do
+ ctg <- VSM.read accums i
+ when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
+ when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2
+ backpropagate (i-1) tape'
+ backpropagate _ Start = return ()
+
+ backpropagate outi tape
+
+ when debug $ do
+ accums' <- VS.freeze accums
+ hPutStrLn stderr $ "accums = " ++ show accums'
+
+ when debug $ hPutStrLn stderr "Reconstructing gradient"
+ let readDeriv = do i <- get
+ d <- lift $ VSM.read accums i
+ put (i+1)
+ return d
+ grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0
+
+ return (result, grad)
+
+-- | Contribution to a vector-typed value
+data VCon a = VCon {-# UNPACK #-} !Int -- ^ the ID of the vector value
+ {-# UNPACK #-} !(VS.Vector a) -- ^ the cotangent
+ | VConNothing
+ deriving (Show)
+
+data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution
+ {-# UNPACK #-} !Int !a -- ^ idem
+ !(Chain a)
+ | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list
+ !(Chain a)
+ | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector
+ {-# UNPACK #-} !Int -- ^ start of the reserved output ID range
+ {-# UNPACK #-} !Int -- ^ number of reserved output IDs (length of the vector)
+ !(Chain a)
+ | VCsum {-# UNPACK #-} !Int -- ^ ID of the input vector
+ !(Chain a)
+ | VCreplicate {-# UNPACK #-} !Int -- ^ length of the replicated vector
+ {-# UNPACK #-} !Int -- ^ ID of the input scalar
+ !(Chain a)
+ | Start
+ deriving (Show)
+
+data Log s a = Log !Int -- ^ next ID to generate
+ !(Chain a) -- ^ tape
+
+-- | This class does not have any instances defined, on purpose. You'll get one
+-- magically when you differentiate.
+class Taping s a where
+ getTape :: IORef (Log s a)
+
+data Dual s a = Dual !a
+ {-# UNPACK #-} !Int -- ^ -1 if this is a constant
+
+instance Eq a => Eq (Dual s a) where
+ Dual x _ == Dual y _ = x == y
+
+instance Ord a => Ord (Dual s a) where
+ compare (Dual x _) (Dual y _) = compare x y
+
+instance (Num a, Taping s a) => Num (Dual s a) where
+ Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1
+ Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1)
+ Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x
+ negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0
+ abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0
+ signum (Dual x _) = Dual (signum x) (-1)
+ fromInteger n = Dual (fromInteger n) (-1)
+
+instance (Fractional a, Taping s a) => Fractional (Dual s a) where
+ Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y))
+ recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0
+ fromRational r = Dual (fromRational r) (-1)
+
+instance (Floating a, Taping s a) => Floating (Dual s a) where
+ pi = Dual pi (-1)
+ exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0
+ log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0
+ sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0
+ -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x
+ -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x
+ Dual x i1 ** Dual y i2 =
+ let z = x ** y
+ in mkDual z i1 (z * y/x) i2 (z * log x)
+ logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined
+ asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined
+ cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined
+ atanh = undefined
+
+constant :: a -> Dual s a
+constant x = Dual x (-1)
+
+mkDual :: forall a s. Taping s a => a -> Int -> a -> Int -> a -> Dual s a
+mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 dy))
+
+data VDual s a = VDual !(VS.Vector a)
+ {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector
+
+instance (Storable a, Taping s a) => VectorOps (VDual s a) where
+ type VectorOpsScalar (VDual s a) = Dual s a
+ vfromListN n l =
+ let (xs, is) = unzip [(x, i) | Dual x i <- l]
+ in mkVDual (VS.fromListN n xs) (VCfromList (VS.fromListN n is))
+ vfromList l =
+ let (xs, is) = unzip [(x, i) | Dual x i <- l]
+ in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is))
+ vtoList (VDual v i) = _
+ vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate n i)
+
+instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where
+ vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i))
+
+vconstant :: VS.Vector a -> VDual s a
+vconstant v = VDual v (-1)
+
+mkVDual :: forall a s. Taping s a => VS.Vector a -> (Chain a -> Chain a) -> VDual s a
+mkVDual res f = VDual res (writeTapeUnsafe (Proxy @s) f)
+
+{-# NOINLINE writeTapeUnsafe #-}
+writeTapeUnsafe :: forall a s proxy. Taping s a
+ => proxy s -> (Chain a -> Chain a) -> Int
+writeTapeUnsafe _ f =
+ unsafePerformIO $
+ atomicModifyIORef' (getTape @s) $ \(Log i tape) ->
+ (Log (i + 1) (f tape), i)