diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:54:12 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:54:12 +0200 |
commit | 3fd8d35cca2a23c137934a170c67e8ce310edf13 (patch) | |
tree | 429fb99f9c1395272f1f9a94bfbc0e003fa39b21 /src/Interpreter.hs | |
parent | 919a36f8eed21501357185a90e2b7a4d9eaf7f08 (diff) |
Complete monoidal accumulator rewrite
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 100 |
1 files changed, 40 insertions, 60 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index af11de8..d7916d8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -35,7 +35,6 @@ import Debug.Trace import Array import AST import AST.Pretty -import CHAD.Types import Data import Interpreter.Rep @@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re withAccum t _ initval f = AcM $ do accum <- newAcDense t initval out <- unAcM $ f accum - val <- readAcSparse t accum + val <- readAc t accum return (out, val) newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) @@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = !linindex = toLinearIndex arrsh arrindex in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: SMTy t -> RepAc t -> IO (Rep t) -readAcSparse typ val = case typ of +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of SMTNil -> return () - SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val - SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val - SMTArr _ t -> traverse (readAcSparse t) val + SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val + SMTMaybe t -> traverse (readAc t) =<< readIORef val + SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - - (STPair t1 t2, SAPHere) -> - case val of - Nothing -> return () - Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 - <*> newAcDense t2 val2) - (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 - accumAddSparse t2 SAPHere ac2 () val2) - (STPair t1 t2, SAPFst prj') -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) - (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) - (STPair t1 t2, SAPSnd prj') -> - realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) - (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - - (STEither{}, SAPHere) -> +accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () +accumAddDense typ ref val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> do + accumAddDense t1 (fst ref) (fst val) + accumAddDense t2 (snd ref) (snd val) + SMTLEither{} -> case val of Nothing -> return () Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - (STEither t1 _, SAPLeft prj') -> + SMTMaybe{} -> + case val of + Nothing -> return () + Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' + SMTArr _ t1 -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) + SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val + + (SMTLEither t1 _, SAPLeft prj') -> realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (STEither _ t2, SAPRight prj') -> + (SMTLEither _ t2, SAPRight prj') -> realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - (STMaybe{}, SAPHere) -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - (STMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (newAcSparse t1 prj' idx val) + (\ac -> accumAddSparse t1 prj' ac idx val) - (STArr _ t1, SAPHere) -> - case val of - Nothing -> return () - Just val' -> - realiseMaybeSparse ref - (arrayMapM (newAcDense t1) val') - (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> - accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj') -> - let ((arrindex', arrsh'), idx') = idx + (SMTArr n t1, SAPArrIdx prj') -> + let ((arrindex', ziarr), idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' + arrsh = arrayShape ziarr linindex = toLinearIndex arrsh arrindex - in realiseMaybeSparse ref - (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) - (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - - (STScal sty, SAPHere) -> AcM $ case sty of - STI32 -> return () - STI64 -> return () - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> return () - - (STAccum{}, _) -> error "Accumulators not allowed in source program" + in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val + realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = |