{-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE BangPatterns #-} module Interpreter.Accum where import Control.Monad.ST import Control.Monad.ST.Unsafe import GHC.Exts import GHC.Int import GHC.ST (ST(..)) import AST import Data import Interpreter.Array newtype AcM s a = AcM (ST s a) deriving newtype (Functor, Applicative, Monad) runAcM :: (forall s. AcM s a) -> a runAcM m = runST (case m of AcM m' -> m') type family Rep t where Rep TNil = () Rep (TPair a b) = (Rep a, Rep b) Rep (TEither a b) = Either (Rep a) (Rep b) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty -- Rep (TAccum t) = _ data Accum s t = Accum (STy t) (MutableByteArray# s) tSize :: STy t -> Rep t -> Int tSize ty x = tSize' ty (Just x) -- | Passing Nothing as the value means "this is (inside) an array element". tSize' :: STy t -> Maybe (Rep t) -> Int tSize' typ val = case typ of STNil -> 0 STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) STEither a b -> case val of Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking Just (Right y) -> 1 + tSize b y -- idem STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing STScal sty -> goScal sty STAccum{} -> error "Nested accumulators unsupported" where goScal :: SScalTy t -> Int goScal STI32 = 4 goScal STI64 = 8 goScal STF32 = 4 goScal STF64 = 8 goScal STBool = 1 -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in ST, not in AcM. accumWrite :: forall s t. Accum s t -> Rep t -> ST s () accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0# where go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int go inarr ty val off# = case ty of STNil -> return (I# off#) STPair a b -> do I# off1# <- go inarr a (fst val) off# go inarr b (snd val) off1# STEither a b -> case val of Left x -> do let !(I8# tag#) = 0 ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) go inarr a x (off# +# 1#) Right y -> do let !(I8# tag#) = 1 ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) go inarr b y (off# +# 1#) STArr _ t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do I# off1# <- goShape (arrayShape val) off# let !(I# eltsize#) = tSize' t Nothing !(I# n#) = arraySize val traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val return (I# (off1# +# eltsize# *# n#)) STScal sty -> goScal sty val off# STAccum{} -> error "Nested accumulators unsupported" goShape :: Shape n -> Int# -> ST s Int goShape ShNil off# = return (I# off#) goShape (ShCons sh n) off# = do I# off1# <- goShape sh off# let !(I64# n') = fromIntegral n ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) goScal STBool b off# = let !(I8# i) = fromIntegral (fromEnum b) in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) accumRead :: forall s t. Accum s t -> AcM s (Rep t) accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0# where go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t') go inarr ty off# = case ty of STNil -> return (I# off#, ()) STPair a b -> do (I# off1#, x) <- go inarr a off# (I# off2#, y) <- go inarr b off1# return (I# (off1# +# off2#), (x, y)) STEither a b -> do tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) if inarr then do val <- case tag of 0 -> Left . snd <$> go inarr a (off# +# 1#) 1 -> Right . snd <$> go inarr b (off# +# 1#) _ -> error "Invalid tag in accum memory" return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val) else case tag of 0 -> fmap Left <$> go inarr a (off# +# 1#) 1 -> fmap Right <$> go inarr b (off# +# 1#) _ -> error "Invalid tag in accum memory" STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do (I# off1#, sh) <- readShape mbarr ndim off# let !(I# eltsize#) = tSize' t Nothing !(I# n#) = shapeSize sh arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#)) return (I# (off1# +# eltsize# *# n#), arr) STScal sty -> goScal sty off# STAccum{} -> error "Nested accumulators unsupported" goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t') goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n) readShape _ SZ off# = return (I# off#, ShNil) readShape mbarr (SS ndim) off# = do (I# off1#, sh) <- readShape mbarr ndim off# n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) return (I# (off1# +# 8#), ShCons sh (fromIntegral n')) data InvShape full yet where IShFull :: InvShape full full IShCons :: InvShape full (S yet) -> Int -> InvShape full yet invertShape :: Shape n -> InvShape n Z invertShape accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0# where go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int go inarr ty SZ idx val off# = performAdd inarr ty val off# go inarr ty (SS dep) idx val off# = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# (STArr rank eltty, _, _) | inarr -> error "Nested arrays" | otherwise -> goArr (SS dep) rank eltty idx val off# -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off# -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off# _ -> _ goArr :: SNat i' -> SNat n -> STy t' -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int goArr dep n t1 idx val off# = do (I# off1#, sh) <- readShape mbarr n off# _ collectArrIndex :: SNat i' -> STy t' -> Shape n -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) -> ST s r collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k = collectArrIndex dep eltty sh idx val $ \index dep' idx' val' -> k (IxCons index (fromIntegral i)) dep' idx' val' -- collectArrIndex SZ eltty goShape :: Shape n -> Int# -> ST s Int goShape ShNil off# = return (I# off#) goShape (ShCons sh n) off# = do I# off1# <- goShape sh off# let !(I64# n') = fromIntegral n ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) goScal STBool b off# = let !(I8# i) = fromIntegral (fromEnum b) in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) withAccum ty start fun = do let !(I# size) = tSize ty start accum <- AcM . ST $ \s -> case newByteArray# size s of (# s', mbarr #) -> (# s', Accum ty mbarr #) AcM $ accumWrite accum start b <- fun accum out <- accumRead accum return (out, b)