summaryrefslogtreecommitdiff
path: root/src/Interpreter
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-12 17:31:20 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-12 17:31:42 +0200
commit36732f84cfade5371248806328791d5066673fb7 (patch)
tree68cf208fca197a48e6b0506e783c1bdaf98d2e42 /src/Interpreter
parent1f53cea6a1352db125e1897ca574360180be2550 (diff)
Interpreter, some operations
Diffstat (limited to 'src/Interpreter')
-rw-r--r--src/Interpreter/Accum.hs76
-rw-r--r--src/Interpreter/Array.hs45
-rw-r--r--src/Interpreter/Rep.hs17
3 files changed, 56 insertions, 82 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index d15ea10..b6a91df 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -15,7 +15,7 @@
module Interpreter.Accum (
AcM,
runAcM,
- Rep,
+ Rep',
Accum,
withAccum,
accumAdd,
@@ -25,6 +25,8 @@ module Interpreter.Accum (
import Control.Concurrent
import Control.Monad (when, forM_)
import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
import Foreign.Storable (sizeOf)
import GHC.Exts
import GHC.Float
@@ -33,10 +35,9 @@ import GHC.IO (IO(..))
import GHC.Word
import System.IO.Unsafe (unsafePerformIO)
+import Array
import AST
import Data
-import Interpreter.Array
-import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
newtype AcM s a = AcM (IO a)
@@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a)
runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
-type family Rep t where
- Rep TNil = ()
- Rep (TPair a b) = (Rep a, Rep b)
- Rep (TEither a b) = Either (Rep a) (Rep b)
- Rep (TArr n t) = Array n (Rep t)
- Rep (TScal sty) = ScalRep sty
- -- Rep (TAccum t) = _
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
-- | Floats and integers are accumulated; booleans are left as-is.
data Accum s t = Accum (STy t) (ForeignPtr ())
-tSize :: STy t -> Rep t -> Int
-tSize ty x = tSize' ty (Just x)
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
-- | Passing Nothing as the value means "this is (inside) an array element".
-tSize' :: STy t -> Maybe (Rep t) -> Int
-tSize' typ val = case typ of
+tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int
+tSize' p typ val = case typ of
STNil -> 0
- STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val)
+ STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val)
STEither a b ->
case val of
- Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing)
- Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking
- Just (Right y) -> 1 + tSize b y -- idem
+ Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing)
+ Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking
+ Just (Right y) -> 1 + tSize' p b (Just y) -- idem
STArr ndim t ->
case val of
Nothing -> error "Nested arrays not supported in this implementation"
- Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
STScal sty -> goScal sty
STAccum{} -> error "Nested accumulators unsupported"
where
@@ -86,10 +88,10 @@ tSize' typ val = case typ of
-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in IO, not in AcM.
-accumWrite :: forall s t. Accum s t -> Rep t -> IO ()
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> Rep t' -> Int -> IO Int
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
go inarr ty val off = case ty of
STNil -> return off
STPair a b -> do
@@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
off1 <- goShape (arrayShape val) off
- let eltsize = tSize' t Nothing
+ let eltsize = tSize' (Proxy @s) t Nothing
n = arraySize val
traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
return (off1 + eltsize * n)
@@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
in () <$ go False topty top_value 0
-accumRead :: forall s t. Accum s t -> AcM s (Rep t)
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> Int -> IO (Int, Rep t')
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
go inarr ty off = case ty of
STNil -> return (off, ())
STPair a b -> do
@@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
1 -> fmap Right <$> go inarr b (off + 1)
_ -> error "Invalid tag in accum memory"
if inarr
- then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val)
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
else return (off1, val)
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
(off1, sh) <- readShape addr# ndim off
- let eltsize = tSize' t Nothing
+ let eltsize = tSize' (Proxy @s) t Nothing
n = shapeSize sh
arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
return (off1 + eltsize * n, arr)
@@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil
go ShNil ish = ish
go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
-accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO ()
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
go inarr ty SZ () val off = () <$ performAdd inarr ty val off
go inarr ty (SS dep) idx val off = case (ty, idx, val) of
(STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
@@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(STAccum{}, _, _) -> error "Nested accumulators unsupported"
goArr :: SNat i' -> InvShape n -> STy t'
- -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO ()
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
- let i' = fromIntegral @(Rep TIx) @Int i
+ let i' = fromIntegral @(Rep' s TIx) @Int i
when (i' < 0 || i' >= n) $
error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
goArr depm1 ish t1 idx val (off + i' * ishSize ish)
- performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
performAddArr arraySz eltty val off = do
- let eltsize = tSize' eltty Nothing
+ let eltsize = tSize' (Proxy @s) eltty Nothing
forM_ [0 .. arraySz - 1] $ \lini ->
performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
return (off + arraySz * eltsize)
- performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
performAdd inarr ty val off = case ty of
STNil -> return off
STPair t1 t2 -> do
@@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
_ -> error "accumAdd: Tag mismatch for Either"
if inarr
- then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing))
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
else return off1
STArr n ty'
| inarr -> error "Nested array"
@@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
in () <$ go False topty top_depth top_index top_value 0
-withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
withAccum ty start fun = do
-- The initial write must happen before any of the adds or reads, so it makes
-- sense to put it in IO together with the allocation, instead of in AcM.
- accum <- AcM $ do buffer <- mallocBytes (tSize ty start)
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
ptr <- newForeignPtr finalizerFree buffer
let accum = Accum ty ptr
accumWrite accum start
return accum
b <- fun accum
out <- accumRead accum
- return (out, b)
+ return (b, out)
inParallel :: [AcM s t] -> AcM s [t]
inParallel actions = AcM $ do
diff --git a/src/Interpreter/Array.hs b/src/Interpreter/Array.hs
deleted file mode 100644
index 54e0791..0000000
--- a/src/Interpreter/Array.hs
+++ /dev/null
@@ -1,45 +0,0 @@
-{-# LANGUAGE KindSignatures #-}
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE TupleSections #-}
-module Interpreter.Array where
-
-import Control.Monad.Trans.State.Strict
-import Data.Foldable (traverse_)
-import Data.Vector (Vector)
-import qualified Data.Vector as V
-
-import Data
-
-
-data Shape n where
- ShNil :: Shape Z
- ShCons :: Shape n -> Int -> Shape (S n)
-
-data Index n where
- IxNil :: Index Z
- IxCons :: Index n -> Int -> Index (S n)
-
-shapeSize :: Shape n -> Int
-shapeSize ShNil = 0
-shapeSize (ShCons sh n) = shapeSize sh * n
-
-
--- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
-data Array (n :: Nat) t = Array (Shape n) (Vector t)
-
-arrayShape :: Array n t -> Shape n
-arrayShape (Array sh _) = sh
-
-arraySize :: Array n t -> Int
-arraySize (Array sh _) = shapeSize sh
-
-arrayIndexLinear :: Array n t -> Int -> t
-arrayIndexLinear (Array _ v) i = v V.! i
-
-arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
-arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
-
--- | The Int is the linear index of the value.
-traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
-traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
new file mode 100644
index 0000000..1ded773
--- /dev/null
+++ b/src/Interpreter/Rep.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+module Interpreter.Rep where
+
+import GHC.TypeError
+
+import Array
+import AST
+
+
+type family Rep t where
+ Rep TNil = ()
+ Rep (TPair a b) = (Rep a, Rep b)
+ Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TArr n t) = Array n (Rep t)
+ Rep (TScal sty) = ScalRep sty
+ Rep (TAccum t) = TypeError (Text "Accumulator in Rep")