summaryrefslogtreecommitdiff
path: root/src/Interpreter/AccumOld.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter/AccumOld.hs')
-rw-r--r--src/Interpreter/AccumOld.hs366
1 files changed, 366 insertions, 0 deletions
diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs
new file mode 100644
index 0000000..af7be1e
--- /dev/null
+++ b/src/Interpreter/AccumOld.hs
@@ -0,0 +1,366 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+module Interpreter.Accum (
+ AcM,
+ runAcM,
+ Rep',
+ Accum,
+ withAccum,
+ accumAdd,
+ inParallel,
+) where
+
+import Control.Concurrent
+import Control.Monad (when, forM_)
+import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
+import Foreign.Storable (sizeOf)
+import GHC.Exts
+import GHC.Float
+import GHC.Int
+import GHC.IO (IO(..))
+import GHC.Word
+import System.IO.Unsafe (unsafePerformIO)
+
+import Array
+import AST
+import Data
+
+
+newtype AcM s a = AcM (IO a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
+
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TMaybe t) = Maybe (Rep' s t)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
+
+-- | Floats and integers are accumulated; booleans are left as-is.
+data Accum s t = Accum (STy t) (ForeignPtr ())
+
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
+
+tSize' :: Proxy s -> STy t -> Int
+tSize' p typ = case typ of
+ STNil -> 0
+ STPair a b -> tSize' p a + tSize' p b
+ STEither a b -> 1 + max (tSize' p a) (tSize' p b)
+ -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
+ STMaybe t -> tSize' p (STEither STNil t)
+ STArr ndim t ->
+ case val of
+ Nothing -> error "Nested arrays not supported in this implementation"
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
+ STScal sty -> goScal sty
+ STAccum{} -> error "Nested accumulators unsupported"
+ where
+ goScal :: SScalTy t -> Int
+ goScal STI32 = 4
+ goScal STI64 = 8
+ goScal STF32 = 4
+ goScal STF64 = 8
+ goScal STBool = 1
+
+-- | This operation does not commute with 'accumAdd', so it must be used with
+-- care. Furthermore it must be used on exactly the same value as tSize was
+-- called on. Hence it lives in IO, not in AcM.
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
+accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ go inarr ty val off = case ty of
+ STNil -> return off
+ STPair a b -> do
+ off1 <- go inarr a (fst val) off
+ go inarr b (snd val) off1
+ STEither a b -> do
+ let !(I# off#) = off
+ off1 <- case val of
+ Left x -> do
+ let !(I8# tag#) = 0
+ writeInt8# addr# off# tag#
+ go inarr a x (off + 1)
+ Right y -> do
+ let !(I8# tag#) = 1
+ writeInt8# addr# off# tag#
+ go inarr b y (off + 1)
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
+ else return off1
+ -- Representation is the same, but add operation is different
+ STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
+ STArr _ t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ off1 <- goShape (arrayShape val) off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = arraySize val
+ traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
+ return (off1 + eltsize * n)
+ STScal sty -> goScal sty val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goShape :: Shape n -> Int -> IO Int
+ goShape ShNil off = return off
+ goShape (ShCons sh n) off = do
+ off1@(I# off1#) <- goShape sh off
+ let !(I64# n'#) = fromIntegral n
+ writeInt64# addr# off1# n'#
+ return (off1 + 8)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
+ goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
+ goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
+ goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
+ goScal STBool b off@(I# off#) = do
+ let !(I8# i) = fromIntegral (fromEnum b)
+ off + 1 <$ writeInt8# addr# off# i
+
+ in () <$ go False topty top_value 0
+
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
+accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
+ go inarr ty off = case ty of
+ STNil -> return (off, ())
+ STPair a b -> do
+ (off1, x) <- go inarr a off
+ (off2, y) <- go inarr b off1
+ return (off1 + off2, (x, y))
+ STEither a b -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ (off1, val) <- case tag of
+ 0 -> fmap Left <$> go inarr a (off + 1)
+ 1 -> fmap Right <$> go inarr b (off + 1)
+ _ -> error "Invalid tag in accum memory"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
+ else return (off1, val)
+ -- Representation is the same, but add operation is different
+ STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
+ STArr ndim t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# ndim off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = shapeSize sh
+ arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
+ return (off1 + eltsize * n, arr)
+ STScal sty -> goScal sty off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
+ goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
+ goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
+ goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
+ goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
+ goScal STBool off@(I# off#) = do
+ i8 <- readInt8 addr# off#
+ return (off + 1, toEnum (fromIntegral i8))
+
+ in snd <$> go False topty 0
+
+readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
+readShape _ SZ off = return (off, ShNil)
+readShape mbarr (SS ndim) off = do
+ (off1@(I# off1#), sh) <- readShape mbarr ndim off
+ n' <- readInt64 mbarr off1#
+ return (off1 + 8, ShCons sh (fromIntegral n'))
+
+-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
+-- the list.
+data InvShape n where
+ IShNil :: InvShape Z
+ IShCons :: Int -- ^ How many subarrays are there?
+ -> Int -- ^ What is the size of all subarrays together?
+ -> InvShape n -- ^ Sub array inverted shape
+ -> InvShape (S n)
+
+ishSize :: InvShape n -> Int
+ishSize IShNil = 1
+ishSize (IShCons _ sz _) = sz
+
+invertShape :: forall n. Shape n -> InvShape n
+invertShape | Refl <- lemPlusZero @n = flip go IShNil
+ where
+ go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
+ go ShNil ish = ish
+ go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
+
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
+accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
+ go inarr ty SZ () val off = () <$ performAdd inarr ty val off
+ go inarr ty (SS dep) idx val off = case (ty, idx, val) of
+ (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
+ (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
+ (STMaybe t, _, _) -> _ idx val
+ (STArr rank eltty, _, _)
+ | inarr -> error "Nested arrays"
+ | otherwise -> do
+ (off1, ish) <- second invertShape <$> readShape addr# rank off
+ goArr (SS dep) ish eltty idx val off1
+ (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
+ (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
+ (STAccum{}, _, _) -> error "Nested accumulators unsupported"
+
+ goArr :: SNat i' -> InvShape n -> STy t'
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
+ goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
+ goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
+ goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
+ let i' = fromIntegral @(Rep' s TIx) @Int i
+ when (i' < 0 || i' >= n) $
+ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
+ goArr depm1 ish t1 idx val (off + i' * ishSize ish)
+
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
+ performAddArr arraySz eltty val off = do
+ let eltsize = tSize' (Proxy @s) eltty Nothing
+ forM_ [0 .. arraySz - 1] $ \lini ->
+ performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
+ return (off + arraySz * eltsize)
+
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ performAdd inarr ty val off = case ty of
+ STNil -> return off
+ STPair t1 t2 -> do
+ off1 <- performAdd inarr t1 (fst val) off
+ performAdd inarr t2 (snd val) off1
+ STEither t1 t2 -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ off1 <- case (val, tag) of
+ (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
+ (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
+ _ -> error "accumAdd: Tag mismatch for Either"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
+ else return off1
+ STArr n ty'
+ | inarr -> error "Nested array"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# n off
+ performAddArr (shapeSize sh) ty' val off1
+ STScal ty' -> performAddScal ty' val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ performAddScal STI32 (I32# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 4
+ = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
+ | otherwise
+ = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
+ performAddScal STI64 (I64# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 8
+ = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
+ | otherwise
+ = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
+ performAddScal STF32 x off@(I# off#) =
+ off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
+ performAddScal STF64 x off@(I# off#) =
+ off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
+ performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
+
+ casLoop :: Eq w
+ => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
+ -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
+ -> Addr# -- ^ Address to attempt to modify
+ -> (w -> w) -- ^ Operation to apply to the value
+ -> IO ()
+ casLoop readOp casOp addr modify = readOp addr 0# >>= loop
+ where
+ loop value = do
+ value' <- casOp addr value (modify value)
+ if value == value'
+ then return ()
+ else loop value'
+
+ in () <$ go False topty top_depth top_index top_value 0
+
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
+withAccum ty start fun = do
+ -- The initial write must happen before any of the adds or reads, so it makes
+ -- sense to put it in IO together with the allocation, instead of in AcM.
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
+ ptr <- newForeignPtr finalizerFree buffer
+ let accum = Accum ty ptr
+ accumWrite accum start
+ return accum
+ b <- fun accum
+ out <- accumRead accum
+ return (b, out)
+
+inParallel :: [AcM s t] -> AcM s [t]
+inParallel actions = AcM $ do
+ mvars <- mapM (\_ -> newEmptyMVar) actions
+ forM_ (zip actions mvars) $ \(AcM action, var) ->
+ forkIO $ action >>= putMVar var
+ mapM takeMVar mvars
+
+-- | Offset is in bytes.
+readInt8 :: Addr# -> Int# -> IO Int8
+readInt32 :: Addr# -> Int# -> IO Int32
+readInt64 :: Addr# -> Int# -> IO Int64
+readWord32 :: Addr# -> Int# -> IO Word32
+readWord64 :: Addr# -> Int# -> IO Word64
+readFloat :: Addr# -> Int# -> IO Float
+readDouble :: Addr# -> Int# -> IO Double
+readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
+readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
+readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
+readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
+readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
+readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
+readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
+
+writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
+writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
+writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
+writeFloat# :: Addr# -> Int# -> Float# -> IO ()
+writeDouble# :: Addr# -> Int# -> Double# -> IO ()
+writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+
+fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
+fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
+
+atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
+atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
+atomicCasWord32Addr addr (W32# expected) (W32# desired) =
+ IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
+atomicCasWord64Addr addr (W64# expected) (W64# desired) =
+ IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)