From f84c5b0cab7e819cdaae8288a06641973cf83437 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 24 Feb 2025 11:04:50 +0100 Subject: WIP array AD --- ad-dual.cabal | 2 + src/Numeric/ADDual/Array/Internal.hs | 187 +++++++++++++++++++++++++++++++++++ src/Numeric/ADDual/VectorOps.hs | 60 +++++++++++ 3 files changed, 249 insertions(+) create mode 100644 src/Numeric/ADDual/Array/Internal.hs create mode 100644 src/Numeric/ADDual/VectorOps.hs diff --git a/ad-dual.cabal b/ad-dual.cabal index 9880744..60c442d 100644 --- a/ad-dual.cabal +++ b/ad-dual.cabal @@ -15,6 +15,8 @@ library exposed-modules: Numeric.ADDual Numeric.ADDual.Internal + Numeric.ADDual.Array.Internal + Numeric.ADDual.VectorOps other-modules: c-sources: cbits/backprop.c cc-options: -O3 -Wall -Wextra -std=c99 diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs new file mode 100644 index 0000000..1cc2796 --- /dev/null +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -0,0 +1,187 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +module Numeric.ADDual.Array.Internal where + +import Control.Monad (when) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Data.IORef +import Data.Proxy +import Data.Typeable +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Storable +import GHC.Stack +import GHC.Exts (withDict) +import System.IO.Unsafe + +import System.IO (hPutStrLn, stderr) + +import Numeric.ADDual.VectorOps + + +-- TODO: type roles on 's' + + +debug :: Bool +debug = toEnum 0 + + +-- TODO: full vjp (just some more Traversable mess) +-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate +{-# NOINLINE gradient' #-} +gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) + => HasCallStack + => Show a -- TODO: remove + => (forall s. Taping s a => f (Dual s a) -> Dual s a) + -> f a -> a -> (a, f a) +gradient' f inp topctg = unsafePerformIO $ do + when debug $ hPutStrLn stderr "Preparing input" + let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 + -- The tape starts after the input IDs. + taperef <- newIORef (Log starti Start) + + when debug $ hPutStrLn stderr "Running function" + let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' + when debug $ hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi + Log _ tape <- readIORef taperef + + -- when debug $ do + -- tapestr <- showTape (tapeTail `Snoc` lastChunk) + -- hPutStrLn stderr $ "tape = " ++ tapestr "" + + when debug $ hPutStrLn stderr "Backpropagating" + accums <- VSM.new (outi+1) + VSM.write accums outi topctg + + let backpropagate i (Cscalar i1 dx i2 dy tape') = do + ctg <- VSM.read accums i + when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 + when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 + backpropagate (i-1) tape' + backpropagate _ Start = return () + + backpropagate outi tape + + when debug $ do + accums' <- VS.freeze accums + hPutStrLn stderr $ "accums = " ++ show accums' + + when debug $ hPutStrLn stderr "Reconstructing gradient" + let readDeriv = do i <- get + d <- lift $ VSM.read accums i + put (i+1) + return d + grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + + return (result, grad) + +-- | Contribution to a vector-typed value +data VCon a = VCon {-# UNPACK #-} !Int -- ^ the ID of the vector value + {-# UNPACK #-} !(VS.Vector a) -- ^ the cotangent + | VConNothing + deriving (Show) + +data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution + {-# UNPACK #-} !Int !a -- ^ idem + !(Chain a) + | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list + !(Chain a) + | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector + {-# UNPACK #-} !Int -- ^ start of the reserved output ID range + {-# UNPACK #-} !Int -- ^ number of reserved output IDs (length of the vector) + !(Chain a) + | VCsum {-# UNPACK #-} !Int -- ^ ID of the input vector + !(Chain a) + | VCreplicate {-# UNPACK #-} !Int -- ^ length of the replicated vector + {-# UNPACK #-} !Int -- ^ ID of the input scalar + !(Chain a) + | Start + deriving (Show) + +data Log s a = Log !Int -- ^ next ID to generate + !(Chain a) -- ^ tape + +-- | This class does not have any instances defined, on purpose. You'll get one +-- magically when you differentiate. +class Taping s a where + getTape :: IORef (Log s a) + +data Dual s a = Dual !a + {-# UNPACK #-} !Int -- ^ -1 if this is a constant + +instance Eq a => Eq (Dual s a) where + Dual x _ == Dual y _ = x == y + +instance Ord a => Ord (Dual s a) where + compare (Dual x _) (Dual y _) = compare x y + +instance (Num a, Taping s a) => Num (Dual s a) where + Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 + Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) + Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x + negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 + abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 + signum (Dual x _) = Dual (signum x) (-1) + fromInteger n = Dual (fromInteger n) (-1) + +instance (Fractional a, Taping s a) => Fractional (Dual s a) where + Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y)) + recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0 + fromRational r = Dual (fromRational r) (-1) + +instance (Floating a, Taping s a) => Floating (Dual s a) where + pi = Dual pi (-1) + exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0 + log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0 + sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0 + -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x + -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x + Dual x i1 ** Dual y i2 = + let z = x ** y + in mkDual z i1 (z * y/x) i2 (z * log x) + logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined + asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined + cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined + atanh = undefined + +constant :: a -> Dual s a +constant x = Dual x (-1) + +mkDual :: forall a s. Taping s a => a -> Int -> a -> Int -> a -> Dual s a +mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 dy)) + +data VDual s a = VDual !(VS.Vector a) + {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector + +instance (Storable a, Taping s a) => VectorOps (VDual s a) where + type VectorOpsScalar (VDual s a) = Dual s a + vfromListN n l = + let (xs, is) = unzip [(x, i) | Dual x i <- l] + in mkVDual (VS.fromListN n xs) (VCfromList (VS.fromListN n is)) + vfromList l = + let (xs, is) = unzip [(x, i) | Dual x i <- l] + in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) + vtoList (VDual v i) = _ + vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate n i) + +instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where + vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i)) + +vconstant :: VS.Vector a -> VDual s a +vconstant v = VDual v (-1) + +mkVDual :: forall a s. Taping s a => VS.Vector a -> (Chain a -> Chain a) -> VDual s a +mkVDual res f = VDual res (writeTapeUnsafe (Proxy @s) f) + +{-# NOINLINE writeTapeUnsafe #-} +writeTapeUnsafe :: forall a s proxy. Taping s a + => proxy s -> (Chain a -> Chain a) -> Int +writeTapeUnsafe _ f = + unsafePerformIO $ + atomicModifyIORef' (getTape @s) $ \(Log i tape) -> + (Log (i + 1) (f tape), i) diff --git a/src/Numeric/ADDual/VectorOps.hs b/src/Numeric/ADDual/VectorOps.hs new file mode 100644 index 0000000..38063e0 --- /dev/null +++ b/src/Numeric/ADDual/VectorOps.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE TypeFamilies #-} +module Numeric.ADDual.VectorOps where + +import Data.Kind (Type) +import qualified Data.Vector as V +import qualified Data.Vector.Strict as VSr +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Unboxed as VU +import Foreign.Storable (Storable) + + +class VectorOps v where + type VectorOpsScalar v :: Type + vfromListN :: Int -> [VectorOpsScalar v] -> v + vfromList :: [VectorOpsScalar v] -> v + vtoList :: v -> [VectorOpsScalar v] + vreplicate :: Int -> VectorOpsScalar v -> v + +class VectorOpsNum v where + vsum :: v -> VectorOpsScalar v + +instance VectorOps (V.Vector a) where + type VectorOpsScalar (V.Vector a) = a + vfromListN = V.fromListN + vfromList = V.fromList + vtoList = V.toList + vreplicate = V.replicate + +instance Num a => VectorOpsNum (V.Vector a) where + vsum = V.sum + +instance VectorOps (VSr.Vector a) where + type VectorOpsScalar (VSr.Vector a) = a + vfromListN = VSr.fromListN + vfromList = VSr.fromList + vtoList = VSr.toList + vreplicate = VSr.replicate + +instance Num a => VectorOpsNum (VSr.Vector a) where + vsum = VSr.sum + +instance Storable a => VectorOps (VS.Vector a) where + type VectorOpsScalar (VS.Vector a) = a + vfromListN = VS.fromListN + vfromList = VS.fromList + vtoList = VS.toList + vreplicate = VS.replicate + +instance (Storable a, Num a) => VectorOpsNum (VS.Vector a) where + vsum = VS.sum + +instance VU.Unbox a => VectorOps (VU.Vector a) where + type VectorOpsScalar (VU.Vector a) = a + vfromListN = VU.fromListN + vfromList = VU.fromList + vtoList = VU.toList + vreplicate = VU.replicate + +instance (VU.Unbox a, Num a) => VectorOpsNum (VU.Vector a) where + vsum = VU.sum -- cgit v1.2.3-70-g09d2