summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs129
1 files changed, 124 insertions, 5 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index d2b8074..8fb4a78 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -213,7 +213,7 @@ pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn
pattern IIxCons i ix = InvCons i ix
{-# COMPLETE IIxNil, IIxCons #-}
-pattern IShNil :: () => n ~ Z => InvShape Z
+pattern IShNil :: () => n ~ Z => InvShape n
pattern IShNil = InvNil
pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn
pattern IShCons n sh = InvCons n sh
@@ -234,8 +234,15 @@ instance Shapey Shape where
shapeyCase ShNil k0 _ = k0
shapeyCase (ShCons sh n) _ k1 = k1 sh n
+enumInvShape :: InvShape n -> [InvIndex n]
+enumInvShape IShNil = [IIxNil]
+enumInvShape (n `IShCons` sh) = [i `IIxCons` ix | i <- [0 .. n - 1], ix <- enumInvShape sh]
+
+enumShape :: Shape n -> [Index n]
+enumShape = map uninvert . enumInvShape . invert
+
invert :: forall f n. Shapey f => f n -> Inverted f n
-invert | Refl <- lemPlusZero @n = flip go shapeyNil
+invert | Refl <- lemPlusZero @n = flip go InvNil
where
go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m)
go sh ish = shapeyCase sh
@@ -255,6 +262,10 @@ piindexMatch (PIIxCons i pix) (IIxCons i' ix)
| i == i' = piindexMatch pix ix
| otherwise = Nothing
+piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
+piindexConcat PIIxEnd ix = ix
+piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix)
+
newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
@@ -358,7 +369,7 @@ accumAddSparse typ SZ ref () val = case typ of
case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
(_, 0) -> return ()
(0, _) -> do
- newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t SZ () (arrayIndexLinear val i))
+ newrefarr <- newAcDense typ SZ () val
join $ atomicModifyIORef' ref $ \refarr ->
if shapeSize (arrayShape refarr) == 0
then (newrefarr, return ())
@@ -396,13 +407,121 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
-- Yep, ref already had a value in there, so we can just add
-- val' to it recursively.
Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
- STArr _ t -> _ ref idx val
+ STArr dim (t :: STy t) -> AcM $ do
+ refs <- readIORef ref
+ if shapeSize (arrayShape refs) == 0
+ then do newrefarr <- newAcDense typ (SS dep) idx val
+ join $ atomicModifyIORef' ref $ \refarr ->
+ if shapeSize (arrayShape refarr) == 0
+ then (newrefarr, return ())
+ else -- someone was faster than us in initialising the reference!
+ (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up)
+ else do let sh = unTupRepIdx ShNil ShCons dim (fst val)
+ go (SS dep) (invert sh) idx (snd val)
+ (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj)
+ (\piix subsh val' -> unAcM $ sequence_
+ [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix)))
+ () (val' `arrayIndex` subix)
+ | subix <- enumShape subsh])
+ where
+ go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
+ -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array
+ -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray
+ -> r
+ go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray
+ go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element
+ go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 =
+ go dep' ish idx' val'
+ (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj)
+ (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm)
STScal{} -> error "Cannot index into scalar"
STAccum{} -> error "Nested accumulators"
STEither{} -> error "Bare Either in accumulator"
accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
-accumAddDense = _
+accumAddDense typ SZ ref () val = case typ of
+ STNil -> ((), return ())
+ STPair t1 t2 ->
+ (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
+ accumAddSparse t2 SZ (snd ref) () (snd val))
+ STMaybe t ->
+ case (ref, val) of
+ (_, Nothing) -> (ref, return ())
+ (Nothing, Just val') -> _ val'
+ (Just ref', Just val') -> _ ref' val'
+ STArr _ t -> AcM $ do
+ refs <- readIORef ref
+ case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of
+ (_, 0) -> return ()
+ (0, _) -> do
+ newrefarr <- newAcDense typ SZ () val
+ join $ atomicModifyIORef' ref $ \refarr ->
+ if shapeSize (arrayShape refarr) == 0
+ then (newrefarr, return ())
+ else -- someone was faster than us in initialising the reference!
+ (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up)
+ _ | arrayShape refs == arrayShape val ->
+ sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i)
+ | i <- [0 .. shapeSize (arrayShape val) - 1]]
+ | otherwise -> error "Array shape mismatch in accum add"
+ STScal sty -> AcM $ case sty of
+ STI32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STI64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STBool -> error "Accumulator of Bool"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
+
+accumAddDense typ (SS dep) ref idx val = case typ of
+ STNil -> return ()
+ STPair t1 t2 -> AcM $ do
+ (ref1, ref2) <- readIORef ref
+ case (idx, val) of
+ (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val'
+ (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val'
+ _ -> error "Index/value mismatch in pair accumulator add"
+ STMaybe t ->
+ AcM $ join $ atomicModifyIORef' ref $ \case
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (Nothing, do newac <- newAcDense t dep idx val
+ join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of
+ Nothing -> (Just newac, return ())
+ Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val)
+ STArr dim (t :: STy t) -> AcM $ do
+ refs <- readIORef ref
+ if shapeSize (arrayShape refs) == 0
+ then do newrefarr <- newAcDense typ (SS dep) idx val
+ join $ atomicModifyIORef' ref $ \refarr ->
+ if shapeSize (arrayShape refarr) == 0
+ then (newrefarr, return ())
+ else -- someone was faster than us in initialising the reference!
+ (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up)
+ else do let sh = unTupRepIdx ShNil ShCons dim (fst val)
+ go (SS dep) (invert sh) idx (snd val)
+ (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj)
+ (\piix subsh val' -> unAcM $ sequence_
+ [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix)))
+ () (val' `arrayIndex` subix)
+ | subix <- enumShape subsh])
+ where
+ go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
+ -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array
+ -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray
+ -> r
+ go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray
+ go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element
+ go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 =
+ go dep' ish idx' val'
+ (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj)
+ (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm)
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Nested accumulators"
+ STEither{} -> error "Bare Either in accumulator"
-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ())
-- accumAddVal typ SZ ac () val = case typ of