diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:54:12 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-29 15:54:12 +0200 | 
| commit | 3fd8d35cca2a23c137934a170c67e8ce310edf13 (patch) | |
| tree | 429fb99f9c1395272f1f9a94bfbc0e003fa39b21 /src/Interpreter.hs | |
| parent | 919a36f8eed21501357185a90e2b7a4d9eaf7f08 (diff) | |
Complete monoidal accumulator rewrite
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 98 | 
1 files changed, 39 insertions, 59 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index af11de8..d7916d8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -35,7 +35,6 @@ import Debug.Trace  import Array  import AST  import AST.Pretty -import CHAD.Types  import Data  import Interpreter.Rep @@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re  withAccum t _ initval f = AcM $ do    accum <- newAcDense t initval    out <- unAcM $ f accum -  val <- readAcSparse t accum +  val <- readAc t accum    return (out, val)  newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) @@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =        !linindex = toLinearIndex arrsh arrindex    in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: SMTy t -> RepAc t -> IO (Rep t) -readAcSparse typ val = case typ of +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of    SMTNil -> return () -  SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val -  SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val -  SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val -  SMTArr _ t -> traverse (readAcSparse t) val +  SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val +  SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val +  SMTMaybe t -> traverse (readAc t) =<< readIORef val +  SMTArr _ t -> traverse (readAc t) val    SMTScal _ -> readIORef val -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of -  (STNil, SAPHere) -> return () - -  (STPair t1 t2, SAPHere) -> -    case val of -      Nothing -> return () -      Just (val1, val2) -> -        realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 -                                    <*> newAcDense t2 val2) -                               (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 -                                                  accumAddSparse t2 SAPHere ac2 () val2) -  (STPair t1 t2, SAPFst prj') -> -    realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) -                           (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) -  (STPair t1 t2, SAPSnd prj') -> -    realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) -                           (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - -  (STEither{}, SAPHere) -> +accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () +accumAddDense typ ref val = case typ of +  SMTNil -> return () +  SMTPair t1 t2 -> do +    accumAddDense t1 (fst ref) (fst val) +    accumAddDense t2 (snd ref) (snd val) +  SMTLEither{} ->      case val of        Nothing -> return ()        Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1        Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 -  (STEither t1 _, SAPLeft prj') -> +  SMTMaybe{} -> +    case val of +      Nothing -> return () +      Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' +  SMTArr _ t1 -> +    forM_ [0 .. arraySize ref - 1] $ \i -> +      accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) +  SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of +  (_, SAPHere) -> accumAddDense typ ref val + +  (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val +  (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val + +  (SMTLEither t1 _, SAPLeft prj') ->      realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)                             (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val                                    Right{} -> error "Mismatched Either in accumAddSparse (r +l)") -  (STEither _ t2, SAPRight prj') -> +  (SMTLEither _ t2, SAPRight prj') ->      realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)                             (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val                                    Left{} -> error "Mismatched Either in accumAddSparse (l +r)") -  (STMaybe{}, SAPHere) -> -    case val of -      Nothing -> return () -      Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' -  (STMaybe t1, SAPJust prj') -> -      realiseMaybeSparse ref (newAcSparse t1 prj' idx val) -                             (\ac -> accumAddSparse t1 prj' ac idx val) +  (SMTMaybe t1, SAPJust prj') -> +    realiseMaybeSparse ref (newAcSparse t1 prj' idx val) +                           (\ac -> accumAddSparse t1 prj' ac idx val) -  (STArr _ t1, SAPHere) -> -    case val of -      Nothing -> return () -      Just val' -> -        realiseMaybeSparse ref -          (arrayMapM (newAcDense t1) val') -          (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> -                    accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) -  (STArr n t1, SAPArrIdx prj') -> -    let ((arrindex', arrsh'), idx') = idx +  (SMTArr n t1, SAPArrIdx prj') -> +    let ((arrindex', ziarr), idx') = idx          arrindex = unTupRepIdx IxNil IxCons n arrindex' -        arrsh = unTupRepIdx ShNil ShCons n arrsh' +        arrsh = arrayShape ziarr          linindex = toLinearIndex arrsh arrindex -    in realiseMaybeSparse ref -         (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) -         (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - -  (STScal sty, SAPHere) -> AcM $ case sty of -    STI32 -> return () -    STI64 -> return () -    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) -    STBool -> return () +    in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val -  (STAccum{}, _) -> error "Accumulators not allowed in source program"  realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()  realiseMaybeSparse ref makeval modifyval = | 
