diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | src/Interpreter.hs | 5 | ||||
-rw-r--r-- | src/Interpreter/Accum.hs | 235 |
3 files changed, 163 insertions, 78 deletions
@@ -1 +1,2 @@ dist-newstyle/ +cabal.project.local diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d1dfb1d..afc50f9 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,3 +1,8 @@ module Interpreter where +import AST +import Interpreter.Array +import Interpreter.Accum + + diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index 45f507b..d15ea10 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -4,33 +4,46 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum where +module Interpreter.Accum ( + AcM, + runAcM, + Rep, + Accum, + withAccum, + accumAdd, + inParallel, +) where -import Control.Monad.ST -import Control.Monad.ST.Unsafe +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Foreign.Storable (sizeOf) import GHC.Exts +import GHC.Float import GHC.Int -import GHC.ST (ST(..)) +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) import AST import Data import Interpreter.Array -import Data.Bifunctor (second) -import Control.Monad (when, forM_) -import Foreign.Storable (sizeOf) +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) -newtype AcM s a = AcM (ST s a) +newtype AcM s a = AcM (IO a) deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a -runAcM m = runST (case m of AcM m' -> m') +runAcM (AcM m) = unsafePerformIO m type family Rep t where Rep TNil = () @@ -40,7 +53,8 @@ type family Rep t where Rep (TScal sty) = ScalRep sty -- Rep (TAccum t) = _ -data Accum s t = Accum (STy t) (MutableByteArray# s) +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) tSize :: STy t -> Rep t -> Int tSize ty x = tSize' ty (Just x) @@ -71,11 +85,11 @@ tSize' typ val = case typ of -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was --- called on. Hence it lives in ST, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep t -> ST s () -accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0 - where - go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Rep t' -> Int -> IO Int go inarr ty val off = case ty of STNil -> return off STPair a b -> do @@ -86,11 +100,11 @@ accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0 case val of Left x -> do let !(I8# tag#) = 0 - ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) + writeInt8# addr# off# tag# go inarr a x (off + 1) Right y -> do let !(I8# tag#) = 1 - ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) + writeInt8# addr# off# tag# go inarr b y (off + 1) STArr _ t | inarr -> error "Nested arrays not supported in this implementation" @@ -103,26 +117,29 @@ accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0 STScal sty -> goScal sty val off STAccum{} -> error "Nested accumulators unsupported" - goShape :: Shape n -> Int -> ST s Int + goShape :: Shape n -> Int -> IO Int goShape ShNil off = return off goShape (ShCons sh n) off = do off1@(I# off1#) <- goShape sh off let !(I64# n'#) = fromIntegral n - ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #) - - goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int - goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) - goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) - goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) - goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) - goScal STBool b (I# off#) = + writeInt64# addr# off1# n'# + return (off1 + 8) + + goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x + goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x + goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x + goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x + goScal STBool b off@(I# off#) = do let !(I8# i) = fromIntegral (fromEnum b) - in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + off + 1 <$ writeInt8# addr# off# i + + in () <$ go False topty top_value 0 accumRead :: forall s t. Accum s t -> AcM s (Rep t) -accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 - where - go :: Bool -> STy t' -> Int -> ST s (Int, Rep t') +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Int -> IO (Int, Rep t') go inarr ty off = case ty of STNil -> return (off, ()) STPair a b -> do @@ -131,7 +148,7 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 return (off1 + off2, (x, y)) STEither a b -> do let !(I# off#) = off - tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) + tag <- readInt8 addr# off# (off1, val) <- case tag of 0 -> fmap Left <$> go inarr a (off + 1) 1 -> fmap Right <$> go inarr b (off + 1) @@ -142,7 +159,7 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do - (off1, sh) <- readShape mbarr ndim off + (off1, sh) <- readShape addr# ndim off let eltsize = tSize' t Nothing n = shapeSize sh arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) @@ -150,18 +167,22 @@ accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 STScal sty -> goScal sty off STAccum{} -> error "Nested accumulators unsupported" - goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t') - goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) - goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) - goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) - goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) - goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) + goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') + goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# + goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# + goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# + goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# + goScal STBool off@(I# off#) = do + i8 <- readInt8 addr# off# + return (off + 1, toEnum (fromIntegral i8)) -readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n) + in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) readShape _ SZ off = return (off, ShNil) readShape mbarr (SS ndim) off = do (off1@(I# off1#), sh) <- readShape mbarr ndim off - n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) + n' <- readInt64 mbarr off1# return (off1 + 8, ShCons sh (fromIntegral n')) -- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of @@ -174,7 +195,7 @@ data InvShape n where -> InvShape (S n) ishSize :: InvShape n -> Int -ishSize IShNil = 0 +ishSize IShNil = 1 ishSize (IShCons _ sz _) = sz invertShape :: forall n. Shape n -> InvShape n @@ -185,9 +206,9 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () -accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0 - where - go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO () go inarr ty SZ () val off = () <$ performAdd inarr ty val off go inarr ty (SS dep) idx val off = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off @@ -199,15 +220,15 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty (STArr rank eltty, _, _) | inarr -> error "Nested arrays" | otherwise -> do - (off1, ish) <- second invertShape <$> readShape mbarr rank off + (off1, ish) <- second invertShape <$> readShape addr# rank off goArr (SS dep) ish eltty idx val off1 (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" (STAccum{}, _, _) -> error "Nested accumulators unsupported" goArr :: SNat i' -> InvShape n -> STy t' - -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s () - goArr SZ ish t1 () val off = performAddArr ish t1 val off + -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO () + goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do let i' = fromIntegral @(Rep TIx) @Int i @@ -215,7 +236,14 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" goArr depm1 ish t1 idx val (off + i' * ishSize ish) - performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int + performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int + performAddArr arraySz eltty val off = do + let eltsize = tSize' eltty Nothing + forM_ [0 .. arraySz - 1] $ \lini -> + performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) + return (off + arraySz * eltsize) + + performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int performAdd inarr ty val off = case ty of STNil -> return off STPair t1 t2 -> do @@ -223,7 +251,7 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty performAdd inarr t2 (snd val) off1 STEither t1 t2 -> do let !(I# off#) = off - tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) + tag <- readInt8 addr# off# off1 <- case (val, tag) of (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) @@ -234,46 +262,97 @@ accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty STArr n ty' | inarr -> error "Nested array" | otherwise -> do - (off1, sh) <- readShape mbarr n off - let sz = shapeSize sh - eltsize = tSize' ty' Nothing - forM_ [0 .. sz - 1] $ \lini -> - performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize) - return (off1 + sz * eltsize) + (off1, sh) <- readShape addr# n off + performAddArr (shapeSize sh) ty' val off1 STScal ty' -> performAddScal ty' val off STAccum{} -> error "Nested accumulators unsupported" - performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int - performAddScal STI32 (I32# x#) (I# off#) + performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + performAddScal STI32 (I32# x#) off@(I# off#) | sizeOf (undefined :: Int) == 4 - = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of - (# s', _ #) -> (# s', I# (off# +# 4#) #) + = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) | otherwise - = _ - performAddScal STI64 (I64# x#) (I# off#) + = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) + performAddScal STI64 (I64# x#) off@(I# off#) | sizeOf (undefined :: Int) == 8 - = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of - (# s', _ #) -> (# s', I# (off# +# 8#) #) + = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) | otherwise - = _ - performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) - performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) - performAddScal STBool b (I# off#) = - let !(I8# i) = fromIntegral (fromEnum b) - in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) + performAddScal STF32 x off@(I# off#) = + off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) + performAddScal STF64 x off@(I# off#) = + off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) + performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans + + casLoop :: Eq w + => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) + -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) + -> Addr# -- ^ Address to attempt to modify + -> (w -> w) -- ^ Operation to apply to the value + -> IO () + casLoop readOp casOp addr modify = readOp addr 0# >>= loop + where + loop value = do + value' <- casOp addr value (modify value) + if value == value' + then return () + else loop value' - casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #)) - -> (w -> w) - -> r - -> State# d -> (# State# d, r #) - casLoop casOp add ret = _ + in () <$ go False topty top_depth top_index top_value 0 withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) withAccum ty start fun = do - let !(I# size) = tSize ty start - accum <- AcM . ST $ \s -> case newByteArray# size s of - (# s', mbarr #) -> (# s', Accum ty mbarr #) - AcM $ accumWrite accum start + -- The initial write must happen before any of the adds or reads, so it makes + -- sense to put it in IO together with the allocation, instead of in AcM. + accum <- AcM $ do buffer <- mallocBytes (tSize ty start) + ptr <- newForeignPtr finalizerFree buffer + let accum = Accum ty ptr + accumWrite accum start + return accum b <- fun accum out <- accumRead accum return (out, b) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do + mvars <- mapM (\_ -> newEmptyMVar) actions + forM_ (zip actions mvars) $ \(AcM action, var) -> + forkIO $ action >>= putMVar var + mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8 :: Addr# -> Int# -> IO Int8 +readInt32 :: Addr# -> Int# -> IO Int32 +readInt64 :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) +readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) + +writeInt8# :: Addr# -> Int# -> Int8# -> IO () +writeInt32# :: Addr# -> Int# -> Int32# -> IO () +writeInt64# :: Addr# -> Int# -> Int64# -> IO () +writeFloat# :: Addr# -> Int# -> Float# -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = + IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = + IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) |