summaryrefslogtreecommitdiff
path: root/src/Interpreter/Accum.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter/Accum.hs')
-rw-r--r--src/Interpreter/Accum.hs76
1 files changed, 39 insertions, 37 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index d15ea10..b6a91df 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -15,7 +15,7 @@
module Interpreter.Accum (
AcM,
runAcM,
- Rep,
+ Rep',
Accum,
withAccum,
accumAdd,
@@ -25,6 +25,8 @@ module Interpreter.Accum (
import Control.Concurrent
import Control.Monad (when, forM_)
import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
import Foreign.Storable (sizeOf)
import GHC.Exts
import GHC.Float
@@ -33,10 +35,9 @@ import GHC.IO (IO(..))
import GHC.Word
import System.IO.Unsafe (unsafePerformIO)
+import Array
import AST
import Data
-import Interpreter.Array
-import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
newtype AcM s a = AcM (IO a)
@@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a)
runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
-type family Rep t where
- Rep TNil = ()
- Rep (TPair a b) = (Rep a, Rep b)
- Rep (TEither a b) = Either (Rep a) (Rep b)
- Rep (TArr n t) = Array n (Rep t)
- Rep (TScal sty) = ScalRep sty
- -- Rep (TAccum t) = _
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
-- | Floats and integers are accumulated; booleans are left as-is.
data Accum s t = Accum (STy t) (ForeignPtr ())
-tSize :: STy t -> Rep t -> Int
-tSize ty x = tSize' ty (Just x)
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
-- | Passing Nothing as the value means "this is (inside) an array element".
-tSize' :: STy t -> Maybe (Rep t) -> Int
-tSize' typ val = case typ of
+tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int
+tSize' p typ val = case typ of
STNil -> 0
- STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val)
+ STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val)
STEither a b ->
case val of
- Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing)
- Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking
- Just (Right y) -> 1 + tSize b y -- idem
+ Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing)
+ Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking
+ Just (Right y) -> 1 + tSize' p b (Just y) -- idem
STArr ndim t ->
case val of
Nothing -> error "Nested arrays not supported in this implementation"
- Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
STScal sty -> goScal sty
STAccum{} -> error "Nested accumulators unsupported"
where
@@ -86,10 +88,10 @@ tSize' typ val = case typ of
-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in IO, not in AcM.
-accumWrite :: forall s t. Accum s t -> Rep t -> IO ()
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> Rep t' -> Int -> IO Int
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
go inarr ty val off = case ty of
STNil -> return off
STPair a b -> do
@@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
off1 <- goShape (arrayShape val) off
- let eltsize = tSize' t Nothing
+ let eltsize = tSize' (Proxy @s) t Nothing
n = arraySize val
traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
return (off1 + eltsize * n)
@@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
in () <$ go False topty top_value 0
-accumRead :: forall s t. Accum s t -> AcM s (Rep t)
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> Int -> IO (Int, Rep t')
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
go inarr ty off = case ty of
STNil -> return (off, ())
STPair a b -> do
@@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
1 -> fmap Right <$> go inarr b (off + 1)
_ -> error "Invalid tag in accum memory"
if inarr
- then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val)
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
else return (off1, val)
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
(off1, sh) <- readShape addr# ndim off
- let eltsize = tSize' t Nothing
+ let eltsize = tSize' (Proxy @s) t Nothing
n = shapeSize sh
arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
return (off1 + eltsize * n, arr)
@@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil
go ShNil ish = ish
go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
-accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO ()
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
go inarr ty SZ () val off = () <$ performAdd inarr ty val off
go inarr ty (SS dep) idx val off = case (ty, idx, val) of
(STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
@@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(STAccum{}, _, _) -> error "Nested accumulators unsupported"
goArr :: SNat i' -> InvShape n -> STy t'
- -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO ()
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
- let i' = fromIntegral @(Rep TIx) @Int i
+ let i' = fromIntegral @(Rep' s TIx) @Int i
when (i' < 0 || i' >= n) $
error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
goArr depm1 ish t1 idx val (off + i' * ishSize ish)
- performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
performAddArr arraySz eltty val off = do
- let eltsize = tSize' eltty Nothing
+ let eltsize = tSize' (Proxy @s) eltty Nothing
forM_ [0 .. arraySz - 1] $ \lini ->
performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
return (off + arraySz * eltsize)
- performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
performAdd inarr ty val off = case ty of
STNil -> return off
STPair t1 t2 -> do
@@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
_ -> error "accumAdd: Tag mismatch for Either"
if inarr
- then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing))
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
else return off1
STArr n ty'
| inarr -> error "Nested array"
@@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
in () <$ go False topty top_depth top_index top_value 0
-withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
withAccum ty start fun = do
-- The initial write must happen before any of the adds or reads, so it makes
-- sense to put it in IO together with the allocation, instead of in AcM.
- accum <- AcM $ do buffer <- mallocBytes (tSize ty start)
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
ptr <- newForeignPtr finalizerFree buffer
let accum = Accum ty ptr
accumWrite accum start
return accum
b <- fun accum
out <- accumRead accum
- return (out, b)
+ return (b, out)
inParallel :: [AcM s t] -> AcM s [t]
inParallel actions = AcM $ do