diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Interpreter.hs | 48 | 
1 files changed, 40 insertions, 8 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index f58cefb..62160aa 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -26,7 +26,7 @@ import AST  import CHAD.Types  import Data  import Interpreter.Rep -import Data.Bifunctor (first) +import Data.Bifunctor (first, bimap)  newtype AcM s a = AcM { unAcM :: IO a } @@ -174,7 +174,7 @@ withAccum t _ initval f = AcM $ do  newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)  newAcSparse typ val = case typ of    STNil -> return () -  STPair{} -> newIORef =<<newAcDense typ val +  STPair{} -> newIORef =<< newAcDense typ val    STMaybe t -> newIORef =<< traverse (newAcDense t) val    STArr _ t -> newIORef =<< traverse (newAcSparse t) val    STScal{} -> newIORef val @@ -225,14 +225,46 @@ accumAddSparse typ SZ ref () val = case typ of      unAcM $ accumAddSparse t1 SZ r1 () (fst val)      unAcM $ accumAddSparse t2 SZ r2 () (snd val)    STMaybe t -> -    join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of -                   (Nothing, _) -> (ac, _) -                   (Just{}, Nothing) -> (ac, return ()) -                   (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val') -  STArr _ t -> _ ref val -  STScal{} -> _ ref val +    case val of +      Nothing -> return () +      Just val' -> +        -- Try adding val' to what's already in ref. The 'join' makes the snd +        -- of the function's return value a _continuation_ that is run after +        -- the critical section ends. +        AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of +                       -- Oops, ref's contents was still sparse. Have to initialise +                       -- it first, then try again. +                       Nothing -> (ac, do newac <- newAcDense t val' +                                          join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of +                                                   Nothing -> (Just newac, return ()) +                                                   Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val')) +                       -- Yep, ref already had a value in there, so we can just add +                       -- val' to it recursively. +                       Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val') +  STArr _ t -> AcM $ do +    refs <- readIORef ref +    case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of +      (_, 0) -> return () +      (0, _) -> do +        newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t (arrayIndexLinear val i)) +        join $ atomicModifyIORef' ref $ \refarr -> +                 if shapeSize (arrayShape refarr) == 0 +                   then (newrefarr, return ()) +                   else -- someone was faster than us in initialising the reference! +                        (refarr, unAcM $ accumAddSparse typ SZ ref () val)  -- just try again from the start +      _ | arrayShape refs == arrayShape val -> +            sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) +                      | i <- [0 .. shapeSize (arrayShape val) - 1]] +        | otherwise -> error "Array shape mismatch in accum add" +  STScal sty -> AcM $ case sty of +    STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STBool -> error "Accumulator of Bool"    STAccum{} -> error "Nested accumulators"    STEither{} -> error "Bare Either in accumulator" +  accumAddSparse typ (SS dep) ref idx val = case typ of    STNil -> return ()    STPair t1 t2 -> _ ref idx val | 
