diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-09-12 17:31:20 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-12 17:31:42 +0200 |
commit | 36732f84cfade5371248806328791d5066673fb7 (patch) | |
tree | 68cf208fca197a48e6b0506e783c1bdaf98d2e42 /src/Interpreter | |
parent | 1f53cea6a1352db125e1897ca574360180be2550 (diff) |
Interpreter, some operations
Diffstat (limited to 'src/Interpreter')
-rw-r--r-- | src/Interpreter/Accum.hs | 76 | ||||
-rw-r--r-- | src/Interpreter/Array.hs | 45 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 17 |
3 files changed, 56 insertions, 82 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index d15ea10..b6a91df 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -15,7 +15,7 @@ module Interpreter.Accum ( AcM, runAcM, - Rep, + Rep', Accum, withAccum, accumAdd, @@ -25,6 +25,8 @@ module Interpreter.Accum ( import Control.Concurrent import Control.Monad (when, forM_) import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) import Foreign.Storable (sizeOf) import GHC.Exts import GHC.Float @@ -33,10 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) +import Array import AST import Data -import Interpreter.Array -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) newtype AcM s a = AcM (IO a) @@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a) runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m -type family Rep t where - Rep TNil = () - Rep (TPair a b) = (Rep a, Rep b) - Rep (TEither a b) = Either (Rep a) (Rep b) - Rep (TArr n t) = Array n (Rep t) - Rep (TScal sty) = ScalRep sty - -- Rep (TAccum t) = _ +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t -- | Floats and integers are accumulated; booleans are left as-is. data Accum s t = Accum (STy t) (ForeignPtr ()) -tSize :: STy t -> Rep t -> Int -tSize ty x = tSize' ty (Just x) +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) -- | Passing Nothing as the value means "this is (inside) an array element". -tSize' :: STy t -> Maybe (Rep t) -> Int -tSize' typ val = case typ of +tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int +tSize' p typ val = case typ of STNil -> 0 - STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) + STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val) STEither a b -> case val of - Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) - Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking - Just (Right y) -> 1 + tSize b y -- idem + Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing) + Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking + Just (Right y) -> 1 + tSize' p b (Just y) -- idem STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" - Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing STScal sty -> goScal sty STAccum{} -> error "Nested accumulators unsupported" where @@ -86,10 +88,10 @@ tSize' typ val = case typ of -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep t -> IO () +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> Rep t' -> Int -> IO Int + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int go inarr ty val off = case ty of STNil -> return off STPair a b -> do @@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do off1 <- goShape (arrayShape val) off - let eltsize = tSize' t Nothing + let eltsize = tSize' (Proxy @s) t Nothing n = arraySize val traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val return (off1 + eltsize * n) @@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> in () <$ go False topty top_value 0 -accumRead :: forall s t. Accum s t -> AcM s (Rep t) +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> Int -> IO (Int, Rep t') + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') go inarr ty off = case ty of STNil -> return (off, ()) STPair a b -> do @@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> 1 -> fmap Right <$> go inarr b (off + 1) _ -> error "Invalid tag in accum memory" if inarr - then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val) + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) else return (off1, val) STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do (off1, sh) <- readShape addr# ndim off - let eltsize = tSize' t Nothing + let eltsize = tSize' (Proxy @s) t Nothing n = shapeSize sh arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) return (off1 + eltsize * n, arr) @@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil go ShNil ish = ish go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO () + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () go inarr ty SZ () val off = () <$ performAdd inarr ty val off go inarr ty (SS dep) idx val off = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off @@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (STAccum{}, _, _) -> error "Nested accumulators unsupported" goArr :: SNat i' -> InvShape n -> STy t' - -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO () + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do - let i' = fromIntegral @(Rep TIx) @Int i + let i' = fromIntegral @(Rep' s TIx) @Int i when (i' < 0 || i' >= n) $ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" goArr depm1 ish t1 idx val (off + i' * ishSize ish) - performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int performAddArr arraySz eltty val off = do - let eltsize = tSize' eltty Nothing + let eltsize = tSize' (Proxy @s) eltty Nothing forM_ [0 .. arraySz - 1] $ \lini -> performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) return (off + arraySz * eltsize) - performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int performAdd inarr ty val off = case ty of STNil -> return off STPair t1 t2 -> do @@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) _ -> error "accumAdd: Tag mismatch for Either" if inarr - then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing)) + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) else return off1 STArr n ty' | inarr -> error "Nested array" @@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr in () <$ go False topty top_depth top_index top_value 0 -withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) withAccum ty start fun = do -- The initial write must happen before any of the adds or reads, so it makes -- sense to put it in IO together with the allocation, instead of in AcM. - accum <- AcM $ do buffer <- mallocBytes (tSize ty start) + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) ptr <- newForeignPtr finalizerFree buffer let accum = Accum ty ptr accumWrite accum start return accum b <- fun accum out <- accumRead accum - return (out, b) + return (b, out) inParallel :: [AcM s t] -> AcM s [t] inParallel actions = AcM $ do diff --git a/src/Interpreter/Array.hs b/src/Interpreter/Array.hs deleted file mode 100644 index 54e0791..0000000 --- a/src/Interpreter/Array.hs +++ /dev/null @@ -1,45 +0,0 @@ -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TupleSections #-} -module Interpreter.Array where - -import Control.Monad.Trans.State.Strict -import Data.Foldable (traverse_) -import Data.Vector (Vector) -import qualified Data.Vector as V - -import Data - - -data Shape n where - ShNil :: Shape Z - ShCons :: Shape n -> Int -> Shape (S n) - -data Index n where - IxNil :: Index Z - IxCons :: Index n -> Int -> Index (S n) - -shapeSize :: Shape n -> Int -shapeSize ShNil = 0 -shapeSize (ShCons sh n) = shapeSize sh * n - - --- | TODO: this Vector is a boxed vector, which is horrendously inefficient. -data Array (n :: Nat) t = Array (Shape n) (Vector t) - -arrayShape :: Array n t -> Shape n -arrayShape (Array sh _) = sh - -arraySize :: Array n t -> Int -arraySize (Array sh _) = shapeSize sh - -arrayIndexLinear :: Array n t -> Int -> t -arrayIndexLinear (Array _ v) i = v V.! i - -arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) -arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f - --- | The Int is the linear index of the value. -traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () -traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs new file mode 100644 index 0000000..1ded773 --- /dev/null +++ b/src/Interpreter/Rep.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +module Interpreter.Rep where + +import GHC.TypeError + +import Array +import AST + + +type family Rep t where + Rep TNil = () + Rep (TPair a b) = (Rep a, Rep b) + Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TArr n t) = Array n (Rep t) + Rep (TScal sty) = ScalRep sty + Rep (TAccum t) = TypeError (Text "Accumulator in Rep") |