diff options
Diffstat (limited to 'src/Interpreter/Accum.hs')
-rw-r--r-- | src/Interpreter/Accum.hs | 76 |
1 files changed, 39 insertions, 37 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index d15ea10..b6a91df 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -15,7 +15,7 @@ module Interpreter.Accum ( AcM, runAcM, - Rep, + Rep', Accum, withAccum, accumAdd, @@ -25,6 +25,8 @@ module Interpreter.Accum ( import Control.Concurrent import Control.Monad (when, forM_) import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) import Foreign.Storable (sizeOf) import GHC.Exts import GHC.Float @@ -33,10 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) +import Array import AST import Data -import Interpreter.Array -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) newtype AcM s a = AcM (IO a) @@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a) runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m -type family Rep t where - Rep TNil = () - Rep (TPair a b) = (Rep a, Rep b) - Rep (TEither a b) = Either (Rep a) (Rep b) - Rep (TArr n t) = Array n (Rep t) - Rep (TScal sty) = ScalRep sty - -- Rep (TAccum t) = _ +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t -- | Floats and integers are accumulated; booleans are left as-is. data Accum s t = Accum (STy t) (ForeignPtr ()) -tSize :: STy t -> Rep t -> Int -tSize ty x = tSize' ty (Just x) +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) -- | Passing Nothing as the value means "this is (inside) an array element". -tSize' :: STy t -> Maybe (Rep t) -> Int -tSize' typ val = case typ of +tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int +tSize' p typ val = case typ of STNil -> 0 - STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) + STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val) STEither a b -> case val of - Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) - Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking - Just (Right y) -> 1 + tSize b y -- idem + Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing) + Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking + Just (Right y) -> 1 + tSize' p b (Just y) -- idem STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" - Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing STScal sty -> goScal sty STAccum{} -> error "Nested accumulators unsupported" where @@ -86,10 +88,10 @@ tSize' typ val = case typ of -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep t -> IO () +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> Rep t' -> Int -> IO Int + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int go inarr ty val off = case ty of STNil -> return off STPair a b -> do @@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do off1 <- goShape (arrayShape val) off - let eltsize = tSize' t Nothing + let eltsize = tSize' (Proxy @s) t Nothing n = arraySize val traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val return (off1 + eltsize * n) @@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> in () <$ go False topty top_value 0 -accumRead :: forall s t. Accum s t -> AcM s (Rep t) +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> Int -> IO (Int, Rep t') + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') go inarr ty off = case ty of STNil -> return (off, ()) STPair a b -> do @@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> 1 -> fmap Right <$> go inarr b (off + 1) _ -> error "Invalid tag in accum memory" if inarr - then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val) + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) else return (off1, val) STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do (off1, sh) <- readShape addr# ndim off - let eltsize = tSize' t Nothing + let eltsize = tSize' (Proxy @s) t Nothing n = shapeSize sh arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) return (off1 + eltsize * n, arr) @@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil go ShNil ish = ish go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO () + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () go inarr ty SZ () val off = () <$ performAdd inarr ty val off go inarr ty (SS dep) idx val off = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off @@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (STAccum{}, _, _) -> error "Nested accumulators unsupported" goArr :: SNat i' -> InvShape n -> STy t' - -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO () + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do - let i' = fromIntegral @(Rep TIx) @Int i + let i' = fromIntegral @(Rep' s TIx) @Int i when (i' < 0 || i' >= n) $ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" goArr depm1 ish t1 idx val (off + i' * ishSize ish) - performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int performAddArr arraySz eltty val off = do - let eltsize = tSize' eltty Nothing + let eltsize = tSize' (Proxy @s) eltty Nothing forM_ [0 .. arraySz - 1] $ \lini -> performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) return (off + arraySz * eltsize) - performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int performAdd inarr ty val off = case ty of STNil -> return off STPair t1 t2 -> do @@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) _ -> error "accumAdd: Tag mismatch for Either" if inarr - then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing)) + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) else return off1 STArr n ty' | inarr -> error "Nested array" @@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr in () <$ go False topty top_depth top_index top_value 0 -withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) withAccum ty start fun = do -- The initial write must happen before any of the adds or reads, so it makes -- sense to put it in IO together with the allocation, instead of in AcM. - accum <- AcM $ do buffer <- mallocBytes (tSize ty start) + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) ptr <- newForeignPtr finalizerFree buffer let accum = Accum ty ptr accumWrite accum start return accum b <- fun accum out <- accumRead accum - return (out, b) + return (b, out) inParallel :: [AcM s t] -> AcM s [t] inParallel actions = AcM $ do |