diff options
Diffstat (limited to 'src/Interpreter/Accum.hs')
-rw-r--r-- | src/Interpreter/Accum.hs | 274 |
1 files changed, 162 insertions, 112 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index b0deaef..45f507b 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -1,13 +1,15 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -{-# LANGUAGE BangPatterns #-} module Interpreter.Accum where import Control.Monad.ST @@ -19,6 +21,9 @@ import GHC.ST (ST(..)) import AST import Data import Interpreter.Array +import Data.Bifunctor (second) +import Control.Monad (when, forM_) +import Foreign.Storable (sizeOf) newtype AcM s a = AcM (ST s a) @@ -68,156 +73,201 @@ tSize' typ val = case typ of -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in ST, not in AcM. accumWrite :: forall s t. Accum s t -> Rep t -> ST s () -accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0# +accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0 where - go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int - go inarr ty val off# = case ty of - STNil -> return (I# off#) + go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int + go inarr ty val off = case ty of + STNil -> return off STPair a b -> do - I# off1# <- go inarr a (fst val) off# - go inarr b (snd val) off1# - STEither a b -> + off1 <- go inarr a (fst val) off + go inarr b (snd val) off1 + STEither a b -> do + let !(I# off#) = off case val of Left x -> do let !(I8# tag#) = 0 ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) - go inarr a x (off# +# 1#) + go inarr a x (off + 1) Right y -> do let !(I8# tag#) = 1 ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) - go inarr b y (off# +# 1#) + go inarr b y (off + 1) STArr _ t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do - I# off1# <- goShape (arrayShape val) off# - let !(I# eltsize#) = tSize' t Nothing - !(I# n#) = arraySize val - traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val - return (I# (off1# +# eltsize# *# n#)) - STScal sty -> goScal sty val off# + off1 <- goShape (arrayShape val) off + let eltsize = tSize' t Nothing + n = arraySize val + traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val + return (off1 + eltsize * n) + STScal sty -> goScal sty val off STAccum{} -> error "Nested accumulators unsupported" - goShape :: Shape n -> Int# -> ST s Int - goShape ShNil off# = return (I# off#) - goShape (ShCons sh n) off# = do - I# off1# <- goShape sh off# - let !(I64# n') = fromIntegral n - ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) - - goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int - goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) - goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) - goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) - goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) - goScal STBool b off# = + goShape :: Shape n -> Int -> ST s Int + goShape ShNil off = return off + goShape (ShCons sh n) off = do + off1@(I# off1#) <- goShape sh off + let !(I64# n'#) = fromIntegral n + ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #) + + goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int + goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) + goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) + goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) + goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) + goScal STBool b (I# off#) = let !(I8# i) = fromIntegral (fromEnum b) in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) accumRead :: forall s t. Accum s t -> AcM s (Rep t) -accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0# +accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 where - go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t') - go inarr ty off# = case ty of - STNil -> return (I# off#, ()) + go :: Bool -> STy t' -> Int -> ST s (Int, Rep t') + go inarr ty off = case ty of + STNil -> return (off, ()) STPair a b -> do - (I# off1#, x) <- go inarr a off# - (I# off2#, y) <- go inarr b off1# - return (I# (off1# +# off2#), (x, y)) + (off1, x) <- go inarr a off + (off2, y) <- go inarr b off1 + return (off1 + off2, (x, y)) STEither a b -> do + let !(I# off#) = off tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) + (off1, val) <- case tag of + 0 -> fmap Left <$> go inarr a (off + 1) + 1 -> fmap Right <$> go inarr b (off + 1) + _ -> error "Invalid tag in accum memory" if inarr - then do val <- case tag of - 0 -> Left . snd <$> go inarr a (off# +# 1#) - 1 -> Right . snd <$> go inarr b (off# +# 1#) - _ -> error "Invalid tag in accum memory" - return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val) - else case tag of - 0 -> fmap Left <$> go inarr a (off# +# 1#) - 1 -> fmap Right <$> go inarr b (off# +# 1#) - _ -> error "Invalid tag in accum memory" + then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val) + else return (off1, val) STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do - (I# off1#, sh) <- readShape mbarr ndim off# - let !(I# eltsize#) = tSize' t Nothing - !(I# n#) = shapeSize sh - arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#)) - return (I# (off1# +# eltsize# *# n#), arr) - STScal sty -> goScal sty off# + (off1, sh) <- readShape mbarr ndim off + let eltsize = tSize' t Nothing + n = shapeSize sh + arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) + return (off1 + eltsize * n, arr) + STScal sty -> goScal sty off STAccum{} -> error "Nested accumulators unsupported" - goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t') - goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) - goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) - goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) - goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) - goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) - -readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n) -readShape _ SZ off# = return (I# off#, ShNil) -readShape mbarr (SS ndim) off# = do - (I# off1#, sh) <- readShape mbarr ndim off# + goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t') + goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) + goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) + goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) + goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) + goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) + +readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do + (off1@(I# off1#), sh) <- readShape mbarr ndim off n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) - return (I# (off1# +# 8#), ShCons sh (fromIntegral n')) + return (off1 + 8, ShCons sh (fromIntegral n')) -data InvShape full yet where - IShFull :: InvShape full full - IShCons :: InvShape full (S yet) -> Int -> InvShape full yet +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where + IShNil :: InvShape Z + IShCons :: Int -- ^ How many subarrays are there? + -> Int -- ^ What is the size of all subarrays together? + -> InvShape n -- ^ Sub array inverted shape + -> InvShape (S n) -invertShape :: Shape n -> InvShape n Z -invertShape +ishSize :: InvShape n -> Int +ishSize IShNil = 0 +ishSize (IShCons _ sz _) = sz +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil + where + go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) + go ShNil ish = ish + go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () -accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0# +accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0 where - go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int - go inarr ty SZ idx val off# = performAdd inarr ty val off# - go inarr ty (SS dep) idx val off# = case (ty, idx, val) of - (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# - (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# - (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# - (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# + go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s () + go inarr ty SZ () val off = () <$ performAdd inarr ty val off + go inarr ty (SS dep) idx val off = case (ty, idx, val) of + (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" + (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" (STArr rank eltty, _, _) | inarr -> error "Nested arrays" - | otherwise -> goArr (SS dep) rank eltty idx val off# - -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off# - -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off# - _ -> _ - - goArr :: SNat i' -> SNat n -> STy t' - -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int - goArr dep n t1 idx val off# = do - (I# off1#, sh) <- readShape mbarr n off# - _ - - collectArrIndex :: SNat i' -> STy t' -> Shape n - -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') - -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) - -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) - -> ST s r - collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val - collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k = - collectArrIndex dep eltty sh idx val $ \index dep' idx' val' -> - k (IxCons index (fromIntegral i)) dep' idx' val' - -- collectArrIndex SZ eltty - - goShape :: Shape n -> Int# -> ST s Int - goShape ShNil off# = return (I# off#) - goShape (ShCons sh n) off# = do - I# off1# <- goShape sh off# - let !(I64# n') = fromIntegral n - ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) - - goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int - goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) - goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) - goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) - goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) - goScal STBool b off# = + | otherwise -> do + (off1, ish) <- second invertShape <$> readShape mbarr rank off + goArr (SS dep) ish eltty idx val off1 + (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" + (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" + (STAccum{}, _, _) -> error "Nested accumulators unsupported" + + goArr :: SNat i' -> InvShape n -> STy t' + -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s () + goArr SZ ish t1 () val off = performAddArr ish t1 val off + goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off + goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do + let i' = fromIntegral @(Rep TIx) @Int i + when (i' < 0 || i' >= n) $ + error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" + goArr depm1 ish t1 idx val (off + i' * ishSize ish) + + performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int + performAdd inarr ty val off = case ty of + STNil -> return off + STPair t1 t2 -> do + off1 <- performAdd inarr t1 (fst val) off + performAdd inarr t2 (snd val) off1 + STEither t1 t2 -> do + let !(I# off#) = off + tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) + off1 <- case (val, tag) of + (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) + (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) + _ -> error "accumAdd: Tag mismatch for Either" + if inarr + then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing)) + else return off1 + STArr n ty' + | inarr -> error "Nested array" + | otherwise -> do + (off1, sh) <- readShape mbarr n off + let sz = shapeSize sh + eltsize = tSize' ty' Nothing + forM_ [0 .. sz - 1] $ \lini -> + performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize) + return (off1 + sz * eltsize) + STScal ty' -> performAddScal ty' val off + STAccum{} -> error "Nested accumulators unsupported" + + performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int + performAddScal STI32 (I32# x#) (I# off#) + | sizeOf (undefined :: Int) == 4 + = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of + (# s', _ #) -> (# s', I# (off# +# 4#) #) + | otherwise + = _ + performAddScal STI64 (I64# x#) (I# off#) + | sizeOf (undefined :: Int) == 8 + = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of + (# s', _ #) -> (# s', I# (off# +# 8#) #) + | otherwise + = _ + performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) + performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) + performAddScal STBool b (I# off#) = let !(I8# i) = fromIntegral (fromEnum b) in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #)) + -> (w -> w) + -> r + -> State# d -> (# State# d, r #) + casLoop casOp add ret = _ + withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) withAccum ty start fun = do let !(I# size) = tSize ty start |