{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} module Interpreter.Accum where import Control.Monad.ST import Control.Monad.ST.Unsafe import GHC.Exts import GHC.Int import GHC.ST (ST(..)) import AST import Data import Interpreter.Array import Data.Bifunctor (second) import Control.Monad (when, forM_) import Foreign.Storable (sizeOf) newtype AcM s a = AcM (ST s a) deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a runAcM m = runST (case m of AcM m' -> m') type family Rep t where Rep TNil = () Rep (TPair a b) = (Rep a, Rep b) Rep (TEither a b) = Either (Rep a) (Rep b) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty -- Rep (TAccum t) = _ data Accum s t = Accum (STy t) (MutableByteArray# s) tSize :: STy t -> Rep t -> Int tSize ty x = tSize' ty (Just x) -- | Passing Nothing as the value means "this is (inside) an array element". tSize' :: STy t -> Maybe (Rep t) -> Int tSize' typ val = case typ of STNil -> 0 STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) STEither a b -> case val of Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking Just (Right y) -> 1 + tSize b y -- idem STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing STScal sty -> goScal sty STAccum{} -> error "Nested accumulators unsupported" where goScal :: SScalTy t -> Int goScal STI32 = 4 goScal STI64 = 8 goScal STF32 = 4 goScal STF64 = 8 goScal STBool = 1 -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in ST, not in AcM. accumWrite :: forall s t. Accum s t -> Rep t -> ST s () accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0 where go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int go inarr ty val off = case ty of STNil -> return off STPair a b -> do off1 <- go inarr a (fst val) off go inarr b (snd val) off1 STEither a b -> do let !(I# off#) = off case val of Left x -> do let !(I8# tag#) = 0 ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) go inarr a x (off + 1) Right y -> do let !(I8# tag#) = 1 ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) go inarr b y (off + 1) STArr _ t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do off1 <- goShape (arrayShape val) off let eltsize = tSize' t Nothing n = arraySize val traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val return (off1 + eltsize * n) STScal sty -> goScal sty val off STAccum{} -> error "Nested accumulators unsupported" goShape :: Shape n -> Int -> ST s Int goShape ShNil off = return off goShape (ShCons sh n) off = do off1@(I# off1#) <- goShape sh off let !(I64# n'#) = fromIntegral n ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #) goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) goScal STBool b (I# off#) = let !(I8# i) = fromIntegral (fromEnum b) in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) accumRead :: forall s t. Accum s t -> AcM s (Rep t) accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0 where go :: Bool -> STy t' -> Int -> ST s (Int, Rep t') go inarr ty off = case ty of STNil -> return (off, ()) STPair a b -> do (off1, x) <- go inarr a off (off2, y) <- go inarr b off1 return (off1 + off2, (x, y)) STEither a b -> do let !(I# off#) = off tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) (off1, val) <- case tag of 0 -> fmap Left <$> go inarr a (off + 1) 1 -> fmap Right <$> go inarr b (off + 1) _ -> error "Invalid tag in accum memory" if inarr then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val) else return (off1, val) STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do (off1, sh) <- readShape mbarr ndim off let eltsize = tSize' t Nothing n = shapeSize sh arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) return (off1 + eltsize * n, arr) STScal sty -> goScal sty off STAccum{} -> error "Nested accumulators unsupported" goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t') goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n) readShape _ SZ off = return (off, ShNil) readShape mbarr (SS ndim) off = do (off1@(I# off1#), sh) <- readShape mbarr ndim off n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) return (off1 + 8, ShCons sh (fromIntegral n')) -- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of -- the list. data InvShape n where IShNil :: InvShape Z IShCons :: Int -- ^ How many subarrays are there? -> Int -- ^ What is the size of all subarrays together? -> InvShape n -- ^ Sub array inverted shape -> InvShape (S n) ishSize :: InvShape n -> Int ishSize IShNil = 0 ishSize (IShCons _ sz _) = sz invertShape :: forall n. Shape n -> InvShape n invertShape | Refl <- lemPlusZero @n = flip go IShNil where go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) go ShNil ish = ish go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0 where go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s () go inarr ty SZ () val off = () <$ performAdd inarr ty val off go inarr ty (SS dep) idx val off = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" (STArr rank eltty, _, _) | inarr -> error "Nested arrays" | otherwise -> do (off1, ish) <- second invertShape <$> readShape mbarr rank off goArr (SS dep) ish eltty idx val off1 (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" (STAccum{}, _, _) -> error "Nested accumulators unsupported" goArr :: SNat i' -> InvShape n -> STy t' -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s () goArr SZ ish t1 () val off = performAddArr ish t1 val off goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do let i' = fromIntegral @(Rep TIx) @Int i when (i' < 0 || i' >= n) $ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" goArr depm1 ish t1 idx val (off + i' * ishSize ish) performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int performAdd inarr ty val off = case ty of STNil -> return off STPair t1 t2 -> do off1 <- performAdd inarr t1 (fst val) off performAdd inarr t2 (snd val) off1 STEither t1 t2 -> do let !(I# off#) = off tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) off1 <- case (val, tag) of (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) _ -> error "accumAdd: Tag mismatch for Either" if inarr then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing)) else return off1 STArr n ty' | inarr -> error "Nested array" | otherwise -> do (off1, sh) <- readShape mbarr n off let sz = shapeSize sh eltsize = tSize' ty' Nothing forM_ [0 .. sz - 1] $ \lini -> performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize) return (off1 + sz * eltsize) STScal ty' -> performAddScal ty' val off STAccum{} -> error "Nested accumulators unsupported" performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int performAddScal STI32 (I32# x#) (I# off#) | sizeOf (undefined :: Int) == 4 = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of (# s', _ #) -> (# s', I# (off# +# 4#) #) | otherwise = _ performAddScal STI64 (I64# x#) (I# off#) | sizeOf (undefined :: Int) == 8 = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of (# s', _ #) -> (# s', I# (off# +# 8#) #) | otherwise = _ performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) performAddScal STBool b (I# off#) = let !(I8# i) = fromIntegral (fromEnum b) in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #)) -> (w -> w) -> r -> State# d -> (# State# d, r #) casLoop casOp add ret = _ withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) withAccum ty start fun = do let !(I# size) = tSize ty start accum <- AcM . ST $ \s -> case newByteArray# size s of (# s', mbarr #) -> (# s', Accum ty mbarr #) AcM $ accumWrite accum start b <- fun accum out <- accumRead accum return (out, b)