{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} module Interpreter.Accum ( AcM, runAcM, Rep', Accum, withAccum, accumAdd, inParallel, ) where import Control.Concurrent import Control.Monad (when, forM_) import Data.Bifunctor (second) import Data.Proxy import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) import Foreign.Storable (sizeOf) import GHC.Exts import GHC.Float import GHC.Int import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) import Array import AST import Data newtype AcM s a = AcM (IO a) deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m -- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. type family Rep' s t where Rep' s TNil = () Rep' s (TPair a b) = (Rep' s a, Rep' s b) Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) Rep' s (TMaybe t) = Maybe (Rep' s t) Rep' s (TArr n t) = Array n (Rep' s t) Rep' s (TScal sty) = ScalRep sty Rep' s (TAccum t) = Accum s t -- | Floats and integers are accumulated; booleans are left as-is. data Accum s t = Accum (STy t) (ForeignPtr ()) tSize :: Proxy s -> STy t -> Rep' s t -> Int tSize p ty x = tSize' p ty (Just x) tSize' :: Proxy s -> STy t -> Int tSize' p typ = case typ of STNil -> 0 STPair a b -> tSize' p a + tSize' p b STEither a b -> 1 + max (tSize' p a) (tSize' p b) -- Representation of Maybe t is the same as Either () t; the add operation is different, however. STMaybe t -> tSize' p (STEither STNil t) STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing STScal sty -> goScal sty STAccum{} -> error "Nested accumulators unsupported" where goScal :: SScalTy t -> Int goScal STI32 = 4 goScal STI64 = 8 goScal STF32 = 4 goScal STF64 = 8 goScal STBool = 1 -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in IO, not in AcM. accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> let go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int go inarr ty val off = case ty of STNil -> return off STPair a b -> do off1 <- go inarr a (fst val) off go inarr b (snd val) off1 STEither a b -> do let !(I# off#) = off off1 <- case val of Left x -> do let !(I8# tag#) = 0 writeInt8# addr# off# tag# go inarr a x (off + 1) Right y -> do let !(I8# tag#) = 1 writeInt8# addr# off# tag# go inarr b y (off + 1) if inarr then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) else return off1 -- Representation is the same, but add operation is different STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off STArr _ t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do off1 <- goShape (arrayShape val) off let eltsize = tSize' (Proxy @s) t Nothing n = arraySize val traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val return (off1 + eltsize * n) STScal sty -> goScal sty val off STAccum{} -> error "Nested accumulators unsupported" goShape :: Shape n -> Int -> IO Int goShape ShNil off = return off goShape (ShCons sh n) off = do off1@(I# off1#) <- goShape sh off let !(I64# n'#) = fromIntegral n writeInt64# addr# off1# n'# return (off1 + 8) goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x goScal STBool b off@(I# off#) = do let !(I8# i) = fromIntegral (fromEnum b) off + 1 <$ writeInt8# addr# off# i in () <$ go False topty top_value 0 accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') go inarr ty off = case ty of STNil -> return (off, ()) STPair a b -> do (off1, x) <- go inarr a off (off2, y) <- go inarr b off1 return (off1 + off2, (x, y)) STEither a b -> do let !(I# off#) = off tag <- readInt8 addr# off# (off1, val) <- case tag of 0 -> fmap Left <$> go inarr a (off + 1) 1 -> fmap Right <$> go inarr b (off + 1) _ -> error "Invalid tag in accum memory" if inarr then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) else return (off1, val) -- Representation is the same, but add operation is different STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do (off1, sh) <- readShape addr# ndim off let eltsize = tSize' (Proxy @s) t Nothing n = shapeSize sh arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) return (off1 + eltsize * n, arr) STScal sty -> goScal sty off STAccum{} -> error "Nested accumulators unsupported" goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# goScal STBool off@(I# off#) = do i8 <- readInt8 addr# off# return (off + 1, toEnum (fromIntegral i8)) in snd <$> go False topty 0 readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) readShape _ SZ off = return (off, ShNil) readShape mbarr (SS ndim) off = do (off1@(I# off1#), sh) <- readShape mbarr ndim off n' <- readInt64 mbarr off1# return (off1 + 8, ShCons sh (fromIntegral n')) -- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of -- the list. data InvShape n where IShNil :: InvShape Z IShCons :: Int -- ^ How many subarrays are there? -> Int -- ^ What is the size of all subarrays together? -> InvShape n -- ^ Sub array inverted shape -> InvShape (S n) ishSize :: InvShape n -> Int ishSize IShNil = 1 ishSize (IShCons _ sz _) = sz invertShape :: forall n. Shape n -> InvShape n invertShape | Refl <- lemPlusZero @n = flip go IShNil where go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) go ShNil ish = ish go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () go inarr ty SZ () val off = () <$ performAdd inarr ty val off go inarr ty (SS dep) idx val off = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" (STMaybe t, _, _) -> _ idx val (STArr rank eltty, _, _) | inarr -> error "Nested arrays" | otherwise -> do (off1, ish) <- second invertShape <$> readShape addr# rank off goArr (SS dep) ish eltty idx val off1 (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" (STAccum{}, _, _) -> error "Nested accumulators unsupported" goArr :: SNat i' -> InvShape n -> STy t' -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do let i' = fromIntegral @(Rep' s TIx) @Int i when (i' < 0 || i' >= n) $ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" goArr depm1 ish t1 idx val (off + i' * ishSize ish) performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int performAddArr arraySz eltty val off = do let eltsize = tSize' (Proxy @s) eltty Nothing forM_ [0 .. arraySz - 1] $ \lini -> performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) return (off + arraySz * eltsize) performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int performAdd inarr ty val off = case ty of STNil -> return off STPair t1 t2 -> do off1 <- performAdd inarr t1 (fst val) off performAdd inarr t2 (snd val) off1 STEither t1 t2 -> do let !(I# off#) = off tag <- readInt8 addr# off# off1 <- case (val, tag) of (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) _ -> error "accumAdd: Tag mismatch for Either" if inarr then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) else return off1 STArr n ty' | inarr -> error "Nested array" | otherwise -> do (off1, sh) <- readShape addr# n off performAddArr (shapeSize sh) ty' val off1 STScal ty' -> performAddScal ty' val off STAccum{} -> error "Nested accumulators unsupported" performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int performAddScal STI32 (I32# x#) off@(I# off#) | sizeOf (undefined :: Int) == 4 = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) | otherwise = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) performAddScal STI64 (I64# x#) off@(I# off#) | sizeOf (undefined :: Int) == 8 = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) | otherwise = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) performAddScal STF32 x off@(I# off#) = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) performAddScal STF64 x off@(I# off#) = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans casLoop :: Eq w => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) -> Addr# -- ^ Address to attempt to modify -> (w -> w) -- ^ Operation to apply to the value -> IO () casLoop readOp casOp addr modify = readOp addr 0# >>= loop where loop value = do value' <- casOp addr value (modify value) if value == value' then return () else loop value' in () <$ go False topty top_depth top_index top_value 0 withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) withAccum ty start fun = do -- The initial write must happen before any of the adds or reads, so it makes -- sense to put it in IO together with the allocation, instead of in AcM. accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) ptr <- newForeignPtr finalizerFree buffer let accum = Accum ty ptr accumWrite accum start return accum b <- fun accum out <- accumRead accum return (b, out) inParallel :: [AcM s t] -> AcM s [t] inParallel actions = AcM $ do mvars <- mapM (\_ -> newEmptyMVar) actions forM_ (zip actions mvars) $ \(AcM action, var) -> forkIO $ action >>= putMVar var mapM takeMVar mvars -- | Offset is in bytes. readInt8 :: Addr# -> Int# -> IO Int8 readInt32 :: Addr# -> Int# -> IO Int32 readInt64 :: Addr# -> Int# -> IO Int64 readWord32 :: Addr# -> Int# -> IO Word32 readWord64 :: Addr# -> Int# -> IO Word64 readFloat :: Addr# -> Int# -> IO Float readDouble :: Addr# -> Int# -> IO Double readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) writeInt8# :: Addr# -> Int# -> Int8# -> IO () writeInt32# :: Addr# -> Int# -> Int32# -> IO () writeInt64# :: Addr# -> Int# -> Int64# -> IO () writeFloat# :: Addr# -> Int# -> Float# -> IO () writeDouble# :: Addr# -> Int# -> Double# -> IO () writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) fetchAddWord# :: Addr# -> Int# -> Word# -> IO () fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 atomicCasWord32Addr addr (W32# expected) (W32# desired) = IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) atomicCasWord64Addr addr (W64# expected) (W64# desired) = IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)