diff options
-rw-r--r-- | src/Interpreter.hs | 239 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 14 |
2 files changed, 53 insertions, 200 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 8fb4a78..01d15f1 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -29,7 +29,7 @@ import AST import CHAD.Types import Data import Interpreter.Rep -import Data.Bifunctor (first, bimap) +import Data.Bifunctor (bimap) newtype AcM s a = AcM { unAcM :: IO a } @@ -194,11 +194,6 @@ data PartialInvIndex n m where PIIxEnd :: PartialInvIndex m m PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m --- | Inverted shape: the outermost dimension is at the /outside/ of this list. -data PartialInvShape n m where - PIShEnd :: PartialInvShape m m - PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m - -- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list. data Inverted (f :: Nat -> Type) n where InvNil :: Inverted f Z @@ -269,52 +264,54 @@ piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix) newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () - STPair{} -> newIORef =<< newAcDense typ SZ () val + STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val - STArr{} -> newIORef =<< newAcDense typ SZ () val + STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val STScal{} -> newIORef val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" newAcSparse typ (SS dep) idx val = case typ of STNil -> return () - STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val + STPair t1 t2 -> newIORef =<< case (idx, val) of + (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 + (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' + _ -> error "Index/value mismatch in newAc pair" STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val - STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val + STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" +newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) +newAcArray _ t SZ _ val = + traverse (newAcSparse t SZ ()) val +newAcArray dim (t :: STy t) dep@SS{} idx val = do + let sh = unTupRepIdx ShNil ShCons dim (fst val) + go dep dim idx (snd val) $ \arr position -> + arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of + Just i' -> return $ arr `arrayIndex` i' + Nothing -> newAcZero t) + where + go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r + go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd + go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd + go (SS dep') (SS dim') (i, idx') val' k = + go dep' dim' idx' val' $ \arr pish -> + k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) + newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) newAcDense typ SZ () val = case typ of - STNil -> return () - STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 SZ () x Right y -> Right <$> newAcSparse t2 SZ () y - STMaybe t -> traverse (newAcSparse t SZ ()) val - STArr _ t -> traverse (newAcSparse t SZ ()) val - STScal{} -> return val - STAccum{} -> error "Nested accumulators" + _ -> error "newAcDense: invalid dense type" newAcDense typ (SS dep) idx val = case typ of - STNil -> return () - STPair{} -> newAcDense typ (SS dep) idx val - STMaybe t -> Just <$> newAcSparse t dep idx val - STArr dim (t :: STy t) -> do - let sh = unTupRepIdx ShNil ShCons dim (fst val) - go (SS dep) dim idx (snd val) $ \arr position -> - arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of - Just i' -> return $ arr `arrayIndex` i' - Nothing -> newAcZero t) - where - go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r - go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd - go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd - go (SS dep') (SS dim') (i, idx') val' k = - go dep' dim' idx' val' $ \arr pish -> - k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" + STEither t1 t2 -> + case (idx, val) of + (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' + (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' + _ -> error "Index/value mismatch in newAc either" + _ -> error "newAcDense: invalid dense type" readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) readAcSparse typ val = case typ of @@ -330,15 +327,10 @@ readAcSparse typ val = case typ of readAcDense :: STy t -> RepAcDense t -> IO (Rep t) readAcDense typ val = case typ of - STNil -> return () - STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) STEither t1 t2 -> case val of Left x -> Left <$> readAcSparse t1 x Right y -> Right <$> readAcSparse t2 y - STMaybe t -> traverse (readAcSparse t) val - STArr _ t -> traverse (readAcSparse t) val - STScal{} -> return val - STAccum{} -> error "Nested accumulators" + _ -> error "readAcDense: invalid dense type" accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () accumAddSparse typ SZ ref () val = case typ of @@ -369,7 +361,7 @@ accumAddSparse typ SZ ref () val = case typ of case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of (_, 0) -> return () (0, _) -> do - newrefarr <- newAcDense typ SZ () val + newrefarr <- traverse (newAcSparse t SZ ()) val join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) @@ -410,7 +402,7 @@ accumAddSparse typ (SS dep) ref idx val = case typ of STArr dim (t :: STy t) -> AcM $ do refs <- readIORef ref if shapeSize (arrayShape refs) == 0 - then do newrefarr <- newAcDense typ (SS dep) idx val + then do newrefarr <- newAcArray dim t (SS dep) idx val join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) @@ -440,160 +432,21 @@ accumAddSparse typ (SS dep) ref idx val = case typ of accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) accumAddDense typ SZ ref () val = case typ of - STNil -> ((), return ()) - STPair t1 t2 -> - (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) - accumAddSparse t2 SZ (snd ref) () (snd val)) - STMaybe t -> + STEither t1 t2 -> case (ref, val) of - (_, Nothing) -> (ref, return ()) - (Nothing, Just val') -> _ val' - (Just ref', Just val') -> _ ref' val' - STArr _ t -> AcM $ do - refs <- readIORef ref - case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of - (_, 0) -> return () - (0, _) -> do - newrefarr <- newAcDense typ SZ () val - join $ atomicModifyIORef' ref $ \refarr -> - if shapeSize (arrayShape refarr) == 0 - then (newrefarr, return ()) - else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up) - _ | arrayShape refs == arrayShape val -> - sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) - | i <- [0 .. shapeSize (arrayShape val) - 1]] - | otherwise -> error "Array shape mismatch in accum add" - STScal sty -> AcM $ case sty of - STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> error "Accumulator of Bool" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" + (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') + (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') + _ -> error "Mismatched Either in accumAdd" + _ -> error "accumAddDense: invalid dense type" accumAddDense typ (SS dep) ref idx val = case typ of - STNil -> return () - STPair t1 t2 -> AcM $ do - (ref1, ref2) <- readIORef ref - case (idx, val) of - (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' - (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' - _ -> error "Index/value mismatch in pair accumulator add" - STMaybe t -> - AcM $ join $ atomicModifyIORef' ref $ \case - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (Nothing, do newac <- newAcDense t dep idx val - join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of - Nothing -> (Just newac, return ()) - Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) - STArr dim (t :: STy t) -> AcM $ do - refs <- readIORef ref - if shapeSize (arrayShape refs) == 0 - then do newrefarr <- newAcDense typ (SS dep) idx val - join $ atomicModifyIORef' ref $ \refarr -> - if shapeSize (arrayShape refarr) == 0 - then (newrefarr, return ()) - else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up) - else do let sh = unTupRepIdx ShNil ShCons dim (fst val) - go (SS dep) (invert sh) idx (snd val) - (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) - (\piix subsh val' -> unAcM $ sequence_ - [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) - () (val' `arrayIndex` subix) - | subix <- enumShape subsh]) - where - go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) - -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array - -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray - -> r - go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray - go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element - go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 = - go dep' ish idx' val' - (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) - (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" + STEither t1 t2 -> + case (ref, idx, val) of + (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') + (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') + _ -> error "Mismatched Either in accumAdd" + _ -> error "accumAddDense: invalid dense type" --- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) --- accumAddVal typ SZ ac () val = case typ of --- STNil -> ((), return ()) --- STPair t1 t2 -> --- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val) --- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val) --- in ((ac1', ac2'), m1 >> m2) --- STMaybe t -> case t of --- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val) --- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def --- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ()) --- def = (ac, accumAddValM t ac val) --- STArr n t --- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val --- STEither{} -> error "Bare Either in accumulator" --- _ -> _ --- accumAddVal typ (SS dep) ac idx val = case typ of --- STNil -> ((), return ()) --- STPair t1 t2 -> --- case (idx, val) of --- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val' --- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val' --- _ -> error "Inconsistent idx and val in accumulator add operation" --- _ -> _ - --- accumAddValME :: STy a -> STy b --- -> IORef (Maybe (Either (RepAc a) (RepAc b))) --- -> Maybe (Either (Rep a) (Rep b)) --- -> AcM s () --- accumAddValME t1 t2 ac val = --- case val of --- Nothing -> return () --- Just val' -> --- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of --- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val)) --- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1 --- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2 --- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either" - --- initAccumOrTryAgainME :: STy a -> STy b --- -> IORef (Maybe (Either (RepAc a) (RepAc b))) --- -> Either (Rep a) (Rep b) --- -> IO () --- -> IO () --- initAccumOrTryAgainME t1 t2 ac val onRace = do --- newContents <- case val of Left x -> Left <$> makeRepAc t1 x --- Right y -> Right <$> makeRepAc t2 y --- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) --- value@Just{} -> (value, onRace)) - --- accumAddValM :: STy t --- -> IORef (Maybe (RepAc t)) --- -> Maybe (Rep t) --- -> AcM s () --- accumAddValM t ac val = --- case val of --- Nothing -> return () --- Just val' -> --- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of --- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val)) --- Just x -> first Just $ accumAddVal t SZ x () val' - --- initAccumOrTryAgainM :: STy t --- -> IORef (Maybe (RepAc t)) --- -> Rep t --- -> IO () --- -> IO () --- initAccumOrTryAgainM t ac val onRace = do --- newContents <- makeRepAc t val --- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) --- value@Just{} -> (value, onRace)) numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index aa2fcc9..c0c38b2 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -22,19 +22,19 @@ type family Rep t where -- Mutable, and has a zero. The zero may not be O(1), but RepAcSparse (D2 t) will have an O(1) zero. type family RepAcSparse t where RepAcSparse TNil = () - RepAcSparse (TPair a b) = IORef (RepAcDense (TPair a b)) + RepAcSparse (TPair a b) = IORef (RepAcSparse a, RepAcSparse b) RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid") RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t)) -- allow the value to be dense, because the Maybe's zero can be used for the contents - RepAcSparse (TArr n t) = IORef (RepAcDense (TArr n t)) -- empty array is zero + RepAcSparse (TArr n t) = IORef (Array n (RepAcSparse t)) -- empty array is zero RepAcSparse (TScal sty) = IORef (ScalRep sty) RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") -- Immutable, and does not necessarily have a zero. type family RepAcDense t where RepAcDense TNil = () - RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b) + -- RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b) RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b) - RepAcDense (TMaybe t) = Maybe (RepAcSparse t) - RepAcDense (TArr n t) = Array n (RepAcSparse t) - RepAcDense (TScal sty) = ScalRep sty - RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") + -- RepAcDense (TMaybe t) = RepAcSparse (TMaybe t) -- ^ This can be optimised to TMaybe (RepAcSparse t), but that makes accumAddDense very hard to write. And in any case, we don't need it because D2 will not produce Maybe of Maybe. + -- RepAcDense (TArr n t) = Array n (RepAcSparse t) + -- RepAcDense (TScal sty) = ScalRep sty + -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") |