summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 15:54:12 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 15:54:12 +0200
commit3fd8d35cca2a23c137934a170c67e8ce310edf13 (patch)
tree429fb99f9c1395272f1f9a94bfbc0e003fa39b21 /src/Interpreter.hs
parent919a36f8eed21501357185a90e2b7a4d9eaf7f08 (diff)
Complete monoidal accumulator rewrite
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs100
1 files changed, 40 insertions, 60 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index af11de8..d7916d8 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -35,7 +35,6 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
-import CHAD.Types
import Data
import Interpreter.Rep
@@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re
withAccum t _ initval f = AcM $ do
accum <- newAcDense t initval
out <- unAcM $ f accum
- val <- readAcSparse t accum
+ val <- readAc t accum
return (out, val)
newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t)
@@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
!linindex = toLinearIndex arrsh arrindex
in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
-readAcSparse :: SMTy t -> RepAc t -> IO (Rep t)
-readAcSparse typ val = case typ of
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
SMTNil -> return ()
- SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val
- SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val
- SMTArr _ t -> traverse (readAcSparse t) val
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
-
- (STPair t1 t2, SAPHere) ->
- case val of
- Nothing -> return ()
- Just (val1, val2) ->
- realiseMaybeSparse ref ((,) <$> newAcDense t1 val1
- <*> newAcDense t2 val2)
- (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
- accumAddSparse t2 SAPHere ac2 () val2)
- (STPair t1 t2, SAPFst prj') ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
- (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
- (STPair t1 t2, SAPSnd prj') ->
- realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
- (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
-
- (STEither{}, SAPHere) ->
+accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s ()
+accumAddDense typ ref val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> do
+ accumAddDense t1 (fst ref) (fst val)
+ accumAddDense t2 (snd ref) (snd val)
+ SMTLEither{} ->
case val of
Nothing -> return ()
Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- (STEither t1 _, SAPLeft prj') ->
+ SMTMaybe{} ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
+ SMTArr _ t1 ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
+ SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
+accumAddSparse typ prj ref idx val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
(\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
- (STEither _ t2, SAPRight prj') ->
+ (SMTLEither _ t2, SAPRight prj') ->
realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
(\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
- (STMaybe{}, SAPHere) ->
- case val of
- Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- (STMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+ (\ac -> accumAddSparse t1 prj' ac idx val)
- (STArr _ t1, SAPHere) ->
- case val of
- Nothing -> return ()
- Just val' ->
- realiseMaybeSparse ref
- (arrayMapM (newAcDense t1) val')
- (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
- accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
- (STArr n t1, SAPArrIdx prj') ->
- let ((arrindex', arrsh'), idx') = idx
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let ((arrindex', ziarr), idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ arrsh = arrayShape ziarr
linindex = toLinearIndex arrsh arrindex
- in realiseMaybeSparse ref
- (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
- (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
-
- (STScal sty, SAPHere) -> AcM $ case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> return ()
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
+ in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val
+
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =