diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-11 23:06:44 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-11 23:06:44 +0200 | 
| commit | 1f53cea6a1352db125e1897ca574360180be2550 (patch) | |
| tree | 0c933929808479eed08a3da26ff0c6d825305631 /src/Interpreter/Accum.hs | |
| parent | b728b22248414c8319681a75f1c8e8cdf8da1fb2 (diff) | |
Finish Accum implementation
Diffstat (limited to 'src/Interpreter/Accum.hs')
| -rw-r--r-- | src/Interpreter/Accum.hs | 233 | 
1 files changed, 156 insertions, 77 deletions
| diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index 45f507b..d15ea10 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -4,33 +4,46 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum where +module Interpreter.Accum ( +  AcM, +  runAcM, +  Rep, +  Accum, +  withAccum, +  accumAdd, +  inParallel, +) where -import Control.Monad.ST -import Control.Monad.ST.Unsafe +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Foreign.Storable (sizeOf)  import GHC.Exts +import GHC.Float  import GHC.Int -import GHC.ST (ST(..)) +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO)  import AST  import Data  import Interpreter.Array -import Data.Bifunctor (second) -import Control.Monad (when, forM_) -import Foreign.Storable (sizeOf) +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) -newtype AcM s a = AcM (ST s a) +newtype AcM s a = AcM (IO a)    deriving newtype (Functor, Applicative, Monad)  runAcM :: (forall s. AcM s a) -> a -runAcM m = runST (case m of AcM m' -> m') +runAcM (AcM m) = unsafePerformIO m  type family Rep t where    Rep TNil = () @@ -40,7 +53,8 @@ type family Rep t where    Rep (TScal sty) = ScalRep sty    -- Rep (TAccum t) = _ -data Accum s t = Accum (STy t) (MutableByteArray# s) +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ())  tSize :: STy t -> Rep t -> Int  tSize ty x = tSize' ty (Just x) @@ -71,11 +85,11 @@ tSize' typ val = case typ of  -- | This operation does not commute with 'accumAdd', so it must be used with  -- care. Furthermore it must be used on exactly the same value as tSize was --- called on. Hence it lives in ST, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep t -> ST s () -accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0 -  where -    go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> Rep t' -> Int -> IO Int      go inarr ty val off = case ty of        STNil -> return off        STPair a b -> do @@ -86,11 +100,11 @@ accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0          case val of            Left x -> do              let !(I8# tag#) = 0 -            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) +            writeInt8# addr# off# tag#              go inarr a x (off + 1)            Right y -> do              let !(I8# tag#) = 1 -            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) +            writeInt8# addr# off# tag#              go inarr b y (off + 1)        STArr _ t          | inarr -> error "Nested arrays not supported in this implementation" @@ -103,26 +117,29 @@ accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0        STScal sty -> goScal sty val off        STAccum{} -> error "Nested accumulators unsupported" -    goShape :: Shape n -> Int -> ST s Int +    goShape :: Shape n -> Int -> IO Int      goShape ShNil off = return off      goShape (ShCons sh n) off = do        off1@(I# off1#) <- goShape sh off        let !(I64# n'#) = fromIntegral n -      ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #) +      writeInt64# addr# off1# n'# +      return (off1 + 8) -    goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int -    goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) -    goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) -    goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) -    goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) -    goScal STBool b (I# off#) = +    goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int +    goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x +    goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x +    goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x +    goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x +    goScal STBool b off@(I# off#) = do        let !(I8# i) = fromIntegral (fromEnum b) -      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) +      off + 1 <$ writeInt8# addr# off# i + +  in () <$ go False topty top_value 0  accumRead :: forall s t. Accum s t -> AcM s (Rep t) -accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 -  where -    go :: Bool -> STy t' -> Int -> ST s (Int, Rep t') +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> Int -> IO (Int, Rep t')      go inarr ty off = case ty of        STNil -> return (off, ())        STPair a b -> do @@ -131,7 +148,7 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0          return (off1 + off2, (x, y))        STEither a b -> do          let !(I# off#) = off -        tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) +        tag <- readInt8 addr# off#          (off1, val) <- case tag of                           0 -> fmap Left <$> go inarr a (off + 1)                           1 -> fmap Right <$> go inarr b (off + 1) @@ -142,7 +159,7 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0        STArr ndim t          | inarr -> error "Nested arrays not supported in this implementation"          | otherwise -> do -            (off1, sh) <- readShape mbarr ndim off +            (off1, sh) <- readShape addr# ndim off              let eltsize = tSize' t Nothing                  n = shapeSize sh              arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) @@ -150,18 +167,22 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0        STScal sty -> goScal sty off        STAccum{} -> error "Nested accumulators unsupported" -    goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t') -    goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) -    goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) -    goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) -    goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) -    goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) +    goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') +    goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# +    goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# +    goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# +    goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# +    goScal STBool off@(I# off#) = do +      i8 <- readInt8 addr# off# +      return (off + 1, toEnum (fromIntegral i8)) + +  in snd <$> go False topty 0 -readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n) +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)  readShape _ SZ off = return (off, ShNil)  readShape mbarr (SS ndim) off = do    (off1@(I# off1#), sh) <- readShape mbarr ndim off -  n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) +  n' <- readInt64 mbarr off1#    return (off1 + 8, ShCons sh (fromIntegral n'))  -- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of @@ -174,7 +195,7 @@ data InvShape n where            -> InvShape (S n)  ishSize :: InvShape n -> Int -ishSize IShNil = 0 +ishSize IShNil = 1  ishSize (IShCons _ sz _) = sz  invertShape :: forall n. Shape n -> InvShape n @@ -185,9 +206,9 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil      go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)  accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () -accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0 -  where -    go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO ()      go inarr ty SZ () val off = () <$ performAdd inarr ty val off      go inarr ty (SS dep) idx val off = case (ty, idx, val) of        (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off @@ -199,15 +220,15 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty        (STArr rank eltty, _, _)          | inarr -> error "Nested arrays"          | otherwise -> do -            (off1, ish) <- second invertShape <$> readShape mbarr rank off +            (off1, ish) <- second invertShape <$> readShape addr# rank off              goArr (SS dep) ish eltty idx val off1        (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"        (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"        (STAccum{}, _, _) -> error "Nested accumulators unsupported"      goArr :: SNat i' -> InvShape n -> STy t' -          -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s () -    goArr SZ ish t1 () val off = performAddArr ish t1 val off +          -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO () +    goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off      goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off      goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do        let i' = fromIntegral @(Rep TIx) @Int i @@ -215,7 +236,14 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty          error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"        goArr depm1 ish t1 idx val (off + i' * ishSize ish) -    performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int +    performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int +    performAddArr arraySz eltty val off = do +      let eltsize = tSize' eltty Nothing +      forM_ [0 .. arraySz - 1] $ \lini -> +        performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) +      return (off + arraySz * eltsize) + +    performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int      performAdd inarr ty val off = case ty of        STNil -> return off        STPair t1 t2 -> do @@ -223,7 +251,7 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty          performAdd inarr t2 (snd val) off1        STEither t1 t2 -> do          let !(I# off#) = off -        tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) +        tag <- readInt8 addr# off#          off1 <- case (val, tag) of                    (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)                    (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) @@ -234,46 +262,97 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty        STArr n ty'          | inarr -> error "Nested array"          | otherwise -> do -            (off1, sh) <- readShape mbarr n off -            let sz = shapeSize sh -                eltsize = tSize' ty' Nothing -            forM_ [0 .. sz - 1] $ \lini -> -              performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize) -            return (off1 + sz * eltsize) +            (off1, sh) <- readShape addr# n off +            performAddArr (shapeSize sh) ty' val off1        STScal ty' -> performAddScal ty' val off        STAccum{} -> error "Nested accumulators unsupported" -    performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int -    performAddScal STI32 (I32# x#) (I# off#) +    performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int +    performAddScal STI32 (I32# x#) off@(I# off#)        | sizeOf (undefined :: Int) == 4 -      = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of -                     (# s', _ #) -> (# s', I# (off# +# 4#) #) +      = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))        | otherwise -      = _ -    performAddScal STI64 (I64# x#) (I# off#) +      = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) +    performAddScal STI64 (I64# x#) off@(I# off#)        | sizeOf (undefined :: Int) == 8 -      = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of -                     (# s', _ #) -> (# s', I# (off# +# 8#) #) +      = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))        | otherwise -      = _ -    performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) -    performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) -    performAddScal STBool b (I# off#) = -      let !(I8# i) = fromIntegral (fromEnum b) -      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) +      = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) +    performAddScal STF32 x off@(I# off#) = +      off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) +    performAddScal STF64 x off@(I# off#) = +      off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) +    performAddScal STBool _ off = return (off + 1)  -- don't do anything with booleans -    casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #)) -            -> (w -> w) -            -> r -            -> State# d -> (# State# d, r #) -    casLoop casOp add ret = _ +    casLoop :: Eq w +            => (Addr# -> Int# -> IO w)    -- ^ read value (from a given byte offset; will get 0#) +            -> (Addr# -> w -> w -> IO w)  -- ^ CAS value at address (expected -> desired -> IO observed) +            -> Addr#                      -- ^ Address to attempt to modify +            -> (w -> w)                   -- ^ Operation to apply to the value +            -> IO () +    casLoop readOp casOp addr modify = readOp addr 0# >>= loop +      where +        loop value = do +          value' <- casOp addr value (modify value) +          if value == value' +            then return () +            else loop value' + +  in () <$ go False topty top_depth top_index top_value 0  withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)  withAccum ty start fun = do -  let !(I# size) = tSize ty start -  accum <- AcM . ST $ \s -> case newByteArray# size s of -                              (# s', mbarr #) -> (# s', Accum ty mbarr #) -  AcM $ accumWrite accum start +  -- The initial write must happen before any of the adds or reads, so it makes +  -- sense to put it in IO together with the allocation, instead of in AcM. +  accum <- AcM $ do buffer <- mallocBytes (tSize ty start) +                    ptr <- newForeignPtr finalizerFree buffer +                    let accum = Accum ty ptr +                    accumWrite accum start +                    return accum    b <- fun accum    out <- accumRead accum    return (out, b) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do +  mvars <- mapM (\_ -> newEmptyMVar) actions +  forM_ (zip actions mvars) $ \(AcM action, var) -> +    forkIO $ action >>= putMVar var +  mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8   :: Addr# -> Int# -> IO Int8 +readInt32  :: Addr# -> Int# -> IO Int32 +readInt64  :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat  :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8   addr off# = IO $ \s -> case readInt8OffAddr#   (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8#  val #) +readInt32  addr off# = IO $ \s -> case readInt32OffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64  addr off# = IO $ \s -> case readInt64OffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat  addr off# = IO $ \s -> case readFloatOffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F#   val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D#   val #) + +writeInt8#   :: Addr# -> Int# -> Int8#   -> IO () +writeInt32#  :: Addr# -> Int# -> Int32#  -> IO () +writeInt64#  :: Addr# -> Int# -> Int64#  -> IO () +writeFloat#  :: Addr# -> Int# -> Float#  -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8#   addr off# val = IO $ \s -> (# writeInt8OffAddr#   (addr `plusAddr#` off#) 0# val s, () #) +writeInt32#  addr off# val = IO $ \s -> (# writeInt32OffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeInt64#  addr off# val = IO $ \s -> (# writeInt64OffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeFloat#  addr off# val = IO $ \s -> (# writeFloatOffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = +  IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = +  IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) | 
