summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-15 11:32:34 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-15 11:32:34 +0100
commit6da98aedf2f28ec8848d1cb8f5605b0c7e64d644 (patch)
treec5723c920a200c001fd5f156f5f80f4a6eb11455
parent095e7be937c2414cd34eb6288bd2c0856be63def (diff)
Complete accumulator revamp!
-rw-r--r--src/AST.hs1
-rw-r--r--src/CHAD.hs24
-rw-r--r--src/Compile.hs6
-rw-r--r--src/Interpreter.hs320
4 files changed, 92 insertions, 259 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 1cdd710..ecd4647 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -89,6 +89,7 @@ data Expr x env t where
-- accumulation effect on monoids
EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t))
+ -- TODO: let this contain a OneHotTerm that is shared with EOneHot for uniformity in Simplify
EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 4675434..d7d7da2 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -286,7 +286,7 @@ conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
data Idx2 env sto t
- = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t))
| Idx2Me (Idx (Select env sto "merge") t)
| Idx2Di (Idx (Select env sto "discr") t)
@@ -409,11 +409,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
envpro
prosub
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -441,7 +441,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
-- Discrete values are left as-is, nothing to do
@@ -568,7 +568,7 @@ drev des = \case
SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
- (EAccum ext SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
+ (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -944,12 +944,10 @@ drev des = \case
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EOneHot ext (STArr n eltty) n
- (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ))
- (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ)
- SS{} | Refl <- lemAcValArrN (d2 eltty) n ->
- EPair ext (EVar ext tIxN (IS (IS IZ)))
- (EUnit ext (EVar ext (d2 eltty) IZ)))) $
+ (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n)
+ (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ))))
+ (ENil ext))
+ (EVar ext (d2 eltty) IZ)) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
EShape _ e
@@ -1050,10 +1048,10 @@ drevScoped des argty argsto expr
SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))
SAccum ->
RetScoped e0 subtape e1 sub $
- EWith ext (EZero ext argty) $
+ EWith ext argty (EZero ext argty) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #body (subList (bindingsBinds e0) subtape)
- &. #ac (auto1 @(TAccum (D2 a)))
+ &. #ac (auto1 @(TAccum a))
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: #body :++: #tl))
diff --git a/src/Compile.hs b/src/Compile.hs
index 424b28d..5c9d1a2 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -755,12 +755,10 @@ compile' env = \case
emit $ SVarDecl True (repSTy t2) name2 e2'
compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
- EWith _ e1 e2 -> do
- let t = typeOf e1
-
+ EWith _ t e1 e2 -> do
e1' <- compile' env e1
name1 <- genName
- emit $ SVarDecl True (repSTy t) name1 e1'
+ emit $ SVarDecl True (repSTy (typeOf e1)) name1 e1'
mcopy <- copyForWriting t name1
accname <- genName' "accum"
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 11caac0..11184c9 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
@@ -19,12 +20,10 @@ module Interpreter (
Value(..),
) where
-import Control.Monad (foldM, join, when)
-import Data.Bifunctor (bimap)
+import Control.Monad (foldM, join, when, forM_)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
-import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
import System.IO (hPutStrLn, stderr)
@@ -64,8 +63,9 @@ interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value en
interpret' env e = do
let dep = ?depth
let lenlimit = max 20 (100 - dep)
- let trunc s | length s > lenlimit = take (lenlimit - 3) s ++ "..."
- | otherwise = s
+ let replace a b = map (\c -> if c == a then b else c)
+ let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
+ | otherwise = replace '\n' ' ' s
when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
res <- let ?depth = dep + 1 in interpret'Rec env e
when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
@@ -243,7 +243,7 @@ onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
withAccum t _ initval f = AcM $ do
accum <- newAcSparse t SAPHere () initval
- out <- case f accum of AcM m -> m
+ out <- unAcM $ f accum
val <- readAcSparse t accum
return (out, val)
@@ -262,59 +262,6 @@ newAcZero = \case
STBool -> return ()
STAccum{} -> error "Nested accumulators"
--- | Inverted index: the outermost index is at the /outside/ of this list.
-data PartialInvIndex n m where
- PIIxEnd :: PartialInvIndex m m
- PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m
-
--- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list.
-data Inverted (f :: Nat -> Type) n where
- InvNil :: Inverted f Z
- InvCons :: Int -> Inverted f n -> Inverted f (S n)
-
-type InvShape = Inverted Shape
-type InvIndex = Inverted Index
-
-class Shapey f where
- shapeyNil :: f Z
- shapeyCons :: f n -> Int -> f (S n)
- shapeyCase :: f n -> (n ~ Z => r) -> (forall m. n ~ S m => f m -> Int -> r) -> r
-instance Shapey Index where
- shapeyNil = IxNil
- shapeyCons = IxCons
- shapeyCase IxNil k0 _ = k0
- shapeyCase (IxCons idx i) _ k1 = k1 idx i
-instance Shapey Shape where
- shapeyNil = ShNil
- shapeyCons = ShCons
- shapeyCase ShNil k0 _ = k0
- shapeyCase (ShCons sh n) _ k1 = k1 sh n
-
-invert :: forall f n. Shapey f => f n -> Inverted f n
-invert | Refl <- lemPlusZero @n = flip go InvNil
- where
- go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m)
- go sh ish = shapeyCase sh
- ish
- (\sh' n -> case lemPlusSuccRight @n' @m of Refl -> go sh' (InvCons n ish))
-
-uninvert :: forall f n. Shapey f => Inverted f n -> f n
-uninvert = go shapeyNil
- where
- go :: forall n' m. f n' -> Inverted f m -> f (n' + m)
- go sh InvNil | Refl <- lemPlusZero @n' = sh
- go sh (InvCons n (ish :: Inverted f predm)) | Refl <- lemPlusSuccRight @n' @predm = go (shapeyCons sh n) ish
-
-piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m)
-piindexMatch PIIxEnd ix = Just ix
-piindexMatch (PIIxCons i pix) (InvCons i' ix)
- | i == i' = piindexMatch pix ix
- | otherwise = Nothing
-
-piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
-piindexConcat PIIxEnd ix = ix
-piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-
newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)
newAcSparse typ prj idx val = case (typ, prj) of
(STNil, SAPHere) -> return ()
@@ -353,27 +300,8 @@ onehotArray :: Monad m
onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = unTupRepIdx ShNil ShCons n arrsh'
- in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero)
-
--- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a))
--- newAcDense typ SZ () val = case typ of
--- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
--- STEither t1 t2 -> case val of
--- Left x -> Left <$> newAcSparse t1 SZ () x
--- Right y -> Right <$> newAcSparse t2 SZ () y
--- _ -> error "newAcDense: invalid dense type"
--- newAcDense typ (SS dep) idx val = case typ of
--- STPair t1 t2 ->
--- case (idx, val) of
--- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
--- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
--- _ -> error "Index/value mismatch in newAc pair"
--- STEither t1 t2 ->
--- case (idx, val) of
--- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
--- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
--- _ -> error "Index/value mismatch in newAc either"
--- _ -> error "newAcDense: invalid dense type"
+ !linindex = toLinearIndex arrsh arrindex
+ in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero)
readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))
readAcSparse typ val = case typ of
@@ -390,14 +318,6 @@ readAcSparse typ val = case typ of
STBool -> return ()
STAccum{} -> error "Nested accumulators"
--- readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
--- readAcDense typ val = case typ of
--- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
--- STEither t1 t2 -> case val of
--- Left x -> Left <$> readAcSparse t1 x
--- Right y -> Right <$> readAcSparse t2 y
--- _ -> error "readAcDense: invalid dense type"
-
accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()
accumAddSparse typ prj ref idx val = case (typ, prj) of
(STNil, SAPHere) -> return ()
@@ -406,26 +326,66 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
case val of
Nothing -> return ()
Just (val1, val2) ->
- AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
- <*> newAcSparse t2 SAPHere () val2)
- (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1
- unAcM $ accumAddSparse t2 SAPHere ac2 () val2)
+ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
+ <*> newAcSparse t2 SAPHere () val2)
+ (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
+ accumAddSparse t2 SAPHere ac2 () val2)
(STPair t1 t2, SAPFst prj') ->
- AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
- (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val)
+ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
+ (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
(STPair t1 t2, SAPSnd prj') ->
- AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
- (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val)
-
- (STEither t1 t2, SAPHere) -> _ ref val
- (STEither t1 _, SAPLeft prj') -> _ ref idx val
- (STEither _ t2, SAPRight prj') -> _ ref idx val
-
- (STMaybe t1, SAPHere) -> _ ref val
- (STMaybe t1, SAPJust prj') -> _ ref idx val
+ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
+ (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
- (STArr _ t1, SAPHere) -> _ ref val
- (STArr n t, SAPArrIdx prj' _) -> _ ref idx val
+ (STEither{}, SAPHere) ->
+ case val of
+ Nothing -> return ()
+ Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
+ Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
+ (STEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
+ (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ (STEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
+ (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+
+ (STMaybe{}, SAPHere) ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
+ (STMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+ (\ac -> accumAddSparse t1 prj' ac idx val)
+
+ (STArr _ t1, SAPHere) ->
+ let add ac = forM_ [0 .. arraySize ac - 1] $ \i ->
+ unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i)
+ in if arraySize val == 0
+ then return ()
+ else AcM $ join $ atomicModifyIORef' ref $ \ac ->
+ if arraySize ac == 0
+ then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val
+ join $ atomicModifyIORef' ref $ \ac' ->
+ if arraySize ac == 0
+ then (newac, return ())
+ else (ac', add ac'))
+ else (ac, add ac)
+ (STArr n t1, SAPArrIdx prj' _) ->
+ let ((arrindex', arrsh'), idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ linindex = toLinearIndex arrsh arrindex
+ add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val
+ in AcM $ join $ atomicModifyIORef' ref $ \ac ->
+ if arraySize ac == 0
+ then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx
+ join $ atomicModifyIORef' ref $ \ac' ->
+ if arraySize ac == 0
+ then (newac, return ())
+ else (ac', add ac'))
+ else (ac, add ac)
(STScal sty, SAPHere) -> AcM $ case sty of
STI32 -> return ()
@@ -436,145 +396,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
(STAccum{}, _) -> error "Accumulators not allowed in source program"
-realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO ()
+realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
-- Try modifying what's already in ref. The 'join' makes the snd
-- of the function's return value a _continuation_ that is run after
-- the critical section ends.
- join $ atomicModifyIORef' ref $ \ac -> case ac of
- -- Oops, ref's contents was still sparse. Have to initialise
- -- it first, then try again.
- Nothing -> (ac, do val <- makeval
- join $ atomicModifyIORef' ref $ \ac' -> case ac' of
- Nothing -> (Just val, return ())
- Just val' -> (ac', modifyval val'))
- -- Yep, ref already had a value in there, so we can just add
- -- val' to it recursively.
- Just val -> (ac, modifyval val)
-
-{-
-accumAddSparse typ SZ ref () val = case typ of
- STNil -> return ()
- STPair t1 t2 -> AcM $ do
- (r1, r2) <- readIORef ref
- unAcM $ accumAddSparse t1 SZ r1 () (fst val)
- unAcM $ accumAddSparse t2 SZ r2 () (snd val)
- STMaybe t ->
- case val of
- Nothing -> return ()
- Just val' ->
- -- Try adding val' to what's already in ref. The 'join' makes the snd
- -- of the function's return value a _continuation_ that is run after
- -- the critical section ends.
- AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
- -- Oops, ref's contents was still sparse. Have to initialise
- -- it first, then try again.
- Nothing -> (ac, do newac <- newAcDense t SZ () val'
- join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
- Nothing -> (Just newac, return ())
- Just ac2' -> bimap Just unAcM (accumAddDense t SZ ac2' () val'))
- -- Yep, ref already had a value in there, so we can just add
- -- val' to it recursively.
- Just ac' -> bimap Just unAcM (accumAddDense t SZ ac' () val')
- STArr _ t -> AcM $ do
- refs <- readIORef ref
- case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
- (_, 0) -> return ()
- (0, _) -> do
- newrefarr <- traverse (newAcSparse t SZ ()) val
- join $ atomicModifyIORef' ref $ \refarr ->
- if shapeSize (arrayShape refarr) == 0
- then (newrefarr, return ())
- else -- someone was faster than us in initialising the reference!
- (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up)
- _ | arrayShape refs == arrayShape val ->
- sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
- | i <- [0 .. shapeSize (arrayShape val) - 1]]
- | otherwise -> error "Array shape mismatch in accum add"
- STScal sty -> AcM $ case sty of
- STI32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STI64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> error "Accumulator of Bool"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-
-accumAddSparse typ (SS dep) ref idx val = case typ of
- STNil -> return ()
- STPair t1 t2 -> AcM $ do
- (ref1, ref2) <- readIORef ref
- case (idx, val) of
- (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val'
- (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val'
- _ -> error "Index/value mismatch in pair accumulator add"
- STMaybe t ->
- AcM $ join $ atomicModifyIORef' ref $ \case
- -- Oops, ref's contents was still sparse. Have to initialise
- -- it first, then try again.
- Nothing -> (Nothing, do newac <- newAcDense t dep idx val
- join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
- Nothing -> (Just newac, return ())
- Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val))
- -- Yep, ref already had a value in there, so we can just add
- -- val' to it recursively.
- Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
- STArr dim (t :: STy t) -> AcM $ do
- refs <- readIORef ref
- if shapeSize (arrayShape refs) == 0
- then do newrefarr <- newAcArray dim t (SS dep) idx val
- join $ atomicModifyIORef' ref $ \refarr ->
- if shapeSize (arrayShape refarr) == 0
- then (newrefarr, return ())
- else -- someone was faster than us in initialising the reference!
- (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up)
- else do let sh = unTupRepIdx ShNil ShCons dim (fst val)
- go (SS dep) (invert sh) idx (snd val)
- (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj)
- (\piix subsh val' -> unAcM $ sequence_
- [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix)))
- () (val' `arrayIndex` subix)
- | subix <- enumShape subsh])
- where
- go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
- -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array
- -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray
- -> r
- go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray
- go (SS dep') InvNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element
- go (SS dep') (InvCons _ ish) (i, idx') val' kj k0 =
- go dep' ish idx' val'
- (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj)
- (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm)
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-
-accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
-accumAddDense typ SZ ref () val = case typ of
- STPair t1 t2 ->
- (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
- accumAddSparse t2 SZ (snd ref) () (snd val))
- STEither t1 t2 ->
- case (ref, val) of
- (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val')
- (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val')
- _ -> error "Mismatched Either in accumAddDense either"
- _ -> error "accumAddDense: invalid dense type"
-
-accumAddDense typ (SS dep) ref idx val = case typ of
- STPair t1 t2 ->
- case (idx, val) of
- (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val')
- (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val')
- _ -> error "Mismatched Either in accumAddDense pair"
- STEither t1 t2 ->
- case (ref, idx, val) of
- (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val')
- (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
- _ -> error "Mismatched Either in accumAddDense either"
- _ -> error "accumAddDense: invalid dense type"
--}
+ AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (ac, do val <- makeval
+ join $ atomicModifyIORef' ref $ \ac' -> case ac' of
+ Nothing -> (Just val, return ())
+ Just val' -> (ac', unAcM $ modifyval val'))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just val -> (ac, unAcM $ modifyval val)
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
@@ -593,12 +429,12 @@ integralIsIntegral STI64 = id
unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
-unTupRepIdx nil _ SZ _ = nil
+unTupRepIdx nil _ SZ _ = nil
unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i
tupRepIdx :: (forall m. f (S m) -> (f m, Int))
-> SNat n -> f n -> Rep (Tup (Replicate n TIx))
-tupRepIdx _ SZ _ = ()
+tupRepIdx _ SZ _ = ()
tupRepIdx uncons (SS n) tup =
let (tup', i) = uncons tup
in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i