diff options
Diffstat (limited to 'src/Interpreter/AccumOld.hs')
-rw-r--r-- | src/Interpreter/AccumOld.hs | 366 |
1 files changed, 366 insertions, 0 deletions
diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs new file mode 100644 index 0000000..af7be1e --- /dev/null +++ b/src/Interpreter/AccumOld.hs @@ -0,0 +1,366 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +module Interpreter.Accum ( + AcM, + runAcM, + Rep', + Accum, + withAccum, + accumAdd, + inParallel, +) where + +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) +import Foreign.Storable (sizeOf) +import GHC.Exts +import GHC.Float +import GHC.Int +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) + +import Array +import AST +import Data + + +newtype AcM s a = AcM (IO a) + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TMaybe t) = Maybe (Rep' s t) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t + +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) + +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) + +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of + STNil -> 0 + STPair a b -> tSize' p a + tSize' p b + STEither a b -> 1 + max (tSize' p a) (tSize' p b) + -- Representation of Maybe t is the same as Either () t; the add operation is different, however. + STMaybe t -> tSize' p (STEither STNil t) + STArr ndim t -> + case val of + Nothing -> error "Nested arrays not supported in this implementation" + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing + STScal sty -> goScal sty + STAccum{} -> error "Nested accumulators unsupported" + where + goScal :: SScalTy t -> Int + goScal STI32 = 4 + goScal STI64 = 8 + goScal STF32 = 4 + goScal STF64 = 8 + goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + go inarr ty val off = case ty of + STNil -> return off + STPair a b -> do + off1 <- go inarr a (fst val) off + go inarr b (snd val) off1 + STEither a b -> do + let !(I# off#) = off + off1 <- case val of + Left x -> do + let !(I8# tag#) = 0 + writeInt8# addr# off# tag# + go inarr a x (off + 1) + Right y -> do + let !(I8# tag#) = 1 + writeInt8# addr# off# tag# + go inarr b y (off + 1) + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) + else return off1 + -- Representation is the same, but add operation is different + STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off + STArr _ t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + off1 <- goShape (arrayShape val) off + let eltsize = tSize' (Proxy @s) t Nothing + n = arraySize val + traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val + return (off1 + eltsize * n) + STScal sty -> goScal sty val off + STAccum{} -> error "Nested accumulators unsupported" + + goShape :: Shape n -> Int -> IO Int + goShape ShNil off = return off + goShape (ShCons sh n) off = do + off1@(I# off1#) <- goShape sh off + let !(I64# n'#) = fromIntegral n + writeInt64# addr# off1# n'# + return (off1 + 8) + + goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x + goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x + goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x + goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x + goScal STBool b off@(I# off#) = do + let !(I8# i) = fromIntegral (fromEnum b) + off + 1 <$ writeInt8# addr# off# i + + in () <$ go False topty top_value 0 + +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') + go inarr ty off = case ty of + STNil -> return (off, ()) + STPair a b -> do + (off1, x) <- go inarr a off + (off2, y) <- go inarr b off1 + return (off1 + off2, (x, y)) + STEither a b -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + (off1, val) <- case tag of + 0 -> fmap Left <$> go inarr a (off + 1) + 1 -> fmap Right <$> go inarr b (off + 1) + _ -> error "Invalid tag in accum memory" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) + else return (off1, val) + -- Representation is the same, but add operation is different + STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off + STArr ndim t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + (off1, sh) <- readShape addr# ndim off + let eltsize = tSize' (Proxy @s) t Nothing + n = shapeSize sh + arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) + return (off1 + eltsize * n, arr) + STScal sty -> goScal sty off + STAccum{} -> error "Nested accumulators unsupported" + + goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') + goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# + goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# + goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# + goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# + goScal STBool off@(I# off#) = do + i8 <- readInt8 addr# off# + return (off + 1, toEnum (fromIntegral i8)) + + in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do + (off1@(I# off1#), sh) <- readShape mbarr ndim off + n' <- readInt64 mbarr off1# + return (off1 + 8, ShCons sh (fromIntegral n')) + +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where + IShNil :: InvShape Z + IShCons :: Int -- ^ How many subarrays are there? + -> Int -- ^ What is the size of all subarrays together? + -> InvShape n -- ^ Sub array inverted shape + -> InvShape (S n) + +ishSize :: InvShape n -> Int +ishSize IShNil = 1 +ishSize (IShCons _ sz _) = sz + +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil + where + go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) + go ShNil ish = ish + go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () + go inarr ty SZ () val off = () <$ performAdd inarr ty val off + go inarr ty (SS dep) idx val off = case (ty, idx, val) of + (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" + (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" + (STMaybe t, _, _) -> _ idx val + (STArr rank eltty, _, _) + | inarr -> error "Nested arrays" + | otherwise -> do + (off1, ish) <- second invertShape <$> readShape addr# rank off + goArr (SS dep) ish eltty idx val off1 + (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" + (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" + (STAccum{}, _, _) -> error "Nested accumulators unsupported" + + goArr :: SNat i' -> InvShape n -> STy t' + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () + goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off + goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off + goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do + let i' = fromIntegral @(Rep' s TIx) @Int i + when (i' < 0 || i' >= n) $ + error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" + goArr depm1 ish t1 idx val (off + i' * ishSize ish) + + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int + performAddArr arraySz eltty val off = do + let eltsize = tSize' (Proxy @s) eltty Nothing + forM_ [0 .. arraySz - 1] $ \lini -> + performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) + return (off + arraySz * eltsize) + + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + performAdd inarr ty val off = case ty of + STNil -> return off + STPair t1 t2 -> do + off1 <- performAdd inarr t1 (fst val) off + performAdd inarr t2 (snd val) off1 + STEither t1 t2 -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + off1 <- case (val, tag) of + (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) + (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) + _ -> error "accumAdd: Tag mismatch for Either" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) + else return off1 + STArr n ty' + | inarr -> error "Nested array" + | otherwise -> do + (off1, sh) <- readShape addr# n off + performAddArr (shapeSize sh) ty' val off1 + STScal ty' -> performAddScal ty' val off + STAccum{} -> error "Nested accumulators unsupported" + + performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + performAddScal STI32 (I32# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 4 + = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) + | otherwise + = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) + performAddScal STI64 (I64# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 8 + = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) + | otherwise + = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) + performAddScal STF32 x off@(I# off#) = + off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) + performAddScal STF64 x off@(I# off#) = + off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) + performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans + + casLoop :: Eq w + => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) + -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) + -> Addr# -- ^ Address to attempt to modify + -> (w -> w) -- ^ Operation to apply to the value + -> IO () + casLoop readOp casOp addr modify = readOp addr 0# >>= loop + where + loop value = do + value' <- casOp addr value (modify value) + if value == value' + then return () + else loop value' + + in () <$ go False topty top_depth top_index top_value 0 + +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) +withAccum ty start fun = do + -- The initial write must happen before any of the adds or reads, so it makes + -- sense to put it in IO together with the allocation, instead of in AcM. + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) + ptr <- newForeignPtr finalizerFree buffer + let accum = Accum ty ptr + accumWrite accum start + return accum + b <- fun accum + out <- accumRead accum + return (b, out) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do + mvars <- mapM (\_ -> newEmptyMVar) actions + forM_ (zip actions mvars) $ \(AcM action, var) -> + forkIO $ action >>= putMVar var + mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8 :: Addr# -> Int# -> IO Int8 +readInt32 :: Addr# -> Int# -> IO Int32 +readInt64 :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) +readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) + +writeInt8# :: Addr# -> Int# -> Int8# -> IO () +writeInt32# :: Addr# -> Int# -> Int32# -> IO () +writeInt64# :: Addr# -> Int# -> Int64# -> IO () +writeFloat# :: Addr# -> Int# -> Float# -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = + IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = + IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) |