diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-09-06 22:22:09 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-06 22:22:09 +0200 |
commit | 8942db88390ece8b1cf018ad723fedc6ae82cf64 (patch) | |
tree | 6c4bc400655f090db3c7deb19157eedcdd9cdb0e /src/Interpreter | |
parent | 0f94ac819d664b0c1f8feaf567648a3724b5eadb (diff) |
WIP interpreter
Diffstat (limited to 'src/Interpreter')
-rw-r--r-- | src/Interpreter/Accum.hs | 229 | ||||
-rw-r--r-- | src/Interpreter/Array.hs | 42 |
2 files changed, 271 insertions, 0 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs new file mode 100644 index 0000000..b0deaef --- /dev/null +++ b/src/Interpreter/Accum.hs @@ -0,0 +1,229 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE BangPatterns #-} +module Interpreter.Accum where + +import Control.Monad.ST +import Control.Monad.ST.Unsafe +import GHC.Exts +import GHC.Int +import GHC.ST (ST(..)) + +import AST +import Data +import Interpreter.Array + + +newtype AcM s a = AcM (ST s a) + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM m = runST (case m of AcM m' -> m') + +type family Rep t where + Rep TNil = () + Rep (TPair a b) = (Rep a, Rep b) + Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TArr n t) = Array n (Rep t) + Rep (TScal sty) = ScalRep sty + -- Rep (TAccum t) = _ + +data Accum s t = Accum (STy t) (MutableByteArray# s) + +tSize :: STy t -> Rep t -> Int +tSize ty x = tSize' ty (Just x) + +-- | Passing Nothing as the value means "this is (inside) an array element". +tSize' :: STy t -> Maybe (Rep t) -> Int +tSize' typ val = case typ of + STNil -> 0 + STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) + STEither a b -> + case val of + Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) + Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking + Just (Right y) -> 1 + tSize b y -- idem + STArr ndim t -> + case val of + Nothing -> error "Nested arrays not supported in this implementation" + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing + STScal sty -> goScal sty + STAccum{} -> error "Nested accumulators unsupported" + where + goScal :: SScalTy t -> Int + goScal STI32 = 4 + goScal STI64 = 8 + goScal STF32 = 4 + goScal STF64 = 8 + goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in ST, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep t -> ST s () +accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0# + where + go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int + go inarr ty val off# = case ty of + STNil -> return (I# off#) + STPair a b -> do + I# off1# <- go inarr a (fst val) off# + go inarr b (snd val) off1# + STEither a b -> + case val of + Left x -> do + let !(I8# tag#) = 0 + ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) + go inarr a x (off# +# 1#) + Right y -> do + let !(I8# tag#) = 1 + ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) + go inarr b y (off# +# 1#) + STArr _ t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + I# off1# <- goShape (arrayShape val) off# + let !(I# eltsize#) = tSize' t Nothing + !(I# n#) = arraySize val + traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val + return (I# (off1# +# eltsize# *# n#)) + STScal sty -> goScal sty val off# + STAccum{} -> error "Nested accumulators unsupported" + + goShape :: Shape n -> Int# -> ST s Int + goShape ShNil off# = return (I# off#) + goShape (ShCons sh n) off# = do + I# off1# <- goShape sh off# + let !(I64# n') = fromIntegral n + ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) + + goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int + goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) + goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) + goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) + goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) + goScal STBool b off# = + let !(I8# i) = fromIntegral (fromEnum b) + in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + +accumRead :: forall s t. Accum s t -> AcM s (Rep t) +accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0# + where + go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t') + go inarr ty off# = case ty of + STNil -> return (I# off#, ()) + STPair a b -> do + (I# off1#, x) <- go inarr a off# + (I# off2#, y) <- go inarr b off1# + return (I# (off1# +# off2#), (x, y)) + STEither a b -> do + tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) + if inarr + then do val <- case tag of + 0 -> Left . snd <$> go inarr a (off# +# 1#) + 1 -> Right . snd <$> go inarr b (off# +# 1#) + _ -> error "Invalid tag in accum memory" + return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val) + else case tag of + 0 -> fmap Left <$> go inarr a (off# +# 1#) + 1 -> fmap Right <$> go inarr b (off# +# 1#) + _ -> error "Invalid tag in accum memory" + STArr ndim t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + (I# off1#, sh) <- readShape mbarr ndim off# + let !(I# eltsize#) = tSize' t Nothing + !(I# n#) = shapeSize sh + arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#)) + return (I# (off1# +# eltsize# *# n#), arr) + STScal sty -> goScal sty off# + STAccum{} -> error "Nested accumulators unsupported" + + goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t') + goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) + goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) + goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) + goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) + goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) + +readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n) +readShape _ SZ off# = return (I# off#, ShNil) +readShape mbarr (SS ndim) off# = do + (I# off1#, sh) <- readShape mbarr ndim off# + n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) + return (I# (off1# +# 8#), ShCons sh (fromIntegral n')) + +data InvShape full yet where + IShFull :: InvShape full full + IShCons :: InvShape full (S yet) -> Int -> InvShape full yet + +invertShape :: Shape n -> InvShape n Z +invertShape + + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0# + where + go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int + go inarr ty SZ idx val off# = performAdd inarr ty val off# + go inarr ty (SS dep) idx val off# = case (ty, idx, val) of + (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# + (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# + (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# + (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# + (STArr rank eltty, _, _) + | inarr -> error "Nested arrays" + | otherwise -> goArr (SS dep) rank eltty idx val off# + -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off# + -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off# + _ -> _ + + goArr :: SNat i' -> SNat n -> STy t' + -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int + goArr dep n t1 idx val off# = do + (I# off1#, sh) <- readShape mbarr n off# + _ + + collectArrIndex :: SNat i' -> STy t' -> Shape n + -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') + -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) + -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) + -> ST s r + collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val + collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k = + collectArrIndex dep eltty sh idx val $ \index dep' idx' val' -> + k (IxCons index (fromIntegral i)) dep' idx' val' + -- collectArrIndex SZ eltty + + goShape :: Shape n -> Int# -> ST s Int + goShape ShNil off# = return (I# off#) + goShape (ShCons sh n) off# = do + I# off1# <- goShape sh off# + let !(I64# n') = fromIntegral n + ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) + + goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int + goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) + goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) + goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) + goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) + goScal STBool b off# = + let !(I8# i) = fromIntegral (fromEnum b) + in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + +withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) +withAccum ty start fun = do + let !(I# size) = tSize ty start + accum <- AcM . ST $ \s -> case newByteArray# size s of + (# s', mbarr #) -> (# s', Accum ty mbarr #) + AcM $ accumWrite accum start + b <- fun accum + out <- accumRead accum + return (out, b) diff --git a/src/Interpreter/Array.hs b/src/Interpreter/Array.hs new file mode 100644 index 0000000..f358225 --- /dev/null +++ b/src/Interpreter/Array.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TupleSections #-} +module Interpreter.Array where + +import Control.Monad.Trans.State.Strict +import Data.Foldable (traverse_) +import Data.Vector (Vector) +import qualified Data.Vector as V + +import Data + + +data Shape n where + ShNil :: Shape Z + ShCons :: Shape n -> Int -> Shape (S n) + +data Index n where + IxNil :: Index Z + IxCons :: Index n -> Int -> Index (S n) + +shapeSize :: Shape n -> Int +shapeSize ShNil = 0 +shapeSize (ShCons sh n) = shapeSize sh * n + + +-- | TODO: this Vector is a boxed vector, which is horrendously inefficient. +data Array (n :: Nat) t = Array (Shape n) (Vector t) + +arrayShape :: Array n t -> Shape n +arrayShape (Array sh _) = sh + +arraySize :: Array n t -> Int +arraySize (Array sh _) = shapeSize sh + +arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) +arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f + +-- | The Int is the linear index of the value. +traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () +traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 |