diff options
-rw-r--r-- | src/Interpreter.hs | 129 |
1 files changed, 124 insertions, 5 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d2b8074..8fb4a78 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -213,7 +213,7 @@ pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn pattern IIxCons i ix = InvCons i ix {-# COMPLETE IIxNil, IIxCons #-} -pattern IShNil :: () => n ~ Z => InvShape Z +pattern IShNil :: () => n ~ Z => InvShape n pattern IShNil = InvNil pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn pattern IShCons n sh = InvCons n sh @@ -234,8 +234,15 @@ instance Shapey Shape where shapeyCase ShNil k0 _ = k0 shapeyCase (ShCons sh n) _ k1 = k1 sh n +enumInvShape :: InvShape n -> [InvIndex n] +enumInvShape IShNil = [IIxNil] +enumInvShape (n `IShCons` sh) = [i `IIxCons` ix | i <- [0 .. n - 1], ix <- enumInvShape sh] + +enumShape :: Shape n -> [Index n] +enumShape = map uninvert . enumInvShape . invert + invert :: forall f n. Shapey f => f n -> Inverted f n -invert | Refl <- lemPlusZero @n = flip go shapeyNil +invert | Refl <- lemPlusZero @n = flip go InvNil where go :: forall n' m. f n' -> Inverted f m -> Inverted f (n' + m) go sh ish = shapeyCase sh @@ -255,6 +262,10 @@ piindexMatch (PIIxCons i pix) (IIxCons i' ix) | i == i' = piindexMatch pix ix | otherwise = Nothing +piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n +piindexConcat PIIxEnd ix = ix +piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix) + newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () @@ -358,7 +369,7 @@ accumAddSparse typ SZ ref () val = case typ of case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of (_, 0) -> return () (0, _) -> do - newrefarr <- arrayGenerateLinM (arrayShape val) (\i -> newAcSparse t SZ () (arrayIndexLinear val i)) + newrefarr <- newAcDense typ SZ () val join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) @@ -396,13 +407,121 @@ accumAddSparse typ (SS dep) ref idx val = case typ of -- Yep, ref already had a value in there, so we can just add -- val' to it recursively. Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) - STArr _ t -> _ ref idx val + STArr dim (t :: STy t) -> AcM $ do + refs <- readIORef ref + if shapeSize (arrayShape refs) == 0 + then do newrefarr <- newAcDense typ (SS dep) idx val + join $ atomicModifyIORef' ref $ \refarr -> + if shapeSize (arrayShape refarr) == 0 + then (newrefarr, return ()) + else -- someone was faster than us in initialising the reference! + (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up) + else do let sh = unTupRepIdx ShNil ShCons dim (fst val) + go (SS dep) (invert sh) idx (snd val) + (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) + (\piix subsh val' -> unAcM $ sequence_ + [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) + () (val' `arrayIndex` subix) + | subix <- enumShape subsh]) + where + go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) + -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array + -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray + -> r + go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray + go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element + go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 = + go dep' ish idx' val' + (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) + (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) -accumAddDense = _ +accumAddDense typ SZ ref () val = case typ of + STNil -> ((), return ()) + STPair t1 t2 -> + (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) + accumAddSparse t2 SZ (snd ref) () (snd val)) + STMaybe t -> + case (ref, val) of + (_, Nothing) -> (ref, return ()) + (Nothing, Just val') -> _ val' + (Just ref', Just val') -> _ ref' val' + STArr _ t -> AcM $ do + refs <- readIORef ref + case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of + (_, 0) -> return () + (0, _) -> do + newrefarr <- newAcDense typ SZ () val + join $ atomicModifyIORef' ref $ \refarr -> + if shapeSize (arrayShape refarr) == 0 + then (newrefarr, return ()) + else -- someone was faster than us in initialising the reference! + (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up) + _ | arrayShape refs == arrayShape val -> + sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) + | i <- [0 .. shapeSize (arrayShape val) - 1]] + | otherwise -> error "Array shape mismatch in accum add" + STScal sty -> AcM $ case sty of + STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STBool -> error "Accumulator of Bool" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" + +accumAddDense typ (SS dep) ref idx val = case typ of + STNil -> return () + STPair t1 t2 -> AcM $ do + (ref1, ref2) <- readIORef ref + case (idx, val) of + (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' + (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' + _ -> error "Index/value mismatch in pair accumulator add" + STMaybe t -> + AcM $ join $ atomicModifyIORef' ref $ \case + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (Nothing, do newac <- newAcDense t dep idx val + join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of + Nothing -> (Just newac, return ()) + Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) + STArr dim (t :: STy t) -> AcM $ do + refs <- readIORef ref + if shapeSize (arrayShape refs) == 0 + then do newrefarr <- newAcDense typ (SS dep) idx val + join $ atomicModifyIORef' ref $ \refarr -> + if shapeSize (arrayShape refarr) == 0 + then (newrefarr, return ()) + else -- someone was faster than us in initialising the reference! + (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up) + else do let sh = unTupRepIdx ShNil ShCons dim (fst val) + go (SS dep) (invert sh) idx (snd val) + (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) + (\piix subsh val' -> unAcM $ sequence_ + [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) + () (val' `arrayIndex` subix) + | subix <- enumShape subsh]) + where + go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) + -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array + -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray + -> r + go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray + go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element + go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 = + go dep' ish idx' val' + (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) + (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Nested accumulators" + STEither{} -> error "Bare Either in accumulator" -- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) -- accumAddVal typ SZ ac () val = case typ of |