summaryrefslogtreecommitdiff
path: root/src/Interpreter/Accum.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter/Accum.hs')
-rw-r--r--src/Interpreter/Accum.hs235
1 files changed, 157 insertions, 78 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index 45f507b..d15ea10 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -4,33 +4,46 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
-module Interpreter.Accum where
+module Interpreter.Accum (
+ AcM,
+ runAcM,
+ Rep,
+ Accum,
+ withAccum,
+ accumAdd,
+ inParallel,
+) where
-import Control.Monad.ST
-import Control.Monad.ST.Unsafe
+import Control.Concurrent
+import Control.Monad (when, forM_)
+import Data.Bifunctor (second)
+import Foreign.Storable (sizeOf)
import GHC.Exts
+import GHC.Float
import GHC.Int
-import GHC.ST (ST(..))
+import GHC.IO (IO(..))
+import GHC.Word
+import System.IO.Unsafe (unsafePerformIO)
import AST
import Data
import Interpreter.Array
-import Data.Bifunctor (second)
-import Control.Monad (when, forM_)
-import Foreign.Storable (sizeOf)
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
-newtype AcM s a = AcM (ST s a)
+newtype AcM s a = AcM (IO a)
deriving newtype (Functor, Applicative, Monad)
runAcM :: (forall s. AcM s a) -> a
-runAcM m = runST (case m of AcM m' -> m')
+runAcM (AcM m) = unsafePerformIO m
type family Rep t where
Rep TNil = ()
@@ -40,7 +53,8 @@ type family Rep t where
Rep (TScal sty) = ScalRep sty
-- Rep (TAccum t) = _
-data Accum s t = Accum (STy t) (MutableByteArray# s)
+-- | Floats and integers are accumulated; booleans are left as-is.
+data Accum s t = Accum (STy t) (ForeignPtr ())
tSize :: STy t -> Rep t -> Int
tSize ty x = tSize' ty (Just x)
@@ -71,11 +85,11 @@ tSize' typ val = case typ of
-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
--- called on. Hence it lives in ST, not in AcM.
-accumWrite :: forall s t. Accum s t -> Rep t -> ST s ()
-accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0
- where
- go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int
+-- called on. Hence it lives in IO, not in AcM.
+accumWrite :: forall s t. Accum s t -> Rep t -> IO ()
+accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Rep t' -> Int -> IO Int
go inarr ty val off = case ty of
STNil -> return off
STPair a b -> do
@@ -86,11 +100,11 @@ accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0
case val of
Left x -> do
let !(I8# tag#) = 0
- ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
+ writeInt8# addr# off# tag#
go inarr a x (off + 1)
Right y -> do
let !(I8# tag#) = 1
- ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
+ writeInt8# addr# off# tag#
go inarr b y (off + 1)
STArr _ t
| inarr -> error "Nested arrays not supported in this implementation"
@@ -103,26 +117,29 @@ accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0
STScal sty -> goScal sty val off
STAccum{} -> error "Nested accumulators unsupported"
- goShape :: Shape n -> Int -> ST s Int
+ goShape :: Shape n -> Int -> IO Int
goShape ShNil off = return off
goShape (ShCons sh n) off = do
off1@(I# off1#) <- goShape sh off
let !(I64# n'#) = fromIntegral n
- ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #)
-
- goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int
- goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
- goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
- goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
- goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
- goScal STBool b (I# off#) =
+ writeInt64# addr# off1# n'#
+ return (off1 + 8)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
+ goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
+ goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
+ goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
+ goScal STBool b off@(I# off#) = do
let !(I8# i) = fromIntegral (fromEnum b)
- in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)
+ off + 1 <$ writeInt8# addr# off# i
+
+ in () <$ go False topty top_value 0
accumRead :: forall s t. Accum s t -> AcM s (Rep t)
-accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0
- where
- go :: Bool -> STy t' -> Int -> ST s (Int, Rep t')
+accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep t')
go inarr ty off = case ty of
STNil -> return (off, ())
STPair a b -> do
@@ -131,7 +148,7 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0
return (off1 + off2, (x, y))
STEither a b -> do
let !(I# off#) = off
- tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
+ tag <- readInt8 addr# off#
(off1, val) <- case tag of
0 -> fmap Left <$> go inarr a (off + 1)
1 -> fmap Right <$> go inarr b (off + 1)
@@ -142,7 +159,7 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
- (off1, sh) <- readShape mbarr ndim off
+ (off1, sh) <- readShape addr# ndim off
let eltsize = tSize' t Nothing
n = shapeSize sh
arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
@@ -150,18 +167,22 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0
STScal sty -> goScal sty off
STAccum{} -> error "Nested accumulators unsupported"
- goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t')
- goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #)
- goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #)
- goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #)
- goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #)
- goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #)
+ goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
+ goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
+ goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
+ goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
+ goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
+ goScal STBool off@(I# off#) = do
+ i8 <- readInt8 addr# off#
+ return (off + 1, toEnum (fromIntegral i8))
-readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n)
+ in snd <$> go False topty 0
+
+readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
readShape _ SZ off = return (off, ShNil)
readShape mbarr (SS ndim) off = do
(off1@(I# off1#), sh) <- readShape mbarr ndim off
- n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #)
+ n' <- readInt64 mbarr off1#
return (off1 + 8, ShCons sh (fromIntegral n'))
-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
@@ -174,7 +195,7 @@ data InvShape n where
-> InvShape (S n)
ishSize :: InvShape n -> Int
-ishSize IShNil = 0
+ishSize IShNil = 1
ishSize (IShCons _ sz _) = sz
invertShape :: forall n. Shape n -> InvShape n
@@ -185,9 +206,9 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil
go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
-accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0
- where
- go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s ()
+accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO ()
go inarr ty SZ () val off = () <$ performAdd inarr ty val off
go inarr ty (SS dep) idx val off = case (ty, idx, val) of
(STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
@@ -199,15 +220,15 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty
(STArr rank eltty, _, _)
| inarr -> error "Nested arrays"
| otherwise -> do
- (off1, ish) <- second invertShape <$> readShape mbarr rank off
+ (off1, ish) <- second invertShape <$> readShape addr# rank off
goArr (SS dep) ish eltty idx val off1
(STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
(STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
(STAccum{}, _, _) -> error "Nested accumulators unsupported"
goArr :: SNat i' -> InvShape n -> STy t'
- -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s ()
- goArr SZ ish t1 () val off = performAddArr ish t1 val off
+ -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO ()
+ goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
let i' = fromIntegral @(Rep TIx) @Int i
@@ -215,7 +236,14 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty
error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
goArr depm1 ish t1 idx val (off + i' * ishSize ish)
- performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int
+ performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int
+ performAddArr arraySz eltty val off = do
+ let eltsize = tSize' eltty Nothing
+ forM_ [0 .. arraySz - 1] $ \lini ->
+ performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
+ return (off + arraySz * eltsize)
+
+ performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int
performAdd inarr ty val off = case ty of
STNil -> return off
STPair t1 t2 -> do
@@ -223,7 +251,7 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty
performAdd inarr t2 (snd val) off1
STEither t1 t2 -> do
let !(I# off#) = off
- tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
+ tag <- readInt8 addr# off#
off1 <- case (val, tag) of
(Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
(Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
@@ -234,46 +262,97 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty
STArr n ty'
| inarr -> error "Nested array"
| otherwise -> do
- (off1, sh) <- readShape mbarr n off
- let sz = shapeSize sh
- eltsize = tSize' ty' Nothing
- forM_ [0 .. sz - 1] $ \lini ->
- performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize)
- return (off1 + sz * eltsize)
+ (off1, sh) <- readShape addr# n off
+ performAddArr (shapeSize sh) ty' val off1
STScal ty' -> performAddScal ty' val off
STAccum{} -> error "Nested accumulators unsupported"
- performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int
- performAddScal STI32 (I32# x#) (I# off#)
+ performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ performAddScal STI32 (I32# x#) off@(I# off#)
| sizeOf (undefined :: Int) == 4
- = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of
- (# s', _ #) -> (# s', I# (off# +# 4#) #)
+ = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
| otherwise
- = _
- performAddScal STI64 (I64# x#) (I# off#)
+ = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
+ performAddScal STI64 (I64# x#) off@(I# off#)
| sizeOf (undefined :: Int) == 8
- = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of
- (# s', _ #) -> (# s', I# (off# +# 8#) #)
+ = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
| otherwise
- = _
- performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
- performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
- performAddScal STBool b (I# off#) =
- let !(I8# i) = fromIntegral (fromEnum b)
- in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)
+ = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
+ performAddScal STF32 x off@(I# off#) =
+ off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
+ performAddScal STF64 x off@(I# off#) =
+ off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
+ performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
+
+ casLoop :: Eq w
+ => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
+ -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
+ -> Addr# -- ^ Address to attempt to modify
+ -> (w -> w) -- ^ Operation to apply to the value
+ -> IO ()
+ casLoop readOp casOp addr modify = readOp addr 0# >>= loop
+ where
+ loop value = do
+ value' <- casOp addr value (modify value)
+ if value == value'
+ then return ()
+ else loop value'
- casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #))
- -> (w -> w)
- -> r
- -> State# d -> (# State# d, r #)
- casLoop casOp add ret = _
+ in () <$ go False topty top_depth top_index top_value 0
withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
withAccum ty start fun = do
- let !(I# size) = tSize ty start
- accum <- AcM . ST $ \s -> case newByteArray# size s of
- (# s', mbarr #) -> (# s', Accum ty mbarr #)
- AcM $ accumWrite accum start
+ -- The initial write must happen before any of the adds or reads, so it makes
+ -- sense to put it in IO together with the allocation, instead of in AcM.
+ accum <- AcM $ do buffer <- mallocBytes (tSize ty start)
+ ptr <- newForeignPtr finalizerFree buffer
+ let accum = Accum ty ptr
+ accumWrite accum start
+ return accum
b <- fun accum
out <- accumRead accum
return (out, b)
+
+inParallel :: [AcM s t] -> AcM s [t]
+inParallel actions = AcM $ do
+ mvars <- mapM (\_ -> newEmptyMVar) actions
+ forM_ (zip actions mvars) $ \(AcM action, var) ->
+ forkIO $ action >>= putMVar var
+ mapM takeMVar mvars
+
+-- | Offset is in bytes.
+readInt8 :: Addr# -> Int# -> IO Int8
+readInt32 :: Addr# -> Int# -> IO Int32
+readInt64 :: Addr# -> Int# -> IO Int64
+readWord32 :: Addr# -> Int# -> IO Word32
+readWord64 :: Addr# -> Int# -> IO Word64
+readFloat :: Addr# -> Int# -> IO Float
+readDouble :: Addr# -> Int# -> IO Double
+readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
+readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
+readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
+readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
+readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
+readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
+readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
+
+writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
+writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
+writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
+writeFloat# :: Addr# -> Int# -> Float# -> IO ()
+writeDouble# :: Addr# -> Int# -> Double# -> IO ()
+writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+
+fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
+fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
+
+atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
+atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
+atomicCasWord32Addr addr (W32# expected) (W32# desired) =
+ IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
+atomicCasWord64Addr addr (W64# expected) (W64# desired) =
+ IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)