summaryrefslogtreecommitdiff
path: root/src/Interpreter/Accum.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter/Accum.hs')
-rw-r--r--src/Interpreter/Accum.hs274
1 files changed, 162 insertions, 112 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index b0deaef..45f507b 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -1,13 +1,15 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
-{-# LANGUAGE BangPatterns #-}
module Interpreter.Accum where
import Control.Monad.ST
@@ -19,6 +21,9 @@ import GHC.ST (ST(..))
import AST
import Data
import Interpreter.Array
+import Data.Bifunctor (second)
+import Control.Monad (when, forM_)
+import Foreign.Storable (sizeOf)
newtype AcM s a = AcM (ST s a)
@@ -68,156 +73,201 @@ tSize' typ val = case typ of
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in ST, not in AcM.
accumWrite :: forall s t. Accum s t -> Rep t -> ST s ()
-accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0#
+accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0
where
- go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int
- go inarr ty val off# = case ty of
- STNil -> return (I# off#)
+ go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int
+ go inarr ty val off = case ty of
+ STNil -> return off
STPair a b -> do
- I# off1# <- go inarr a (fst val) off#
- go inarr b (snd val) off1#
- STEither a b ->
+ off1 <- go inarr a (fst val) off
+ go inarr b (snd val) off1
+ STEither a b -> do
+ let !(I# off#) = off
case val of
Left x -> do
let !(I8# tag#) = 0
ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
- go inarr a x (off# +# 1#)
+ go inarr a x (off + 1)
Right y -> do
let !(I8# tag#) = 1
ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
- go inarr b y (off# +# 1#)
+ go inarr b y (off + 1)
STArr _ t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
- I# off1# <- goShape (arrayShape val) off#
- let !(I# eltsize#) = tSize' t Nothing
- !(I# n#) = arraySize val
- traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val
- return (I# (off1# +# eltsize# *# n#))
- STScal sty -> goScal sty val off#
+ off1 <- goShape (arrayShape val) off
+ let eltsize = tSize' t Nothing
+ n = arraySize val
+ traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
+ return (off1 + eltsize * n)
+ STScal sty -> goScal sty val off
STAccum{} -> error "Nested accumulators unsupported"
- goShape :: Shape n -> Int# -> ST s Int
- goShape ShNil off# = return (I# off#)
- goShape (ShCons sh n) off# = do
- I# off1# <- goShape sh off#
- let !(I64# n') = fromIntegral n
- ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #)
-
- goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int
- goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
- goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
- goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
- goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
- goScal STBool b off# =
+ goShape :: Shape n -> Int -> ST s Int
+ goShape ShNil off = return off
+ goShape (ShCons sh n) off = do
+ off1@(I# off1#) <- goShape sh off
+ let !(I64# n'#) = fromIntegral n
+ ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int
+ goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
+ goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
+ goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
+ goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
+ goScal STBool b (I# off#) =
let !(I8# i) = fromIntegral (fromEnum b)
in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)
accumRead :: forall s t. Accum s t -> AcM s (Rep t)
-accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0#
+accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0
where
- go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t')
- go inarr ty off# = case ty of
- STNil -> return (I# off#, ())
+ go :: Bool -> STy t' -> Int -> ST s (Int, Rep t')
+ go inarr ty off = case ty of
+ STNil -> return (off, ())
STPair a b -> do
- (I# off1#, x) <- go inarr a off#
- (I# off2#, y) <- go inarr b off1#
- return (I# (off1# +# off2#), (x, y))
+ (off1, x) <- go inarr a off
+ (off2, y) <- go inarr b off1
+ return (off1 + off2, (x, y))
STEither a b -> do
+ let !(I# off#) = off
tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
+ (off1, val) <- case tag of
+ 0 -> fmap Left <$> go inarr a (off + 1)
+ 1 -> fmap Right <$> go inarr b (off + 1)
+ _ -> error "Invalid tag in accum memory"
if inarr
- then do val <- case tag of
- 0 -> Left . snd <$> go inarr a (off# +# 1#)
- 1 -> Right . snd <$> go inarr b (off# +# 1#)
- _ -> error "Invalid tag in accum memory"
- return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val)
- else case tag of
- 0 -> fmap Left <$> go inarr a (off# +# 1#)
- 1 -> fmap Right <$> go inarr b (off# +# 1#)
- _ -> error "Invalid tag in accum memory"
+ then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val)
+ else return (off1, val)
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
- (I# off1#, sh) <- readShape mbarr ndim off#
- let !(I# eltsize#) = tSize' t Nothing
- !(I# n#) = shapeSize sh
- arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#))
- return (I# (off1# +# eltsize# *# n#), arr)
- STScal sty -> goScal sty off#
+ (off1, sh) <- readShape mbarr ndim off
+ let eltsize = tSize' t Nothing
+ n = shapeSize sh
+ arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
+ return (off1 + eltsize * n, arr)
+ STScal sty -> goScal sty off
STAccum{} -> error "Nested accumulators unsupported"
- goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t')
- goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #)
- goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #)
- goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #)
- goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #)
- goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #)
-
-readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n)
-readShape _ SZ off# = return (I# off#, ShNil)
-readShape mbarr (SS ndim) off# = do
- (I# off1#, sh) <- readShape mbarr ndim off#
+ goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t')
+ goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #)
+ goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #)
+ goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #)
+ goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #)
+ goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #)
+
+readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n)
+readShape _ SZ off = return (off, ShNil)
+readShape mbarr (SS ndim) off = do
+ (off1@(I# off1#), sh) <- readShape mbarr ndim off
n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #)
- return (I# (off1# +# 8#), ShCons sh (fromIntegral n'))
+ return (off1 + 8, ShCons sh (fromIntegral n'))
-data InvShape full yet where
- IShFull :: InvShape full full
- IShCons :: InvShape full (S yet) -> Int -> InvShape full yet
+-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
+-- the list.
+data InvShape n where
+ IShNil :: InvShape Z
+ IShCons :: Int -- ^ How many subarrays are there?
+ -> Int -- ^ What is the size of all subarrays together?
+ -> InvShape n -- ^ Sub array inverted shape
+ -> InvShape (S n)
-invertShape :: Shape n -> InvShape n Z
-invertShape
+ishSize :: InvShape n -> Int
+ishSize IShNil = 0
+ishSize (IShCons _ sz _) = sz
+invertShape :: forall n. Shape n -> InvShape n
+invertShape | Refl <- lemPlusZero @n = flip go IShNil
+ where
+ go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
+ go ShNil ish = ish
+ go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
-accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0#
+accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0
where
- go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int
- go inarr ty SZ idx val off# = performAdd inarr ty val off#
- go inarr ty (SS dep) idx val off# = case (ty, idx, val) of
- (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off#
- (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off#
- (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off#
- (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off#
+ go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s ()
+ go inarr ty SZ () val off = () <$ performAdd inarr ty val off
+ go inarr ty (SS dep) idx val off = case (ty, idx, val) of
+ (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
+ (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
(STArr rank eltty, _, _)
| inarr -> error "Nested arrays"
- | otherwise -> goArr (SS dep) rank eltty idx val off#
- -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off#
- -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off#
- _ -> _
-
- goArr :: SNat i' -> SNat n -> STy t'
- -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int
- goArr dep n t1 idx val off# = do
- (I# off1#, sh) <- readShape mbarr n off#
- _
-
- collectArrIndex :: SNat i' -> STy t' -> Shape n
- -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i')
- -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r)
- -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r)
- -> ST s r
- collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val
- collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k =
- collectArrIndex dep eltty sh idx val $ \index dep' idx' val' ->
- k (IxCons index (fromIntegral i)) dep' idx' val'
- -- collectArrIndex SZ eltty
-
- goShape :: Shape n -> Int# -> ST s Int
- goShape ShNil off# = return (I# off#)
- goShape (ShCons sh n) off# = do
- I# off1# <- goShape sh off#
- let !(I64# n') = fromIntegral n
- ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #)
-
- goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int
- goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
- goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
- goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
- goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
- goScal STBool b off# =
+ | otherwise -> do
+ (off1, ish) <- second invertShape <$> readShape mbarr rank off
+ goArr (SS dep) ish eltty idx val off1
+ (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
+ (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
+ (STAccum{}, _, _) -> error "Nested accumulators unsupported"
+
+ goArr :: SNat i' -> InvShape n -> STy t'
+ -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s ()
+ goArr SZ ish t1 () val off = performAddArr ish t1 val off
+ goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
+ goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
+ let i' = fromIntegral @(Rep TIx) @Int i
+ when (i' < 0 || i' >= n) $
+ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
+ goArr depm1 ish t1 idx val (off + i' * ishSize ish)
+
+ performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int
+ performAdd inarr ty val off = case ty of
+ STNil -> return off
+ STPair t1 t2 -> do
+ off1 <- performAdd inarr t1 (fst val) off
+ performAdd inarr t2 (snd val) off1
+ STEither t1 t2 -> do
+ let !(I# off#) = off
+ tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
+ off1 <- case (val, tag) of
+ (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
+ (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
+ _ -> error "accumAdd: Tag mismatch for Either"
+ if inarr
+ then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing))
+ else return off1
+ STArr n ty'
+ | inarr -> error "Nested array"
+ | otherwise -> do
+ (off1, sh) <- readShape mbarr n off
+ let sz = shapeSize sh
+ eltsize = tSize' ty' Nothing
+ forM_ [0 .. sz - 1] $ \lini ->
+ performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize)
+ return (off1 + sz * eltsize)
+ STScal ty' -> performAddScal ty' val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int
+ performAddScal STI32 (I32# x#) (I# off#)
+ | sizeOf (undefined :: Int) == 4
+ = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of
+ (# s', _ #) -> (# s', I# (off# +# 4#) #)
+ | otherwise
+ = _
+ performAddScal STI64 (I64# x#) (I# off#)
+ | sizeOf (undefined :: Int) == 8
+ = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of
+ (# s', _ #) -> (# s', I# (off# +# 8#) #)
+ | otherwise
+ = _
+ performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
+ performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
+ performAddScal STBool b (I# off#) =
let !(I8# i) = fromIntegral (fromEnum b)
in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)
+ casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #))
+ -> (w -> w)
+ -> r
+ -> State# d -> (# State# d, r #)
+ casLoop casOp add ret = _
+
withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
withAccum ty start fun = do
let !(I# size) = tSize ty start