summaryrefslogtreecommitdiff
path: root/src/Interpreter/Accum.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter/Accum.hs')
-rw-r--r--src/Interpreter/Accum.hs26
1 files changed, 16 insertions, 10 deletions
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index b6a91df..af7be1e 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -51,6 +51,7 @@ type family Rep' s t where
Rep' s TNil = ()
Rep' s (TPair a b) = (Rep' s a, Rep' s b)
Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TMaybe t) = Maybe (Rep' s t)
Rep' s (TArr n t) = Array n (Rep' s t)
Rep' s (TScal sty) = ScalRep sty
Rep' s (TAccum t) = Accum s t
@@ -61,16 +62,13 @@ data Accum s t = Accum (STy t) (ForeignPtr ())
tSize :: Proxy s -> STy t -> Rep' s t -> Int
tSize p ty x = tSize' p ty (Just x)
--- | Passing Nothing as the value means "this is (inside) an array element".
-tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int
-tSize' p typ val = case typ of
+tSize' :: Proxy s -> STy t -> Int
+tSize' p typ = case typ of
STNil -> 0
- STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val)
- STEither a b ->
- case val of
- Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing)
- Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking
- Just (Right y) -> 1 + tSize' p b (Just y) -- idem
+ STPair a b -> tSize' p a + tSize' p b
+ STEither a b -> 1 + max (tSize' p a) (tSize' p b)
+ -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
+ STMaybe t -> tSize' p (STEither STNil t)
STArr ndim t ->
case val of
Nothing -> error "Nested arrays not supported in this implementation"
@@ -99,7 +97,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
go inarr b (snd val) off1
STEither a b -> do
let !(I# off#) = off
- case val of
+ off1 <- case val of
Left x -> do
let !(I8# tag#) = 0
writeInt8# addr# off# tag#
@@ -108,6 +106,11 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
let !(I8# tag#) = 1
writeInt8# addr# off# tag#
go inarr b y (off + 1)
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
+ else return off1
+ -- Representation is the same, but add operation is different
+ STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
STArr _ t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
@@ -158,6 +161,8 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
if inarr
then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
else return (off1, val)
+ -- Representation is the same, but add operation is different
+ STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
@@ -219,6 +224,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
(STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
(STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
+ (STMaybe t, _, _) -> _ idx val
(STArr rank eltty, _, _)
| inarr -> error "Nested arrays"
| otherwise -> do