diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-09-22 23:11:37 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-22 23:11:37 +0200 |
commit | 7bc10684870e2249efbdcdddb4950f52d8527699 (patch) | |
tree | 8ff3ffd5966ead77edd9f66b61df9e92dc237a47 /src/Interpreter.hs | |
parent | 1d14fbd9665b25aef5672e0652d5e7e27bcd4908 (diff) |
Interpreter typechecks, at the cost of compositionality of RepAc
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 239 |
1 files changed, 46 insertions, 193 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 8fb4a78..01d15f1 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -29,7 +29,7 @@ import AST import CHAD.Types import Data import Interpreter.Rep -import Data.Bifunctor (first, bimap) +import Data.Bifunctor (bimap) newtype AcM s a = AcM { unAcM :: IO a } @@ -194,11 +194,6 @@ data PartialInvIndex n m where PIIxEnd :: PartialInvIndex m m PIIxCons :: Int -> PartialInvIndex n m -> PartialInvIndex (S n) m --- | Inverted shape: the outermost dimension is at the /outside/ of this list. -data PartialInvShape n m where - PIShEnd :: PartialInvShape m m - PIShCons :: Int -> PartialInvShape n m -> PartialInvShape (S n) m - -- | Inverted shapey thing: the outermost dimension is at the /outside/ of this list. data Inverted (f :: Nat -> Type) n where InvNil :: Inverted f Z @@ -269,52 +264,54 @@ piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix) newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () - STPair{} -> newIORef =<< newAcDense typ SZ () val + STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val - STArr{} -> newIORef =<< newAcDense typ SZ () val + STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val STScal{} -> newIORef val STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" newAcSparse typ (SS dep) idx val = case typ of STNil -> return () - STPair{} -> newIORef =<< newAcDense typ (SS dep) idx val + STPair t1 t2 -> newIORef =<< case (idx, val) of + (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 + (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' + _ -> error "Index/value mismatch in newAc pair" STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val - STArr{} -> newIORef =<< newAcDense typ (SS dep) idx val + STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Nested accumulators" STEither{} -> error "Bare Either in accumulator" +newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) +newAcArray _ t SZ _ val = + traverse (newAcSparse t SZ ()) val +newAcArray dim (t :: STy t) dep@SS{} idx val = do + let sh = unTupRepIdx ShNil ShCons dim (fst val) + go dep dim idx (snd val) $ \arr position -> + arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of + Just i' -> return $ arr `arrayIndex` i' + Nothing -> newAcZero t) + where + go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r + go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd + go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd + go (SS dep') (SS dim') (i, idx') val' k = + go dep' dim' idx' val' $ \arr pish -> + k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) + newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) newAcDense typ SZ () val = case typ of - STNil -> return () - STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 SZ () x Right y -> Right <$> newAcSparse t2 SZ () y - STMaybe t -> traverse (newAcSparse t SZ ()) val - STArr _ t -> traverse (newAcSparse t SZ ()) val - STScal{} -> return val - STAccum{} -> error "Nested accumulators" + _ -> error "newAcDense: invalid dense type" newAcDense typ (SS dep) idx val = case typ of - STNil -> return () - STPair{} -> newAcDense typ (SS dep) idx val - STMaybe t -> Just <$> newAcSparse t dep idx val - STArr dim (t :: STy t) -> do - let sh = unTupRepIdx ShNil ShCons dim (fst val) - go (SS dep) dim idx (snd val) $ \arr position -> - arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of - Just i' -> return $ arr `arrayIndex` i' - Nothing -> newAcZero t) - where - go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r - go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd - go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd - go (SS dep') (SS dim') (i, idx') val' k = - go dep' dim' idx' val' $ \arr pish -> - k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" + STEither t1 t2 -> + case (idx, val) of + (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' + (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' + _ -> error "Index/value mismatch in newAc either" + _ -> error "newAcDense: invalid dense type" readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) readAcSparse typ val = case typ of @@ -330,15 +327,10 @@ readAcSparse typ val = case typ of readAcDense :: STy t -> RepAcDense t -> IO (Rep t) readAcDense typ val = case typ of - STNil -> return () - STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) STEither t1 t2 -> case val of Left x -> Left <$> readAcSparse t1 x Right y -> Right <$> readAcSparse t2 y - STMaybe t -> traverse (readAcSparse t) val - STArr _ t -> traverse (readAcSparse t) val - STScal{} -> return val - STAccum{} -> error "Nested accumulators" + _ -> error "readAcDense: invalid dense type" accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () accumAddSparse typ SZ ref () val = case typ of @@ -369,7 +361,7 @@ accumAddSparse typ SZ ref () val = case typ of case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of (_, 0) -> return () (0, _) -> do - newrefarr <- newAcDense typ SZ () val + newrefarr <- traverse (newAcSparse t SZ ()) val join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) @@ -410,7 +402,7 @@ accumAddSparse typ (SS dep) ref idx val = case typ of STArr dim (t :: STy t) -> AcM $ do refs <- readIORef ref if shapeSize (arrayShape refs) == 0 - then do newrefarr <- newAcDense typ (SS dep) idx val + then do newrefarr <- newAcArray dim t (SS dep) idx val join $ atomicModifyIORef' ref $ \refarr -> if shapeSize (arrayShape refarr) == 0 then (newrefarr, return ()) @@ -440,160 +432,21 @@ accumAddSparse typ (SS dep) ref idx val = case typ of accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) accumAddDense typ SZ ref () val = case typ of - STNil -> ((), return ()) - STPair t1 t2 -> - (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) - accumAddSparse t2 SZ (snd ref) () (snd val)) - STMaybe t -> + STEither t1 t2 -> case (ref, val) of - (_, Nothing) -> (ref, return ()) - (Nothing, Just val') -> _ val' - (Just ref', Just val') -> _ ref' val' - STArr _ t -> AcM $ do - refs <- readIORef ref - case (shapeSize (arrayShape refs), shapeSize (arrayShape val)) of - (_, 0) -> return () - (0, _) -> do - newrefarr <- newAcDense typ SZ () val - join $ atomicModifyIORef' ref $ \refarr -> - if shapeSize (arrayShape refarr) == 0 - then (newrefarr, return ()) - else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ SZ ref () val) -- just try again from the start (dropping newrefarr for the GC to clean up) - _ | arrayShape refs == arrayShape val -> - sequence_ [unAcM $ accumAddSparse t SZ (arrayIndexLinear refs i) () (arrayIndexLinear val i) - | i <- [0 .. shapeSize (arrayShape val) - 1]] - | otherwise -> error "Array shape mismatch in accum add" - STScal sty -> AcM $ case sty of - STI32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STI64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> error "Accumulator of Bool" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" + (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') + (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') + _ -> error "Mismatched Either in accumAdd" + _ -> error "accumAddDense: invalid dense type" accumAddDense typ (SS dep) ref idx val = case typ of - STNil -> return () - STPair t1 t2 -> AcM $ do - (ref1, ref2) <- readIORef ref - case (idx, val) of - (Left idx', Left val') -> unAcM $ accumAddSparse t1 dep ref1 idx' val' - (Right idx', Right val') -> unAcM $ accumAddSparse t2 dep ref2 idx' val' - _ -> error "Index/value mismatch in pair accumulator add" - STMaybe t -> - AcM $ join $ atomicModifyIORef' ref $ \case - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (Nothing, do newac <- newAcDense t dep idx val - join $ atomicModifyIORef' ref $ \ac2 -> case ac2 of - Nothing -> (Just newac, return ()) - Just ac2' -> bimap Just unAcM (accumAddDense t dep ac2' idx val)) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just ac -> bimap Just unAcM (accumAddDense t dep ac idx val) - STArr dim (t :: STy t) -> AcM $ do - refs <- readIORef ref - if shapeSize (arrayShape refs) == 0 - then do newrefarr <- newAcDense typ (SS dep) idx val - join $ atomicModifyIORef' ref $ \refarr -> - if shapeSize (arrayShape refarr) == 0 - then (newrefarr, return ()) - else -- someone was faster than us in initialising the reference! - (refarr, unAcM $ accumAddSparse typ (SS dep) ref idx val) -- just try again from the start (dropping newrefarr for the GC to clean up) - else do let sh = unTupRepIdx ShNil ShCons dim (fst val) - go (SS dep) (invert sh) idx (snd val) - (\j index idxj valj -> unAcM $ accumAddSparse t j (refs `arrayIndex` index) idxj valj) - (\piix subsh val' -> unAcM $ sequence_ - [accumAddSparse t SZ (refs `arrayIndex` uninvert (piindexConcat piix (invert subix))) - () (val' `arrayIndex` subix) - | subix <- enumShape subsh]) - where - go :: SNat i -> InvShape n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) - -> (forall j. SNat j -> Index n -> Rep (AcIdx t j) -> Rep (AcVal t j) -> r) -- ^ Indexing into element of the array - -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray - -> r - go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray - go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element - go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 = - go dep' ish idx' val' - (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) - (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" + STEither t1 t2 -> + case (ref, idx, val) of + (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') + (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') + _ -> error "Mismatched Either in accumAdd" + _ -> error "accumAddDense: invalid dense type" --- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ()) --- accumAddVal typ SZ ac () val = case typ of --- STNil -> ((), return ()) --- STPair t1 t2 -> --- let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val) --- (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val) --- in ((ac1', ac2'), m1 >> m2) --- STMaybe t -> case t of --- STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val) --- STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def --- where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ()) --- def = (ac, accumAddValM t ac val) --- STArr n t --- | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val --- STEither{} -> error "Bare Either in accumulator" --- _ -> _ --- accumAddVal typ (SS dep) ac idx val = case typ of --- STNil -> ((), return ()) --- STPair t1 t2 -> --- case (idx, val) of --- (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val' --- (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val' --- _ -> error "Inconsistent idx and val in accumulator add operation" --- _ -> _ - --- accumAddValME :: STy a -> STy b --- -> IORef (Maybe (Either (RepAc a) (RepAc b))) --- -> Maybe (Either (Rep a) (Rep b)) --- -> AcM s () --- accumAddValME t1 t2 ac val = --- case val of --- Nothing -> return () --- Just val' -> --- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of --- (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val)) --- (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1 --- (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2 --- _ -> error "Inconsistent accumulator and value in add operation on Maybe Either" - --- initAccumOrTryAgainME :: STy a -> STy b --- -> IORef (Maybe (Either (RepAc a) (RepAc b))) --- -> Either (Rep a) (Rep b) --- -> IO () --- -> IO () --- initAccumOrTryAgainME t1 t2 ac val onRace = do --- newContents <- case val of Left x -> Left <$> makeRepAc t1 x --- Right y -> Right <$> makeRepAc t2 y --- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) --- value@Just{} -> (value, onRace)) - --- accumAddValM :: STy t --- -> IORef (Maybe (RepAc t)) --- -> Maybe (Rep t) --- -> AcM s () --- accumAddValM t ac val = --- case val of --- Nothing -> return () --- Just val' -> --- join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of --- Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val)) --- Just x -> first Just $ accumAddVal t SZ x () val' - --- initAccumOrTryAgainM :: STy t --- -> IORef (Maybe (RepAc t)) --- -> Rep t --- -> IO () --- -> IO () --- initAccumOrTryAgainM t ac val onRace = do --- newContents <- makeRepAc t val --- join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ()) --- value@Just{} -> (value, onRace)) numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id |