diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-12 17:31:20 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-12 17:31:42 +0200 | 
| commit | 36732f84cfade5371248806328791d5066673fb7 (patch) | |
| tree | 68cf208fca197a48e6b0506e783c1bdaf98d2e42 /src/Interpreter | |
| parent | 1f53cea6a1352db125e1897ca574360180be2550 (diff) | |
Interpreter, some operations
Diffstat (limited to 'src/Interpreter')
| -rw-r--r-- | src/Interpreter/Accum.hs | 76 | ||||
| -rw-r--r-- | src/Interpreter/Array.hs | 45 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 17 | 
3 files changed, 56 insertions, 82 deletions
| diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index d15ea10..b6a91df 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -15,7 +15,7 @@  module Interpreter.Accum (    AcM,    runAcM, -  Rep, +  Rep',    Accum,    withAccum,    accumAdd, @@ -25,6 +25,8 @@ module Interpreter.Accum (  import Control.Concurrent  import Control.Monad (when, forM_)  import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)  import Foreign.Storable (sizeOf)  import GHC.Exts  import GHC.Float @@ -33,10 +35,9 @@ import GHC.IO (IO(..))  import GHC.Word  import System.IO.Unsafe (unsafePerformIO) +import Array  import AST  import Data -import Interpreter.Array -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)  newtype AcM s a = AcM (IO a) @@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a)  runAcM :: (forall s. AcM s a) -> a  runAcM (AcM m) = unsafePerformIO m -type family Rep t where -  Rep TNil = () -  Rep (TPair a b) = (Rep a, Rep b) -  Rep (TEither a b) = Either (Rep a) (Rep b) -  Rep (TArr n t) = Array n (Rep t) -  Rep (TScal sty) = ScalRep sty -  -- Rep (TAccum t) = _ +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where +  Rep' s TNil = () +  Rep' s (TPair a b) = (Rep' s a, Rep' s b) +  Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) +  Rep' s (TArr n t) = Array n (Rep' s t) +  Rep' s (TScal sty) = ScalRep sty +  Rep' s (TAccum t) = Accum s t  -- | Floats and integers are accumulated; booleans are left as-is.  data Accum s t = Accum (STy t) (ForeignPtr ()) -tSize :: STy t -> Rep t -> Int -tSize ty x = tSize' ty (Just x) +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x)  -- | Passing Nothing as the value means "this is (inside) an array element". -tSize' :: STy t -> Maybe (Rep t) -> Int -tSize' typ val = case typ of +tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int +tSize' p typ val = case typ of    STNil -> 0 -  STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) +  STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val)    STEither a b ->      case val of -      Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) -      Just (Left x) -> 1 + tSize a x   -- '1 +' is for runtime sanity checking -      Just (Right y) -> 1 + tSize b y  -- idem +      Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing) +      Just (Left x) -> 1 + tSize' p a (Just x)   -- '1 +' is for runtime sanity checking +      Just (Right y) -> 1 + tSize' p b (Just y)  -- idem    STArr ndim t ->      case val of        Nothing -> error "Nested arrays not supported in this implementation" -      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing +      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing    STScal sty -> goScal sty    STAccum{} -> error "Nested accumulators unsupported"    where @@ -86,10 +88,10 @@ tSize' typ val = case typ of  -- | This operation does not commute with 'accumAdd', so it must be used with  -- care. Furthermore it must be used on exactly the same value as tSize was  -- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep t -> IO () +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()  accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->    let -    go :: Bool -> STy t' -> Rep t' -> Int -> IO Int +    go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int      go inarr ty val off = case ty of        STNil -> return off        STPair a b -> do @@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->          | inarr -> error "Nested arrays not supported in this implementation"          | otherwise -> do              off1 <- goShape (arrayShape val) off -            let eltsize = tSize' t Nothing +            let eltsize = tSize' (Proxy @s) t Nothing                  n = arraySize val              traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val              return (off1 + eltsize * n) @@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->    in () <$ go False topty top_value 0 -accumRead :: forall s t. Accum s t -> AcM s (Rep t) +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)  accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->    let -    go :: Bool -> STy t' -> Int -> IO (Int, Rep t') +    go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')      go inarr ty off = case ty of        STNil -> return (off, ())        STPair a b -> do @@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->                           1 -> fmap Right <$> go inarr b (off + 1)                           _ -> error "Invalid tag in accum memory"          if inarr -          then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val) +          then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)            else return (off1, val)        STArr ndim t          | inarr -> error "Nested arrays not supported in this implementation"          | otherwise -> do              (off1, sh) <- readShape addr# ndim off -            let eltsize = tSize' t Nothing +            let eltsize = tSize' (Proxy @s) t Nothing                  n = shapeSize sh              arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))              return (off1 + eltsize * n, arr) @@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil      go ShNil ish = ish      go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()  accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->    let -    go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO () +    go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()      go inarr ty SZ () val off = () <$ performAdd inarr ty val off      go inarr ty (SS dep) idx val off = case (ty, idx, val) of        (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off @@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr        (STAccum{}, _, _) -> error "Nested accumulators unsupported"      goArr :: SNat i' -> InvShape n -> STy t' -          -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO () +          -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()      goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off      goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off      goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do -      let i' = fromIntegral @(Rep TIx) @Int i +      let i' = fromIntegral @(Rep' s TIx) @Int i        when (i' < 0 || i' >= n) $          error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"        goArr depm1 ish t1 idx val (off + i' * ishSize ish) -    performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int +    performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int      performAddArr arraySz eltty val off = do -      let eltsize = tSize' eltty Nothing +      let eltsize = tSize' (Proxy @s) eltty Nothing        forM_ [0 .. arraySz - 1] $ \lini ->          performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)        return (off + arraySz * eltsize) -    performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int +    performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int      performAdd inarr ty val off = case ty of        STNil -> return off        STPair t1 t2 -> do @@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr                    (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)                    _ -> error "accumAdd: Tag mismatch for Either"          if inarr -          then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing)) +          then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))            else return off1        STArr n ty'          | inarr -> error "Nested array" @@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr    in () <$ go False topty top_depth top_index top_value 0 -withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)  withAccum ty start fun = do    -- The initial write must happen before any of the adds or reads, so it makes    -- sense to put it in IO together with the allocation, instead of in AcM. -  accum <- AcM $ do buffer <- mallocBytes (tSize ty start) +  accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)                      ptr <- newForeignPtr finalizerFree buffer                      let accum = Accum ty ptr                      accumWrite accum start                      return accum    b <- fun accum    out <- accumRead accum -  return (out, b) +  return (b, out)  inParallel :: [AcM s t] -> AcM s [t]  inParallel actions = AcM $ do diff --git a/src/Interpreter/Array.hs b/src/Interpreter/Array.hs deleted file mode 100644 index 54e0791..0000000 --- a/src/Interpreter/Array.hs +++ /dev/null @@ -1,45 +0,0 @@ -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TupleSections #-} -module Interpreter.Array where - -import Control.Monad.Trans.State.Strict -import Data.Foldable (traverse_) -import Data.Vector (Vector) -import qualified Data.Vector as V - -import Data - - -data Shape n where -  ShNil :: Shape Z -  ShCons :: Shape n -> Int -> Shape (S n) - -data Index n where -  IxNil :: Index Z -  IxCons :: Index n -> Int -> Index (S n) - -shapeSize :: Shape n -> Int -shapeSize ShNil = 0 -shapeSize (ShCons sh n) = shapeSize sh * n - - --- | TODO: this Vector is a boxed vector, which is horrendously inefficient. -data Array (n :: Nat) t = Array (Shape n) (Vector t) - -arrayShape :: Array n t -> Shape n -arrayShape (Array sh _) = sh - -arraySize :: Array n t -> Int -arraySize (Array sh _) = shapeSize sh - -arrayIndexLinear :: Array n t -> Int -> t -arrayIndexLinear (Array _ v) i = v V.! i - -arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) -arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f - --- | The Int is the linear index of the value. -traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () -traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs new file mode 100644 index 0000000..1ded773 --- /dev/null +++ b/src/Interpreter/Rep.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +module Interpreter.Rep where + +import GHC.TypeError + +import Array +import AST + + +type family Rep t where +  Rep TNil = () +  Rep (TPair a b) = (Rep a, Rep b) +  Rep (TEither a b) = Either (Rep a) (Rep b) +  Rep (TArr n t) = Array n (Rep t) +  Rep (TScal sty) = ScalRep sty +  Rep (TAccum t) = TypeError (Text "Accumulator in Rep") | 
