diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-24 23:34:30 +0200 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-24 23:34:30 +0200 |
| commit | 42176d4a8a0fe7954f17da5c0506721695aa477f (patch) | |
| tree | 8a29e847faa613e9becf1bccdcaad010187e639b /src/Interpreter.hs | |
| parent | 7729c45a325fe653421d654ed4c28b040585fce9 (diff) | |
WIP fold: everything but Compile (slow, but should be sound)
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index ffc2929..db66540 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,8 @@ module Interpreter ( ) where import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) @@ -28,6 +30,7 @@ import Data.Functor.Identity import qualified Data.Functor.Product as Product import Data.Int (Int64) import Data.IORef +import Data.Tuple (swap) import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) @@ -143,6 +146,39 @@ interpret'Rec env = \case sh `ShCons` n = arrayShape arr numericIsNum t $ return $ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do + let STArr _ (STPair t1 ttape) = typeOf ebog + let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef + parr <- interpret' env ep + zi <- interpret' env ezi + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape parr + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let (prefix, tape) = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res) + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) @@ -411,3 +447,11 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y |
