diff options
-rw-r--r-- | src/Interpreter.hs | 48 |
1 files changed, 40 insertions, 8 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index f58cefb..62160aa 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -26,7 +26,7 @@ import AST import CHAD.Types import Data import Interpreter.Rep -import Data.Bifunctor (first) +import Data.Bifunctor (first, bimap) newtype AcM s a = AcM { unAcM :: IO a } @@ -174,7 +174,7 @@ withAccum t _ initval f = AcM $ do newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t) newAcSparse typ val = case typ of STNil -> return () - STPair{} -> newIORef =<<newAcDense typ val + STPair{} -> newIORef =<< newAcDense typ val STMaybe t -> newIORef =<< traverse (newAcDense t) val STArr _ t -> newIORef =<< traverse (newAcSparse t) val STScal{} -> newIORef val @@ -225,14 +225,46 @@ accumAddSparse typ SZ ref () val = case typ of unAcM $ accumAddSparse t1 SZ r1 () (fst val) unAcM $ accumAddSparse t2 SZ r2 () (snd val) STMaybe t -> - join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of - (Nothing, _) -> (ac, _) - (Just{}, Nothing) -> (ac, return ()) - (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val') - STArr _ t -> _ ref val - STScal{} -> _ ref val + case val of + Nothing -> return () + Just val' -> + -- Try adding val' to what's already in ref. The 'join' makes the snd + -- of the function's return value a _continuation_ that is run after + -- the critical section ends. + AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (ac, do newac <- newAcDense t val' + join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of + Nothing -> (Just newac, return ()) + Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val')) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val') + STArr _ t -> AcM $ do + refs <- readIORef ref + case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of + (_, 0) -> return () + (0, _) -> do + newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i)) + join $ atomicModifyIORef' ref $ \refarr -> + if shapeSize (arrayShape refarr) == 0 + then (newrefarr, return ()) + else -- someone was faster than us in initialising the reference! + (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start + _ | arrayShape refs == arrayShape val -> + sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) + | i <- [0 .. shapeSize (arrayShape val) - 1]] + | otherwise -> error "Array shape mismatch in accum add" + STScal sty -> AcM $ case sty of + STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STBool -> error "Accumulator of Bool" STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" + accumAddSparse typ (SS dep) ref idx val = case typ of STNil -> return () STPair t1 t2 -> _ ref idx val |