summaryrefslogtreecommitdiff
path: root/src/Interpreter
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-06 22:22:09 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-06 22:22:09 +0200
commit8942db88390ece8b1cf018ad723fedc6ae82cf64 (patch)
tree6c4bc400655f090db3c7deb19157eedcdd9cdb0e /src/Interpreter
parent0f94ac819d664b0c1f8feaf567648a3724b5eadb (diff)
WIP interpreter
Diffstat (limited to 'src/Interpreter')
-rw-r--r--src/Interpreter/Accum.hs229
-rw-r--r--src/Interpreter/Array.hs42
2 files changed, 271 insertions, 0 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
new file mode 100644
index 0000000..b0deaef
--- /dev/null
+++ b/src/Interpreter/Accum.hs
@@ -0,0 +1,229 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE BangPatterns #-}
+module Interpreter.Accum where
+
+import Control.Monad.ST
+import Control.Monad.ST.Unsafe
+import GHC.Exts
+import GHC.Int
+import GHC.ST (ST(..))
+
+import AST
+import Data
+import Interpreter.Array
+
+
+newtype AcM s a = AcM (ST s a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM m = runST (case m of AcM m' -> m')
+
+type family Rep t where
+ Rep TNil = ()
+ Rep (TPair a b) = (Rep a, Rep b)
+ Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TArr n t) = Array n (Rep t)
+ Rep (TScal sty) = ScalRep sty
+ -- Rep (TAccum t) = _
+
+data Accum s t = Accum (STy t) (MutableByteArray# s)
+
+tSize :: STy t -> Rep t -> Int
+tSize ty x = tSize' ty (Just x)
+
+-- | Passing Nothing as the value means "this is (inside) an array element".
+tSize' :: STy t -> Maybe (Rep t) -> Int
+tSize' typ val = case typ of
+ STNil -> 0
+ STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val)
+ STEither a b ->
+ case val of
+ Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing)
+ Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking
+ Just (Right y) -> 1 + tSize b y -- idem
+ STArr ndim t ->
+ case val of
+ Nothing -> error "Nested arrays not supported in this implementation"
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing
+ STScal sty -> goScal sty
+ STAccum{} -> error "Nested accumulators unsupported"
+ where
+ goScal :: SScalTy t -> Int
+ goScal STI32 = 4
+ goScal STI64 = 8
+ goScal STF32 = 4
+ goScal STF64 = 8
+ goScal STBool = 1
+
+-- | This operation does not commute with 'accumAdd', so it must be used with
+-- care. Furthermore it must be used on exactly the same value as tSize was
+-- called on. Hence it lives in ST, not in AcM.
+accumWrite :: forall s t. Accum s t -> Rep t -> ST s ()
+accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0#
+ where
+ go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int
+ go inarr ty val off# = case ty of
+ STNil -> return (I# off#)
+ STPair a b -> do
+ I# off1# <- go inarr a (fst val) off#
+ go inarr b (snd val) off1#
+ STEither a b ->
+ case val of
+ Left x -> do
+ let !(I8# tag#) = 0
+ ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
+ go inarr a x (off# +# 1#)
+ Right y -> do
+ let !(I8# tag#) = 1
+ ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
+ go inarr b y (off# +# 1#)
+ STArr _ t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ I# off1# <- goShape (arrayShape val) off#
+ let !(I# eltsize#) = tSize' t Nothing
+ !(I# n#) = arraySize val
+ traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val
+ return (I# (off1# +# eltsize# *# n#))
+ STScal sty -> goScal sty val off#
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goShape :: Shape n -> Int# -> ST s Int
+ goShape ShNil off# = return (I# off#)
+ goShape (ShCons sh n) off# = do
+ I# off1# <- goShape sh off#
+ let !(I64# n') = fromIntegral n
+ ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int
+ goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
+ goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
+ goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
+ goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
+ goScal STBool b off# =
+ let !(I8# i) = fromIntegral (fromEnum b)
+ in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)
+
+accumRead :: forall s t. Accum s t -> AcM s (Rep t)
+accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0#
+ where
+ go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t')
+ go inarr ty off# = case ty of
+ STNil -> return (I# off#, ())
+ STPair a b -> do
+ (I# off1#, x) <- go inarr a off#
+ (I# off2#, y) <- go inarr b off1#
+ return (I# (off1# +# off2#), (x, y))
+ STEither a b -> do
+ tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
+ if inarr
+ then do val <- case tag of
+ 0 -> Left . snd <$> go inarr a (off# +# 1#)
+ 1 -> Right . snd <$> go inarr b (off# +# 1#)
+ _ -> error "Invalid tag in accum memory"
+ return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val)
+ else case tag of
+ 0 -> fmap Left <$> go inarr a (off# +# 1#)
+ 1 -> fmap Right <$> go inarr b (off# +# 1#)
+ _ -> error "Invalid tag in accum memory"
+ STArr ndim t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ (I# off1#, sh) <- readShape mbarr ndim off#
+ let !(I# eltsize#) = tSize' t Nothing
+ !(I# n#) = shapeSize sh
+ arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#))
+ return (I# (off1# +# eltsize# *# n#), arr)
+ STScal sty -> goScal sty off#
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t')
+ goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #)
+ goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #)
+ goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #)
+ goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #)
+ goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #)
+
+readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n)
+readShape _ SZ off# = return (I# off#, ShNil)
+readShape mbarr (SS ndim) off# = do
+ (I# off1#, sh) <- readShape mbarr ndim off#
+ n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #)
+ return (I# (off1# +# 8#), ShCons sh (fromIntegral n'))
+
+data InvShape full yet where
+ IShFull :: InvShape full full
+ IShCons :: InvShape full (S yet) -> Int -> InvShape full yet
+
+invertShape :: Shape n -> InvShape n Z
+invertShape
+
+
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0#
+ where
+ go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int
+ go inarr ty SZ idx val off# = performAdd inarr ty val off#
+ go inarr ty (SS dep) idx val off# = case (ty, idx, val) of
+ (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off#
+ (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off#
+ (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off#
+ (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off#
+ (STArr rank eltty, _, _)
+ | inarr -> error "Nested arrays"
+ | otherwise -> goArr (SS dep) rank eltty idx val off#
+ -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off#
+ -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off#
+ _ -> _
+
+ goArr :: SNat i' -> SNat n -> STy t'
+ -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int
+ goArr dep n t1 idx val off# = do
+ (I# off1#, sh) <- readShape mbarr n off#
+ _
+
+ collectArrIndex :: SNat i' -> STy t' -> Shape n
+ -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i')
+ -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r)
+ -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r)
+ -> ST s r
+ collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val
+ collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k =
+ collectArrIndex dep eltty sh idx val $ \index dep' idx' val' ->
+ k (IxCons index (fromIntegral i)) dep' idx' val'
+ -- collectArrIndex SZ eltty
+
+ goShape :: Shape n -> Int# -> ST s Int
+ goShape ShNil off# = return (I# off#)
+ goShape (ShCons sh n) off# = do
+ I# off1# <- goShape sh off#
+ let !(I64# n') = fromIntegral n
+ ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int
+ goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
+ goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
+ goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
+ goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
+ goScal STBool b off# =
+ let !(I8# i) = fromIntegral (fromEnum b)
+ in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)
+
+withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
+withAccum ty start fun = do
+ let !(I# size) = tSize ty start
+ accum <- AcM . ST $ \s -> case newByteArray# size s of
+ (# s', mbarr #) -> (# s', Accum ty mbarr #)
+ AcM $ accumWrite accum start
+ b <- fun accum
+ out <- accumRead accum
+ return (out, b)
diff --git a/src/Interpreter/Array.hs b/src/Interpreter/Array.hs
new file mode 100644
index 0000000..f358225
--- /dev/null
+++ b/src/Interpreter/Array.hs
@@ -0,0 +1,42 @@
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TupleSections #-}
+module Interpreter.Array where
+
+import Control.Monad.Trans.State.Strict
+import Data.Foldable (traverse_)
+import Data.Vector (Vector)
+import qualified Data.Vector as V
+
+import Data
+
+
+data Shape n where
+ ShNil :: Shape Z
+ ShCons :: Shape n -> Int -> Shape (S n)
+
+data Index n where
+ IxNil :: Index Z
+ IxCons :: Index n -> Int -> Index (S n)
+
+shapeSize :: Shape n -> Int
+shapeSize ShNil = 0
+shapeSize (ShCons sh n) = shapeSize sh * n
+
+
+-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
+data Array (n :: Nat) t = Array (Shape n) (Vector t)
+
+arrayShape :: Array n t -> Shape n
+arrayShape (Array sh _) = sh
+
+arraySize :: Array n t -> Int
+arraySize (Array sh _) = shapeSize sh
+
+arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
+arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
+
+-- | The Int is the linear index of the value.
+traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
+traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0