diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-06 22:22:09 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-06 22:22:09 +0200 | 
| commit | 8942db88390ece8b1cf018ad723fedc6ae82cf64 (patch) | |
| tree | 6c4bc400655f090db3c7deb19157eedcdd9cdb0e | |
| parent | 0f94ac819d664b0c1f8feaf567648a3724b5eadb (diff) | |
WIP interpreter
| -rw-r--r-- | chad-fast.cabal | 4 | ||||
| -rw-r--r-- | src/Interpreter.hs | 3 | ||||
| -rw-r--r-- | src/Interpreter/Accum.hs | 229 | ||||
| -rw-r--r-- | src/Interpreter/Array.hs | 42 | 
4 files changed, 278 insertions, 0 deletions
| diff --git a/chad-fast.cabal b/chad-fast.cabal index 290329b..ef1fd66 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -20,6 +20,9 @@ library      -- Compile      Data      Example +    Interpreter +    Interpreter.Accum +    Interpreter.Array      Language      Language.AST      Lemmas @@ -31,6 +34,7 @@ library      containers,      -- template-haskell,      transformers, +    vector,    hs-source-dirs:      src    default-language: diff --git a/src/Interpreter.hs b/src/Interpreter.hs new file mode 100644 index 0000000..d1dfb1d --- /dev/null +++ b/src/Interpreter.hs @@ -0,0 +1,3 @@ +module Interpreter where + + diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs new file mode 100644 index 0000000..b0deaef --- /dev/null +++ b/src/Interpreter/Accum.hs @@ -0,0 +1,229 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE BangPatterns #-} +module Interpreter.Accum where + +import Control.Monad.ST +import Control.Monad.ST.Unsafe +import GHC.Exts +import GHC.Int +import GHC.ST (ST(..)) + +import AST +import Data +import Interpreter.Array + + +newtype AcM s a = AcM (ST s a) +  deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM m = runST (case m of AcM m' -> m') + +type family Rep t where +  Rep TNil = () +  Rep (TPair a b) = (Rep a, Rep b) +  Rep (TEither a b) = Either (Rep a) (Rep b) +  Rep (TArr n t) = Array n (Rep t) +  Rep (TScal sty) = ScalRep sty +  -- Rep (TAccum t) = _ + +data Accum s t = Accum (STy t) (MutableByteArray# s) + +tSize :: STy t -> Rep t -> Int +tSize ty x = tSize' ty (Just x) + +-- | Passing Nothing as the value means "this is (inside) an array element". +tSize' :: STy t -> Maybe (Rep t) -> Int +tSize' typ val = case typ of +  STNil -> 0 +  STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) +  STEither a b -> +    case val of +      Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) +      Just (Left x) -> 1 + tSize a x   -- '1 +' is for runtime sanity checking +      Just (Right y) -> 1 + tSize b y  -- idem +  STArr ndim t -> +    case val of +      Nothing -> error "Nested arrays not supported in this implementation" +      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing +  STScal sty -> goScal sty +  STAccum{} -> error "Nested accumulators unsupported" +  where +    goScal :: SScalTy t -> Int +    goScal STI32 = 4 +    goScal STI64 = 8 +    goScal STF32 = 4 +    goScal STF64 = 8 +    goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in ST, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep t -> ST s () +accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0# +  where +    go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int +    go inarr ty val off# = case ty of +      STNil -> return (I# off#) +      STPair a b -> do +        I# off1# <- go inarr a (fst val) off# +        go inarr b (snd val) off1# +      STEither a b -> +        case val of +          Left x -> do +            let !(I8# tag#) = 0 +            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) +            go inarr a x (off# +# 1#) +          Right y -> do +            let !(I8# tag#) = 1 +            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #) +            go inarr b y (off# +# 1#) +      STArr _ t +        | inarr -> error "Nested arrays not supported in this implementation" +        | otherwise -> do +            I# off1# <- goShape (arrayShape val) off# +            let !(I# eltsize#) = tSize' t Nothing +                !(I# n#) = arraySize val +            traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val +            return (I# (off1# +# eltsize# *# n#)) +      STScal sty -> goScal sty val off# +      STAccum{} -> error "Nested accumulators unsupported" + +    goShape :: Shape n -> Int# -> ST s Int +    goShape ShNil off# = return (I# off#) +    goShape (ShCons sh n) off# = do +      I# off1# <- goShape sh off# +      let !(I64# n') = fromIntegral n +      ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) + +    goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int +    goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) +    goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) +    goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) +    goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) +    goScal STBool b off# = +      let !(I8# i) = fromIntegral (fromEnum b) +      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + +accumRead :: forall s t. Accum s t -> AcM s (Rep t) +accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0# +  where +    go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t') +    go inarr ty off# = case ty of +      STNil -> return (I# off#, ()) +      STPair a b -> do +        (I# off1#, x) <- go inarr a off# +        (I# off2#, y) <- go inarr b off1# +        return (I# (off1# +# off2#), (x, y)) +      STEither a b -> do +        tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #) +        if inarr +          then do val <- case tag of +                           0 -> Left . snd <$> go inarr a (off# +# 1#) +                           1 -> Right . snd <$> go inarr b (off# +# 1#) +                           _ -> error "Invalid tag in accum memory" +                  return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val) +          else case tag of +                 0 -> fmap Left <$> go inarr a (off# +# 1#) +                 1 -> fmap Right <$> go inarr b (off# +# 1#) +                 _ -> error "Invalid tag in accum memory" +      STArr ndim t +        | inarr -> error "Nested arrays not supported in this implementation" +        | otherwise -> do +            (I# off1#, sh) <- readShape mbarr ndim off# +            let !(I# eltsize#) = tSize' t Nothing +                !(I# n#) = shapeSize sh +            arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#)) +            return (I# (off1# +# eltsize# *# n#), arr) +      STScal sty -> goScal sty off# +      STAccum{} -> error "Nested accumulators unsupported" + +    goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t') +    goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #) +    goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #) +    goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #) +    goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #) +    goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #) + +readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n) +readShape _ SZ off# = return (I# off#, ShNil) +readShape mbarr (SS ndim) off# = do +  (I# off1#, sh) <- readShape mbarr ndim off# +  n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #) +  return (I# (off1# +# 8#), ShCons sh (fromIntegral n')) + +data InvShape full yet where +  IShFull :: InvShape full full +  IShCons :: InvShape full (S yet) -> Int -> InvShape full yet + +invertShape :: Shape n -> InvShape n Z +invertShape + + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0# +  where +    go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int +    go inarr ty SZ idx val off# = performAdd inarr ty val off# +    go inarr ty (SS dep) idx val off# = case (ty, idx, val) of +      (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# +      (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# +      (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off# +      (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off# +      (STArr rank eltty, _, _) +        | inarr -> error "Nested arrays" +        | otherwise -> goArr (SS dep) rank eltty idx val off# +      -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off# +      -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off# +      _ -> _ + +    goArr :: SNat i' -> SNat n -> STy t' +          -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int +    goArr dep n t1 idx val off# = do +      (I# off1#, sh) <- readShape mbarr n off# +      _ + +    collectArrIndex :: SNat i' -> STy t' -> Shape n +                    -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') +                    -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) +                    -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r) +                    -> ST s r +    collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val +    collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k = +      collectArrIndex dep eltty sh idx val $ \index dep' idx' val' -> +        k (IxCons index (fromIntegral i)) dep' idx' val' +    -- collectArrIndex SZ eltty + +    goShape :: Shape n -> Int# -> ST s Int +    goShape ShNil off# = return (I# off#) +    goShape (ShCons sh n) off# = do +      I# off1# <- goShape sh off# +      let !(I64# n') = fromIntegral n +      ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #) + +    goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int +    goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #) +    goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #) +    goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #) +    goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #) +    goScal STBool b off# = +      let !(I8# i) = fromIntegral (fromEnum b) +      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #) + +withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) +withAccum ty start fun = do +  let !(I# size) = tSize ty start +  accum <- AcM . ST $ \s -> case newByteArray# size s of +                              (# s', mbarr #) -> (# s', Accum ty mbarr #) +  AcM $ accumWrite accum start +  b <- fun accum +  out <- accumRead accum +  return (out, b) diff --git a/src/Interpreter/Array.hs b/src/Interpreter/Array.hs new file mode 100644 index 0000000..f358225 --- /dev/null +++ b/src/Interpreter/Array.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TupleSections #-} +module Interpreter.Array where + +import Control.Monad.Trans.State.Strict +import Data.Foldable (traverse_) +import Data.Vector (Vector) +import qualified Data.Vector as V + +import Data + + +data Shape n where +  ShNil :: Shape Z +  ShCons :: Shape n -> Int -> Shape (S n) + +data Index n where +  IxNil :: Index Z +  IxCons :: Index n -> Int -> Index (S n) + +shapeSize :: Shape n -> Int +shapeSize ShNil = 0 +shapeSize (ShCons sh n) = shapeSize sh * n + + +-- | TODO: this Vector is a boxed vector, which is horrendously inefficient. +data Array (n :: Nat) t = Array (Shape n) (Vector t) + +arrayShape :: Array n t -> Shape n +arrayShape (Array sh _) = sh + +arraySize :: Array n t -> Int +arraySize (Array sh _) = shapeSize sh + +arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) +arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f + +-- | The Int is the linear index of the value. +traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () +traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 | 
