aboutsummaryrefslogtreecommitdiff
path: root/src/Interpreter/AccumOld.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter/AccumOld.hs')
-rw-r--r--src/Interpreter/AccumOld.hs366
1 files changed, 0 insertions, 366 deletions
diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs
deleted file mode 100644
index af7be1e..0000000
--- a/src/Interpreter/AccumOld.hs
+++ /dev/null
@@ -1,366 +0,0 @@
-{-# LANGUAGE BangPatterns #-}
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE MagicHash #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UnboxedTuples #-}
-module Interpreter.Accum (
- AcM,
- runAcM,
- Rep',
- Accum,
- withAccum,
- accumAdd,
- inParallel,
-) where
-
-import Control.Concurrent
-import Control.Monad (when, forM_)
-import Data.Bifunctor (second)
-import Data.Proxy
-import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
-import Foreign.Storable (sizeOf)
-import GHC.Exts
-import GHC.Float
-import GHC.Int
-import GHC.IO (IO(..))
-import GHC.Word
-import System.IO.Unsafe (unsafePerformIO)
-
-import Array
-import AST
-import Data
-
-
-newtype AcM s a = AcM (IO a)
- deriving newtype (Functor, Applicative, Monad)
-
-runAcM :: (forall s. AcM s a) -> a
-runAcM (AcM m) = unsafePerformIO m
-
--- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
-type family Rep' s t where
- Rep' s TNil = ()
- Rep' s (TPair a b) = (Rep' s a, Rep' s b)
- Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
- Rep' s (TMaybe t) = Maybe (Rep' s t)
- Rep' s (TArr n t) = Array n (Rep' s t)
- Rep' s (TScal sty) = ScalRep sty
- Rep' s (TAccum t) = Accum s t
-
--- | Floats and integers are accumulated; booleans are left as-is.
-data Accum s t = Accum (STy t) (ForeignPtr ())
-
-tSize :: Proxy s -> STy t -> Rep' s t -> Int
-tSize p ty x = tSize' p ty (Just x)
-
-tSize' :: Proxy s -> STy t -> Int
-tSize' p typ = case typ of
- STNil -> 0
- STPair a b -> tSize' p a + tSize' p b
- STEither a b -> 1 + max (tSize' p a) (tSize' p b)
- -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
- STMaybe t -> tSize' p (STEither STNil t)
- STArr ndim t ->
- case val of
- Nothing -> error "Nested arrays not supported in this implementation"
- Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
- STScal sty -> goScal sty
- STAccum{} -> error "Nested accumulators unsupported"
- where
- goScal :: SScalTy t -> Int
- goScal STI32 = 4
- goScal STI64 = 8
- goScal STF32 = 4
- goScal STF64 = 8
- goScal STBool = 1
-
--- | This operation does not commute with 'accumAdd', so it must be used with
--- care. Furthermore it must be used on exactly the same value as tSize was
--- called on. Hence it lives in IO, not in AcM.
-accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
-accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
- let
- go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
- go inarr ty val off = case ty of
- STNil -> return off
- STPair a b -> do
- off1 <- go inarr a (fst val) off
- go inarr b (snd val) off1
- STEither a b -> do
- let !(I# off#) = off
- off1 <- case val of
- Left x -> do
- let !(I8# tag#) = 0
- writeInt8# addr# off# tag#
- go inarr a x (off + 1)
- Right y -> do
- let !(I8# tag#) = 1
- writeInt8# addr# off# tag#
- go inarr b y (off + 1)
- if inarr
- then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
- else return off1
- -- Representation is the same, but add operation is different
- STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
- STArr _ t
- | inarr -> error "Nested arrays not supported in this implementation"
- | otherwise -> do
- off1 <- goShape (arrayShape val) off
- let eltsize = tSize' (Proxy @s) t Nothing
- n = arraySize val
- traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
- return (off1 + eltsize * n)
- STScal sty -> goScal sty val off
- STAccum{} -> error "Nested accumulators unsupported"
-
- goShape :: Shape n -> Int -> IO Int
- goShape ShNil off = return off
- goShape (ShCons sh n) off = do
- off1@(I# off1#) <- goShape sh off
- let !(I64# n'#) = fromIntegral n
- writeInt64# addr# off1# n'#
- return (off1 + 8)
-
- goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
- goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
- goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
- goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
- goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
- goScal STBool b off@(I# off#) = do
- let !(I8# i) = fromIntegral (fromEnum b)
- off + 1 <$ writeInt8# addr# off# i
-
- in () <$ go False topty top_value 0
-
-accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
-accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
- let
- go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
- go inarr ty off = case ty of
- STNil -> return (off, ())
- STPair a b -> do
- (off1, x) <- go inarr a off
- (off2, y) <- go inarr b off1
- return (off1 + off2, (x, y))
- STEither a b -> do
- let !(I# off#) = off
- tag <- readInt8 addr# off#
- (off1, val) <- case tag of
- 0 -> fmap Left <$> go inarr a (off + 1)
- 1 -> fmap Right <$> go inarr b (off + 1)
- _ -> error "Invalid tag in accum memory"
- if inarr
- then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
- else return (off1, val)
- -- Representation is the same, but add operation is different
- STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
- STArr ndim t
- | inarr -> error "Nested arrays not supported in this implementation"
- | otherwise -> do
- (off1, sh) <- readShape addr# ndim off
- let eltsize = tSize' (Proxy @s) t Nothing
- n = shapeSize sh
- arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
- return (off1 + eltsize * n, arr)
- STScal sty -> goScal sty off
- STAccum{} -> error "Nested accumulators unsupported"
-
- goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
- goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
- goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
- goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
- goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
- goScal STBool off@(I# off#) = do
- i8 <- readInt8 addr# off#
- return (off + 1, toEnum (fromIntegral i8))
-
- in snd <$> go False topty 0
-
-readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
-readShape _ SZ off = return (off, ShNil)
-readShape mbarr (SS ndim) off = do
- (off1@(I# off1#), sh) <- readShape mbarr ndim off
- n' <- readInt64 mbarr off1#
- return (off1 + 8, ShCons sh (fromIntegral n'))
-
--- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
--- the list.
-data InvShape n where
- IShNil :: InvShape Z
- IShCons :: Int -- ^ How many subarrays are there?
- -> Int -- ^ What is the size of all subarrays together?
- -> InvShape n -- ^ Sub array inverted shape
- -> InvShape (S n)
-
-ishSize :: InvShape n -> Int
-ishSize IShNil = 1
-ishSize (IShCons _ sz _) = sz
-
-invertShape :: forall n. Shape n -> InvShape n
-invertShape | Refl <- lemPlusZero @n = flip go IShNil
- where
- go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
- go ShNil ish = ish
- go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
-
-accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
-accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
- let
- go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
- go inarr ty SZ () val off = () <$ performAdd inarr ty val off
- go inarr ty (SS dep) idx val off = case (ty, idx, val) of
- (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
- (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
- (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
- (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
- (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
- (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
- (STMaybe t, _, _) -> _ idx val
- (STArr rank eltty, _, _)
- | inarr -> error "Nested arrays"
- | otherwise -> do
- (off1, ish) <- second invertShape <$> readShape addr# rank off
- goArr (SS dep) ish eltty idx val off1
- (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
- (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
- (STAccum{}, _, _) -> error "Nested accumulators unsupported"
-
- goArr :: SNat i' -> InvShape n -> STy t'
- -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
- goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
- goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
- goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
- let i' = fromIntegral @(Rep' s TIx) @Int i
- when (i' < 0 || i' >= n) $
- error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
- goArr depm1 ish t1 idx val (off + i' * ishSize ish)
-
- performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
- performAddArr arraySz eltty val off = do
- let eltsize = tSize' (Proxy @s) eltty Nothing
- forM_ [0 .. arraySz - 1] $ \lini ->
- performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
- return (off + arraySz * eltsize)
-
- performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
- performAdd inarr ty val off = case ty of
- STNil -> return off
- STPair t1 t2 -> do
- off1 <- performAdd inarr t1 (fst val) off
- performAdd inarr t2 (snd val) off1
- STEither t1 t2 -> do
- let !(I# off#) = off
- tag <- readInt8 addr# off#
- off1 <- case (val, tag) of
- (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
- (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
- _ -> error "accumAdd: Tag mismatch for Either"
- if inarr
- then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
- else return off1
- STArr n ty'
- | inarr -> error "Nested array"
- | otherwise -> do
- (off1, sh) <- readShape addr# n off
- performAddArr (shapeSize sh) ty' val off1
- STScal ty' -> performAddScal ty' val off
- STAccum{} -> error "Nested accumulators unsupported"
-
- performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
- performAddScal STI32 (I32# x#) off@(I# off#)
- | sizeOf (undefined :: Int) == 4
- = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
- | otherwise
- = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
- performAddScal STI64 (I64# x#) off@(I# off#)
- | sizeOf (undefined :: Int) == 8
- = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
- | otherwise
- = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
- performAddScal STF32 x off@(I# off#) =
- off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
- performAddScal STF64 x off@(I# off#) =
- off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
- performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
-
- casLoop :: Eq w
- => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
- -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
- -> Addr# -- ^ Address to attempt to modify
- -> (w -> w) -- ^ Operation to apply to the value
- -> IO ()
- casLoop readOp casOp addr modify = readOp addr 0# >>= loop
- where
- loop value = do
- value' <- casOp addr value (modify value)
- if value == value'
- then return ()
- else loop value'
-
- in () <$ go False topty top_depth top_index top_value 0
-
-withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
-withAccum ty start fun = do
- -- The initial write must happen before any of the adds or reads, so it makes
- -- sense to put it in IO together with the allocation, instead of in AcM.
- accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
- ptr <- newForeignPtr finalizerFree buffer
- let accum = Accum ty ptr
- accumWrite accum start
- return accum
- b <- fun accum
- out <- accumRead accum
- return (b, out)
-
-inParallel :: [AcM s t] -> AcM s [t]
-inParallel actions = AcM $ do
- mvars <- mapM (\_ -> newEmptyMVar) actions
- forM_ (zip actions mvars) $ \(AcM action, var) ->
- forkIO $ action >>= putMVar var
- mapM takeMVar mvars
-
--- | Offset is in bytes.
-readInt8 :: Addr# -> Int# -> IO Int8
-readInt32 :: Addr# -> Int# -> IO Int32
-readInt64 :: Addr# -> Int# -> IO Int64
-readWord32 :: Addr# -> Int# -> IO Word32
-readWord64 :: Addr# -> Int# -> IO Word64
-readFloat :: Addr# -> Int# -> IO Float
-readDouble :: Addr# -> Int# -> IO Double
-readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
-readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
-readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
-readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
-readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
-readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
-readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
-
-writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
-writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
-writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
-writeFloat# :: Addr# -> Int# -> Float# -> IO ()
-writeDouble# :: Addr# -> Int# -> Double# -> IO ()
-writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
-writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
-writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
-writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
-writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
-
-fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
-fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
-
-atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
-atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
-atomicCasWord32Addr addr (W32# expected) (W32# desired) =
- IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
-atomicCasWord64Addr addr (W64# expected) (W64# desired) =
- IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)