diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-13 23:07:04 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-13 23:07:04 +0200 | 
| commit | 94938d648e021d2ace0f3b7bf383d256449d619f (patch) | |
| tree | ef077de27b08027c7117761c3efc7d29b7ad3d56 /src/Interpreter/AccumOld.hs | |
| parent | 3d8a6cca424fc5279c43a266900160feb28aa715 (diff) | |
WIP better zero/plus, fixing Accum (...)
The accumulator implementation was wrong because it forgot (in accumAdd)
to take into account that values may be variably-sized. Furthermore, it
was also complexity-inefficient because it did not build up a sparse
value. Thus let's go for the Haskell-interpreter-equivalent of what a
real, fast, compiled implementation would do: just a tree with mutable
variables. In practice one can decide to indeed flatten parts of that
tree, i.e. using a tree representation for nested pairs is bad, but that
should have been done _before_ execution and for _all_ occurrences of
that type fragment, not live at runtime by the accumulator
implementation.
Diffstat (limited to 'src/Interpreter/AccumOld.hs')
| -rw-r--r-- | src/Interpreter/AccumOld.hs | 366 | 
1 files changed, 366 insertions, 0 deletions
| diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs new file mode 100644 index 0000000..af7be1e --- /dev/null +++ b/src/Interpreter/AccumOld.hs @@ -0,0 +1,366 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +module Interpreter.Accum ( +  AcM, +  runAcM, +  Rep', +  Accum, +  withAccum, +  accumAdd, +  inParallel, +) where + +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) +import Foreign.Storable (sizeOf) +import GHC.Exts +import GHC.Float +import GHC.Int +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) + +import Array +import AST +import Data + + +newtype AcM s a = AcM (IO a) +  deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where +  Rep' s TNil = () +  Rep' s (TPair a b) = (Rep' s a, Rep' s b) +  Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) +  Rep' s (TMaybe t) = Maybe (Rep' s t) +  Rep' s (TArr n t) = Array n (Rep' s t) +  Rep' s (TScal sty) = ScalRep sty +  Rep' s (TAccum t) = Accum s t + +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) + +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) + +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of +  STNil -> 0 +  STPair a b -> tSize' p a + tSize' p b +  STEither a b -> 1 + max (tSize' p a) (tSize' p b) +  -- Representation of Maybe t is the same as Either () t; the add operation is different, however. +  STMaybe t -> tSize' p (STEither STNil t) +  STArr ndim t -> +    case val of +      Nothing -> error "Nested arrays not supported in this implementation" +      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing +  STScal sty -> goScal sty +  STAccum{} -> error "Nested accumulators unsupported" +  where +    goScal :: SScalTy t -> Int +    goScal STI32 = 4 +    goScal STI64 = 8 +    goScal STF32 = 4 +    goScal STF64 = 8 +    goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int +    go inarr ty val off = case ty of +      STNil -> return off +      STPair a b -> do +        off1 <- go inarr a (fst val) off +        go inarr b (snd val) off1 +      STEither a b -> do +        let !(I# off#) = off +        off1 <- case val of +          Left x -> do +            let !(I8# tag#) = 0 +            writeInt8# addr# off# tag# +            go inarr a x (off + 1) +          Right y -> do +            let !(I8# tag#) = 1 +            writeInt8# addr# off# tag# +            go inarr b y (off + 1) +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) +          else return off1 +      -- Representation is the same, but add operation is different +      STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off +      STArr _ t +        | inarr -> error "Nested arrays not supported in this implementation" +        | otherwise -> do +            off1 <- goShape (arrayShape val) off +            let eltsize = tSize' (Proxy @s) t Nothing +                n = arraySize val +            traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val +            return (off1 + eltsize * n) +      STScal sty -> goScal sty val off +      STAccum{} -> error "Nested accumulators unsupported" + +    goShape :: Shape n -> Int -> IO Int +    goShape ShNil off = return off +    goShape (ShCons sh n) off = do +      off1@(I# off1#) <- goShape sh off +      let !(I64# n'#) = fromIntegral n +      writeInt64# addr# off1# n'# +      return (off1 + 8) + +    goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int +    goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x +    goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x +    goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x +    goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x +    goScal STBool b off@(I# off#) = do +      let !(I8# i) = fromIntegral (fromEnum b) +      off + 1 <$ writeInt8# addr# off# i + +  in () <$ go False topty top_value 0 + +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') +    go inarr ty off = case ty of +      STNil -> return (off, ()) +      STPair a b -> do +        (off1, x) <- go inarr a off +        (off2, y) <- go inarr b off1 +        return (off1 + off2, (x, y)) +      STEither a b -> do +        let !(I# off#) = off +        tag <- readInt8 addr# off# +        (off1, val) <- case tag of +                         0 -> fmap Left <$> go inarr a (off + 1) +                         1 -> fmap Right <$> go inarr b (off + 1) +                         _ -> error "Invalid tag in accum memory" +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) +          else return (off1, val) +      -- Representation is the same, but add operation is different +      STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t)  off +      STArr ndim t +        | inarr -> error "Nested arrays not supported in this implementation" +        | otherwise -> do +            (off1, sh) <- readShape addr# ndim off +            let eltsize = tSize' (Proxy @s) t Nothing +                n = shapeSize sh +            arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) +            return (off1 + eltsize * n, arr) +      STScal sty -> goScal sty off +      STAccum{} -> error "Nested accumulators unsupported" + +    goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') +    goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# +    goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# +    goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# +    goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# +    goScal STBool off@(I# off#) = do +      i8 <- readInt8 addr# off# +      return (off + 1, toEnum (fromIntegral i8)) + +  in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do +  (off1@(I# off1#), sh) <- readShape mbarr ndim off +  n' <- readInt64 mbarr off1# +  return (off1 + 8, ShCons sh (fromIntegral n')) + +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where +  IShNil :: InvShape Z +  IShCons :: Int  -- ^ How many subarrays are there? +          -> Int  -- ^ What is the size of all subarrays together? +          -> InvShape n  -- ^ Sub array inverted shape +          -> InvShape (S n) + +ishSize :: InvShape n -> Int +ishSize IShNil = 1 +ishSize (IShCons _ sz _) = sz + +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil +  where +    go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) +    go ShNil ish = ish +    go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () +    go inarr ty SZ () val off = () <$ performAdd inarr ty val off +    go inarr ty (SS dep) idx val off = case (ty, idx, val) of +      (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off +      (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off +      (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" +      (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off +      (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off +      (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" +      (STMaybe t, _, _) -> _ idx val +      (STArr rank eltty, _, _) +        | inarr -> error "Nested arrays" +        | otherwise -> do +            (off1, ish) <- second invertShape <$> readShape addr# rank off +            goArr (SS dep) ish eltty idx val off1 +      (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" +      (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" +      (STAccum{}, _, _) -> error "Nested accumulators unsupported" + +    goArr :: SNat i' -> InvShape n -> STy t' +          -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () +    goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off +    goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off +    goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do +      let i' = fromIntegral @(Rep' s TIx) @Int i +      when (i' < 0 || i' >= n) $ +        error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" +      goArr depm1 ish t1 idx val (off + i' * ishSize ish) + +    performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int +    performAddArr arraySz eltty val off = do +      let eltsize = tSize' (Proxy @s) eltty Nothing +      forM_ [0 .. arraySz - 1] $ \lini -> +        performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) +      return (off + arraySz * eltsize) + +    performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int +    performAdd inarr ty val off = case ty of +      STNil -> return off +      STPair t1 t2 -> do +        off1 <- performAdd inarr t1 (fst val) off +        performAdd inarr t2 (snd val) off1 +      STEither t1 t2 -> do +        let !(I# off#) = off +        tag <- readInt8 addr# off# +        off1 <- case (val, tag) of +                  (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) +                  (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) +                  _ -> error "accumAdd: Tag mismatch for Either" +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) +          else return off1 +      STArr n ty' +        | inarr -> error "Nested array" +        | otherwise -> do +            (off1, sh) <- readShape addr# n off +            performAddArr (shapeSize sh) ty' val off1 +      STScal ty' -> performAddScal ty' val off +      STAccum{} -> error "Nested accumulators unsupported" + +    performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int +    performAddScal STI32 (I32# x#) off@(I# off#) +      | sizeOf (undefined :: Int) == 4 +      = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) +      | otherwise +      = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) +    performAddScal STI64 (I64# x#) off@(I# off#) +      | sizeOf (undefined :: Int) == 8 +      = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) +      | otherwise +      = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) +    performAddScal STF32 x off@(I# off#) = +      off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) +    performAddScal STF64 x off@(I# off#) = +      off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) +    performAddScal STBool _ off = return (off + 1)  -- don't do anything with booleans + +    casLoop :: Eq w +            => (Addr# -> Int# -> IO w)    -- ^ read value (from a given byte offset; will get 0#) +            -> (Addr# -> w -> w -> IO w)  -- ^ CAS value at address (expected -> desired -> IO observed) +            -> Addr#                      -- ^ Address to attempt to modify +            -> (w -> w)                   -- ^ Operation to apply to the value +            -> IO () +    casLoop readOp casOp addr modify = readOp addr 0# >>= loop +      where +        loop value = do +          value' <- casOp addr value (modify value) +          if value == value' +            then return () +            else loop value' + +  in () <$ go False topty top_depth top_index top_value 0 + +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) +withAccum ty start fun = do +  -- The initial write must happen before any of the adds or reads, so it makes +  -- sense to put it in IO together with the allocation, instead of in AcM. +  accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) +                    ptr <- newForeignPtr finalizerFree buffer +                    let accum = Accum ty ptr +                    accumWrite accum start +                    return accum +  b <- fun accum +  out <- accumRead accum +  return (b, out) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do +  mvars <- mapM (\_ -> newEmptyMVar) actions +  forM_ (zip actions mvars) $ \(AcM action, var) -> +    forkIO $ action >>= putMVar var +  mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8   :: Addr# -> Int# -> IO Int8 +readInt32  :: Addr# -> Int# -> IO Int32 +readInt64  :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat  :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8   addr off# = IO $ \s -> case readInt8OffAddr#   (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8#  val #) +readInt32  addr off# = IO $ \s -> case readInt32OffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64  addr off# = IO $ \s -> case readInt64OffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat  addr off# = IO $ \s -> case readFloatOffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F#   val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D#   val #) + +writeInt8#   :: Addr# -> Int# -> Int8#   -> IO () +writeInt32#  :: Addr# -> Int# -> Int32#  -> IO () +writeInt64#  :: Addr# -> Int# -> Int64#  -> IO () +writeFloat#  :: Addr# -> Int# -> Float#  -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8#   addr off# val = IO $ \s -> (# writeInt8OffAddr#   (addr `plusAddr#` off#) 0# val s, () #) +writeInt32#  addr off# val = IO $ \s -> (# writeInt32OffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeInt64#  addr off# val = IO $ \s -> (# writeInt64OffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeFloat#  addr off# val = IO $ \s -> (# writeFloatOffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = +  IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = +  IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) | 
