summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Interpreter.hs48
1 files changed, 40 insertions, 8 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index f58cefb..62160aa 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -26,7 +26,7 @@ import AST
import CHAD.Types
import Data
import Interpreter.Rep
-import Data.Bifunctor (first)
+import Data.Bifunctor (first, bimap)
newtype AcM s a = AcM { unAcM :: IO a }
@@ -174,7 +174,7 @@ withAccum t _ initval f = AcM $ do
newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
newAcSparse typ val = case typ of
STNil -> return ()
- STPair{} -> newIORef =<<newAcDense typ val
+ STPair{} -> newIORef =<< newAcDense typ val
STMaybe t -> newIORef =<< traverse (newAcDense t) val
STArr _ t -> newIORef =<< traverse (newAcSparse t) val
STScal{} -> newIORef val
@@ -225,14 +225,46 @@ accumAddSparse typ SZ ref () val = case typ of
unAcM $ accumAddSparse t1 SZ r1 () (fst val)
unAcM $ accumAddSparse t2 SZ r2 () (snd val)
STMaybe t ->
- join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of
- (Nothing, _) -> (ac, _)
- (Just{}, Nothing) -> (ac, return ())
- (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val')
- STArr _ t -> _ ref val
- STScal{} -> _ ref val
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ -- Try adding val' to what's already in ref. The 'join' makes the snd
+ -- of the function's return value a _continuation_ that is run after
+ -- the critical section ends.
+ AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (ac, do newac <- newAcDense t val'
+ join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
+ Nothing -> (Just newac, return ())
+ Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val'))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val')
+ STArr _ t -> AcM $ do
+ refs <- readIORef ref
+ case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
+ (_, 0) -> return ()
+ (0, _) -> do
+ newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i))
+ join $ atomicModifyIORef' ref $ \refarr ->
+ if shapeSize (arrayShape refarr) == 0
+ then (newrefarr, return ())
+ else -- someone was faster than us in initialising the reference!
+ (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start
+ _ | arrayShape refs == arrayShape val ->
+ sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
+ | i <- [0 .. shapeSize (arrayShape val) - 1]]
+ | otherwise -> error "Array shape mismatch in accum add"
+ STScal sty -> AcM $ case sty of
+ STI32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STI64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STBool -> error "Accumulator of Bool"
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
+
accumAddSparse typ (SS dep) ref idx val = case typ of
STNil -> return ()
STPair t1 t2 -> _ ref idx val