diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-15 11:32:34 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-15 11:32:34 +0100 |
commit | 6da98aedf2f28ec8848d1cb8f5605b0c7e64d644 (patch) | |
tree | c5723c920a200c001fd5f156f5f80f4a6eb11455 | |
parent | 095e7be937c2414cd34eb6288bd2c0856be63def (diff) |
Complete accumulator revamp!
-rw-r--r-- | src/AST.hs | 1 | ||||
-rw-r--r-- | src/CHAD.hs | 24 | ||||
-rw-r--r-- | src/Compile.hs | 6 | ||||
-rw-r--r-- | src/Interpreter.hs | 320 |
4 files changed, 92 insertions, 259 deletions
@@ -89,6 +89,7 @@ data Expr x env t where -- accumulation effect on monoids EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) + -- TODO: let this contain a OneHotTerm that is shared with EOneHot for uniformity in Simplify EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) diff --git a/src/CHAD.hs b/src/CHAD.hs index 4675434..d7d7da2 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -286,7 +286,7 @@ conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t)) | Idx2Me (Idx (Select env sto "merge") t) | Idx2Di (Idx (Select env sto "discr") t) @@ -409,11 +409,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = envpro prosub (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -441,7 +441,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do @@ -568,7 +568,7 @@ drev des = \case SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (EAccum ext SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) + (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI))) Idx2Me tupI -> Ret BTop @@ -944,12 +944,10 @@ drev des = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot ext (STArr n eltty) n - (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ)) - (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ) - SS{} | Refl <- lemAcValArrN (d2 eltty) n -> - EPair ext (EVar ext tIxN (IS (IS IZ))) - (EUnit ext (EVar ext (d2 eltty) IZ)))) $ + (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n) + (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ)))) + (ENil ext)) + (EVar ext (d2 eltty) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e @@ -1050,10 +1048,10 @@ drevScoped des argty argsto expr SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) SAccum -> RetScoped e0 subtape e1 sub $ - EWith ext (EZero ext argty) $ + EWith ext argty (EZero ext argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) + &. #ac (auto1 @(TAccum a)) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) diff --git a/src/Compile.hs b/src/Compile.hs index 424b28d..5c9d1a2 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -755,12 +755,10 @@ compile' env = \case emit $ SVarDecl True (repSTy t2) name2 e2' compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg - EWith _ e1 e2 -> do - let t = typeOf e1 - + EWith _ t e1 e2 -> do e1' <- compile' env e1 name1 <- genName - emit $ SVarDecl True (repSTy t) name1 e1' + emit $ SVarDecl True (repSTy (typeOf e1)) name1 e1' mcopy <- copyForWriting t name1 accname <- genName' "accum" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 11caac0..11184c9 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleContexts #-} @@ -19,12 +20,10 @@ module Interpreter ( Value(..), ) where -import Control.Monad (foldM, join, when) -import Data.Bifunctor (bimap) +import Control.Monad (foldM, join, when, forM_) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity -import Data.Kind (Type) import Data.Int (Int64) import Data.IORef import System.IO (hPutStrLn, stderr) @@ -64,8 +63,9 @@ interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value en interpret' env e = do let dep = ?depth let lenlimit = max 20 (100 - dep) - let trunc s | length s > lenlimit = take (lenlimit - 3) s ++ "..." - | otherwise = s + let replace a b = map (\c -> if c == a then b else c) + let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." + | otherwise = replace '\n' ' ' s when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e) res <- let ?depth = dep + 1 in interpret'Rec env e when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" @@ -243,7 +243,7 @@ onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) withAccum t _ initval f = AcM $ do accum <- newAcSparse t SAPHere () initval - out <- case f accum of AcM m -> m + out <- unAcM $ f accum val <- readAcSparse t accum return (out, val) @@ -262,59 +262,6 @@ newAcZero = \case STBool -> return () STAccum{} -> error "Nested accumulators" --- | Inverted index: the outermost index is at the /outside/ of this list. -data PartialInvIndex n m where - PIIxEnd :: PartialInvIndex m m - PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m - --- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list. -data Inverted (f :: Nat -> Type) n where - InvNil :: Inverted f Z - InvCons :: Int -> Inverted f n -> Inverted f (S n) - -type InvShape = Inverted Shape -type InvIndex = Inverted Index - -class Shapey f where - shapeyNil :: f Z - shapeyCons :: f n -> Int -> f (S n) - shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r -instance Shapey Index where - shapeyNil = IxNil - shapeyCons = IxCons - shapeyCase IxNil k0 _ = k0 - shapeyCase (IxCons idx i) _ k1 = k1 idx i -instance Shapey Shape where - shapeyNil = ShNil - shapeyCons = ShCons - shapeyCase ShNil k0 _ = k0 - shapeyCase (ShCons sh n) _ k1 = k1 sh n - -invert :: forall f n. Shapey f => f n -> Inverted f n -invert | Refl <- lemPlusZero @n = flip go InvNil - where - go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m) - go sh ish = shapeyCase sh - ish - (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish)) - -uninvert :: forall f n. Shapey f => Inverted f n -> f n -uninvert = go shapeyNil - where - go :: forall n' m. f n' -> Inverted f m -> f (n' + m) - go sh InvNil | Refl <- lemPlusZero @n' = sh - go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish - -piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m) -piindexMatch PIIxEnd ix = Just ix -piindexMatch (PIIxCons i pix) (InvCons i' ix) - | i == i' = piindexMatch pix ix - | otherwise = Nothing - -piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n -piindexConcat PIIxEnd ix = ix -piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) - newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) newAcSparse typ prj idx val = case (typ, prj) of (STNil, SAPHere) -> return () @@ -353,27 +300,8 @@ onehotArray :: Monad m onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = unTupRepIdx ShNil ShCons n arrsh' - in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero) - --- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a)) --- newAcDense typ SZ () val = case typ of --- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) --- STEither t1 t2 -> case val of --- Left x -> Left <$> newAcSparse t1 SZ () x --- Right y -> Right <$> newAcSparse t2 SZ () y --- _ -> error "newAcDense: invalid dense type" --- newAcDense typ (SS dep) idx val = case typ of --- STPair t1 t2 -> --- case (idx, val) of --- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 --- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' --- _ -> error "Index/value mismatch in newAc pair" --- STEither t1 t2 -> --- case (idx, val) of --- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' --- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' --- _ -> error "Index/value mismatch in newAc either" --- _ -> error "newAcDense: invalid dense type" + !linindex = toLinearIndex arrsh arrindex + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) readAcSparse typ val = case typ of @@ -390,14 +318,6 @@ readAcSparse typ val = case typ of STBool -> return () STAccum{} -> error "Nested accumulators" --- readAcDense :: STy t -> RepAcDense t -> IO (Rep t) --- readAcDense typ val = case typ of --- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) --- STEither t1 t2 -> case val of --- Left x -> Left <$> readAcSparse t1 x --- Right y -> Right <$> readAcSparse t2 y --- _ -> error "readAcDense: invalid dense type" - accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () accumAddSparse typ prj ref idx val = case (typ, prj) of (STNil, SAPHere) -> return () @@ -406,26 +326,66 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of case val of Nothing -> return () Just (val1, val2) -> - AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 - <*> newAcSparse t2 SAPHere () val2) - (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1 - unAcM $ accumAddSparse t2 SAPHere ac2 () val2) + realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 + <*> newAcSparse t2 SAPHere () val2) + (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 + accumAddSparse t2 SAPHere ac2 () val2) (STPair t1 t2, SAPFst prj') -> - AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) - (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val) + realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) + (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) (STPair t1 t2, SAPSnd prj') -> - AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) - (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val) - - (STEither t1 t2, SAPHere) -> _ ref val - (STEither t1 _, SAPLeft prj') -> _ ref idx val - (STEither _ t2, SAPRight prj') -> _ ref idx val - - (STMaybe t1, SAPHere) -> _ ref val - (STMaybe t1, SAPJust prj') -> _ ref idx val + realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) + (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - (STArr _ t1, SAPHere) -> _ ref val - (STArr n t, SAPArrIdx prj' _) -> _ ref idx val + (STEither{}, SAPHere) -> + case val of + Nothing -> return () + Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 + Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 + (STEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) + (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + (STEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) + (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + + (STMaybe{}, SAPHere) -> + case val of + Nothing -> return () + Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' + (STMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (newAcSparse t1 prj' idx val) + (\ac -> accumAddSparse t1 prj' ac idx val) + + (STArr _ t1, SAPHere) -> + let add ac = forM_ [0 .. arraySize ac - 1] $ \i -> + unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i) + in if arraySize val == 0 + then return () + else AcM $ join $ atomicModifyIORef' ref $ \ac -> + if arraySize ac == 0 + then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val + join $ atomicModifyIORef' ref $ \ac' -> + if arraySize ac == 0 + then (newac, return ()) + else (ac', add ac')) + else (ac, add ac) + (STArr n t1, SAPArrIdx prj' _) -> + let ((arrindex', arrsh'), idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = unTupRepIdx ShNil ShCons n arrsh' + linindex = toLinearIndex arrsh arrindex + add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val + in AcM $ join $ atomicModifyIORef' ref $ \ac -> + if arraySize ac == 0 + then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx + join $ atomicModifyIORef' ref $ \ac' -> + if arraySize ac == 0 + then (newac, return ()) + else (ac', add ac')) + else (ac, add ac) (STScal sty, SAPHere) -> AcM $ case sty of STI32 -> return () @@ -436,145 +396,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of (STAccum{}, _) -> error "Accumulators not allowed in source program" -realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO () +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = -- Try modifying what's already in ref. The 'join' makes the snd -- of the function's return value a _continuation_ that is run after -- the critical section ends. - join $ atomicModifyIORef' ref $ \ac -> case ac of - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (ac, do val <- makeval - join $ atomicModifyIORef' ref $ \ac' -> case ac' of - Nothing -> (Just val, return ()) - Just val' -> (ac', modifyval val')) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just val -> (ac, modifyval val) - -{- -accumAddSparse typ SZ ref () val = case typ of - STNil -> return () - STPair t1 t2 -> AcM $ do - (r1, r2) <- readIORef ref - unAcM $ accumAddSparse t1 SZ r1 () (fst val) - unAcM $ accumAddSparse t2 SZ r2 () (snd val) - STMaybe t -> - case val of - Nothing -> return () - Just val' -> - -- Try adding val' to what's already in ref. The 'join' makes the snd - -- of the function's return value a _continuation_ that is run after - -- the critical section ends. - AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (ac, do newac <- newAcDense t SZ () val' - join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of - Nothing -> (Just newac, return ()) - Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val')) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val') - STArr _ t -> AcM $ do - refs <- readIORef ref - case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of - (_, 0) -> return () - (0, _) -> do - newrefarr <- traverse (newAcSparse t SZ ()) val - join $ atomicModifyIORef' ref $ \refarr -> - if shapeSize (arrayShape refarr) == 0 - then (newrefarr, return ()) - else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up) - _ | arrayShape refs == arrayShape val -> - sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) - | i <- [0 .. shapeSize (arrayShape val) - 1]] - | otherwise -> error "Array shape mismatch in accum add" - STScal sty -> AcM $ case sty of - STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> error "Accumulator of Bool" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" - -accumAddSparse typ (SS dep) ref idx val = case typ of - STNil -> return () - STPair t1 t2 -> AcM $ do - (ref1, ref2) <- readIORef ref - case (idx, val) of - (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' - (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' - _ -> error "Index/value mismatch in pair accumulator add" - STMaybe t -> - AcM $ join $ atomicModifyIORef' ref $ \case - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (Nothing, do newac <- newAcDense t dep idx val - join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of - Nothing -> (Just newac, return ()) - Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) - STArr dim (t :: STy t) -> AcM $ do - refs <- readIORef ref - if shapeSize (arrayShape refs) == 0 - then do newrefarr <- newAcArray dim t (SS dep) idx val - join $ atomicModifyIORef' ref $ \refarr -> - if shapeSize (arrayShape refarr) == 0 - then (newrefarr, return ()) - else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up) - else do let sh = unTupRepIdx ShNil ShCons dim (fst val) - go (SS dep) (invert sh) idx (snd val) - (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) - (\piix subsh val' -> unAcM $ sequence_ - [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) - () (val' `arrayIndex` subix) - | subix <- enumShape subsh]) - where - go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) - -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array - -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray - -> r - go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray - go (SS dep') InvNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element - go (SS dep') (InvCons _ ish) (i, idx') val' kj k0 = - go dep' ish idx' val' - (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) - (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" - -accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) -accumAddDense typ SZ ref () val = case typ of - STPair t1 t2 -> - (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) - accumAddSparse t2 SZ (snd ref) () (snd val)) - STEither t1 t2 -> - case (ref, val) of - (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') - (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') - _ -> error "Mismatched Either in accumAddDense either" - _ -> error "accumAddDense: invalid dense type" - -accumAddDense typ (SS dep) ref idx val = case typ of - STPair t1 t2 -> - case (idx, val) of - (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val') - (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val') - _ -> error "Mismatched Either in accumAddDense pair" - STEither t1 t2 -> - case (ref, idx, val) of - (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') - (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') - _ -> error "Mismatched Either in accumAddDense either" - _ -> error "accumAddDense: invalid dense type" --} + AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (ac, do val <- makeval + join $ atomicModifyIORef' ref $ \ac' -> case ac' of + Nothing -> (Just val, return ()) + Just val' -> (ac', unAcM $ modifyval val')) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just val -> (ac, unAcM $ modifyval val) numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r @@ -593,12 +429,12 @@ integralIsIntegral STI64 = id unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n -unTupRepIdx nil _ SZ _ = nil +unTupRepIdx nil _ SZ _ = nil unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i tupRepIdx :: (forall m. f (S m) -> (f m, Int)) -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) -tupRepIdx _ SZ _ = () +tupRepIdx _ SZ _ = () tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i |