diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Interpreter/Accum.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Interpreter/Accum.hs')
| -rw-r--r-- | src/Interpreter/Accum.hs | 366 |
1 files changed, 0 insertions, 366 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs deleted file mode 100644 index af7be1e..0000000 --- a/src/Interpreter/Accum.hs +++ /dev/null @@ -1,366 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( - AcM, - runAcM, - Rep', - Accum, - withAccum, - accumAdd, - inParallel, -) where - -import Control.Concurrent -import Control.Monad (when, forM_) -import Data.Bifunctor (second) -import Data.Proxy -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) -import Foreign.Storable (sizeOf) -import GHC.Exts -import GHC.Float -import GHC.Int -import GHC.IO (IO(..)) -import GHC.Word -import System.IO.Unsafe (unsafePerformIO) - -import Array -import AST -import Data - - -newtype AcM s a = AcM (IO a) - deriving newtype (Functor, Applicative, Monad) - -runAcM :: (forall s. AcM s a) -> a -runAcM (AcM m) = unsafePerformIO m - --- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. -type family Rep' s t where - Rep' s TNil = () - Rep' s (TPair a b) = (Rep' s a, Rep' s b) - Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) - Rep' s (TMaybe t) = Maybe (Rep' s t) - Rep' s (TArr n t) = Array n (Rep' s t) - Rep' s (TScal sty) = ScalRep sty - Rep' s (TAccum t) = Accum s t - --- | Floats and integers are accumulated; booleans are left as-is. -data Accum s t = Accum (STy t) (ForeignPtr ()) - -tSize :: Proxy s -> STy t -> Rep' s t -> Int -tSize p ty x = tSize' p ty (Just x) - -tSize' :: Proxy s -> STy t -> Int -tSize' p typ = case typ of - STNil -> 0 - STPair a b -> tSize' p a + tSize' p b - STEither a b -> 1 + max (tSize' p a) (tSize' p b) - -- Representation of Maybe t is the same as Either () t; the add operation is different, however. - STMaybe t -> tSize' p (STEither STNil t) - STArr ndim t -> - case val of - Nothing -> error "Nested arrays not supported in this implementation" - Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing - STScal sty -> goScal sty - STAccum{} -> error "Nested accumulators unsupported" - where - goScal :: SScalTy t -> Int - goScal STI32 = 4 - goScal STI64 = 8 - goScal STF32 = 4 - goScal STF64 = 8 - goScal STBool = 1 - --- | This operation does not commute with 'accumAdd', so it must be used with --- care. Furthermore it must be used on exactly the same value as tSize was --- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () -accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int - go inarr ty val off = case ty of - STNil -> return off - STPair a b -> do - off1 <- go inarr a (fst val) off - go inarr b (snd val) off1 - STEither a b -> do - let !(I# off#) = off - off1 <- case val of - Left x -> do - let !(I8# tag#) = 0 - writeInt8# addr# off# tag# - go inarr a x (off + 1) - Right y -> do - let !(I8# tag#) = 1 - writeInt8# addr# off# tag# - go inarr b y (off + 1) - if inarr - then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) - else return off1 - -- Representation is the same, but add operation is different - STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off - STArr _ t - | inarr -> error "Nested arrays not supported in this implementation" - | otherwise -> do - off1 <- goShape (arrayShape val) off - let eltsize = tSize' (Proxy @s) t Nothing - n = arraySize val - traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val - return (off1 + eltsize * n) - STScal sty -> goScal sty val off - STAccum{} -> error "Nested accumulators unsupported" - - goShape :: Shape n -> Int -> IO Int - goShape ShNil off = return off - goShape (ShCons sh n) off = do - off1@(I# off1#) <- goShape sh off - let !(I64# n'#) = fromIntegral n - writeInt64# addr# off1# n'# - return (off1 + 8) - - goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int - goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x - goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x - goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x - goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x - goScal STBool b off@(I# off#) = do - let !(I8# i) = fromIntegral (fromEnum b) - off + 1 <$ writeInt8# addr# off# i - - in () <$ go False topty top_value 0 - -accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) -accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') - go inarr ty off = case ty of - STNil -> return (off, ()) - STPair a b -> do - (off1, x) <- go inarr a off - (off2, y) <- go inarr b off1 - return (off1 + off2, (x, y)) - STEither a b -> do - let !(I# off#) = off - tag <- readInt8 addr# off# - (off1, val) <- case tag of - 0 -> fmap Left <$> go inarr a (off + 1) - 1 -> fmap Right <$> go inarr b (off + 1) - _ -> error "Invalid tag in accum memory" - if inarr - then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) - else return (off1, val) - -- Representation is the same, but add operation is different - STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off - STArr ndim t - | inarr -> error "Nested arrays not supported in this implementation" - | otherwise -> do - (off1, sh) <- readShape addr# ndim off - let eltsize = tSize' (Proxy @s) t Nothing - n = shapeSize sh - arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) - return (off1 + eltsize * n, arr) - STScal sty -> goScal sty off - STAccum{} -> error "Nested accumulators unsupported" - - goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') - goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# - goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# - goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# - goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# - goScal STBool off@(I# off#) = do - i8 <- readInt8 addr# off# - return (off + 1, toEnum (fromIntegral i8)) - - in snd <$> go False topty 0 - -readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) -readShape _ SZ off = return (off, ShNil) -readShape mbarr (SS ndim) off = do - (off1@(I# off1#), sh) <- readShape mbarr ndim off - n' <- readInt64 mbarr off1# - return (off1 + 8, ShCons sh (fromIntegral n')) - --- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of --- the list. -data InvShape n where - IShNil :: InvShape Z - IShCons :: Int -- ^ How many subarrays are there? - -> Int -- ^ What is the size of all subarrays together? - -> InvShape n -- ^ Sub array inverted shape - -> InvShape (S n) - -ishSize :: InvShape n -> Int -ishSize IShNil = 1 -ishSize (IShCons _ sz _) = sz - -invertShape :: forall n. Shape n -> InvShape n -invertShape | Refl <- lemPlusZero @n = flip go IShNil - where - go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) - go ShNil ish = ish - go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) - -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () -accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () - go inarr ty SZ () val off = () <$ performAdd inarr ty val off - go inarr ty (SS dep) idx val off = case (ty, idx, val) of - (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off - (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off - (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" - (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off - (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off - (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" - (STMaybe t, _, _) -> _ idx val - (STArr rank eltty, _, _) - | inarr -> error "Nested arrays" - | otherwise -> do - (off1, ish) <- second invertShape <$> readShape addr# rank off - goArr (SS dep) ish eltty idx val off1 - (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" - (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" - (STAccum{}, _, _) -> error "Nested accumulators unsupported" - - goArr :: SNat i' -> InvShape n -> STy t' - -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () - goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off - goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off - goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do - let i' = fromIntegral @(Rep' s TIx) @Int i - when (i' < 0 || i' >= n) $ - error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" - goArr depm1 ish t1 idx val (off + i' * ishSize ish) - - performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int - performAddArr arraySz eltty val off = do - let eltsize = tSize' (Proxy @s) eltty Nothing - forM_ [0 .. arraySz - 1] $ \lini -> - performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) - return (off + arraySz * eltsize) - - performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int - performAdd inarr ty val off = case ty of - STNil -> return off - STPair t1 t2 -> do - off1 <- performAdd inarr t1 (fst val) off - performAdd inarr t2 (snd val) off1 - STEither t1 t2 -> do - let !(I# off#) = off - tag <- readInt8 addr# off# - off1 <- case (val, tag) of - (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) - (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) - _ -> error "accumAdd: Tag mismatch for Either" - if inarr - then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) - else return off1 - STArr n ty' - | inarr -> error "Nested array" - | otherwise -> do - (off1, sh) <- readShape addr# n off - performAddArr (shapeSize sh) ty' val off1 - STScal ty' -> performAddScal ty' val off - STAccum{} -> error "Nested accumulators unsupported" - - performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int - performAddScal STI32 (I32# x#) off@(I# off#) - | sizeOf (undefined :: Int) == 4 - = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) - | otherwise - = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) - performAddScal STI64 (I64# x#) off@(I# off#) - | sizeOf (undefined :: Int) == 8 - = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) - | otherwise - = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) - performAddScal STF32 x off@(I# off#) = - off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) - performAddScal STF64 x off@(I# off#) = - off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) - performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans - - casLoop :: Eq w - => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) - -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) - -> Addr# -- ^ Address to attempt to modify - -> (w -> w) -- ^ Operation to apply to the value - -> IO () - casLoop readOp casOp addr modify = readOp addr 0# >>= loop - where - loop value = do - value' <- casOp addr value (modify value) - if value == value' - then return () - else loop value' - - in () <$ go False topty top_depth top_index top_value 0 - -withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) -withAccum ty start fun = do - -- The initial write must happen before any of the adds or reads, so it makes - -- sense to put it in IO together with the allocation, instead of in AcM. - accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) - ptr <- newForeignPtr finalizerFree buffer - let accum = Accum ty ptr - accumWrite accum start - return accum - b <- fun accum - out <- accumRead accum - return (b, out) - -inParallel :: [AcM s t] -> AcM s [t] -inParallel actions = AcM $ do - mvars <- mapM (\_ -> newEmptyMVar) actions - forM_ (zip actions mvars) $ \(AcM action, var) -> - forkIO $ action >>= putMVar var - mapM takeMVar mvars - --- | Offset is in bytes. -readInt8 :: Addr# -> Int# -> IO Int8 -readInt32 :: Addr# -> Int# -> IO Int32 -readInt64 :: Addr# -> Int# -> IO Int64 -readWord32 :: Addr# -> Int# -> IO Word32 -readWord64 :: Addr# -> Int# -> IO Word64 -readFloat :: Addr# -> Int# -> IO Float -readDouble :: Addr# -> Int# -> IO Double -readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) -readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) -readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) -readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) -readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) -readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) -readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) - -writeInt8# :: Addr# -> Int# -> Int8# -> IO () -writeInt32# :: Addr# -> Int# -> Int32# -> IO () -writeInt64# :: Addr# -> Int# -> Int64# -> IO () -writeFloat# :: Addr# -> Int# -> Float# -> IO () -writeDouble# :: Addr# -> Int# -> Double# -> IO () -writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) - -fetchAddWord# :: Addr# -> Int# -> Word# -> IO () -fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) - -atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 -atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 -atomicCasWord32Addr addr (W32# expected) (W32# desired) = - IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) -atomicCasWord64Addr addr (W64# expected) (W64# desired) = - IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) |
