aboutsummaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
commit42176d4a8a0fe7954f17da5c0506721695aa477f (patch)
tree8a29e847faa613e9becf1bccdcaad010187e639b /src/Interpreter.hs
parent7729c45a325fe653421d654ed4c28b040585fce9 (diff)
WIP fold: everything but Compile (slow, but should be sound)
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs44
1 files changed, 44 insertions, 0 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index ffc2929..db66540 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,6 +21,8 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict (runStateT, get, put)
import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
@@ -28,6 +30,7 @@ import Data.Functor.Identity
import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
+import Data.Tuple (swap)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
@@ -143,6 +146,39 @@ interpret'Rec env = \case
sh `ShCons` n = arrayShape arr
numericIsNum t $ return $
arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EFold1InnerD1 _ _ a b c -> do
+ let t = typeOf b
+ let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a
+ x0 <- interpret' env b
+ arr <- interpret' env c
+ let sh `ShCons` n = arrayShape arr
+ -- TODO: this is very inefficient, even for an interpreter; with mutable
+ -- arrays this can be a lot better with no lists
+ res <- arrayGenerateM sh $ \idx -> do
+ (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ return (y, arrayFromList (ShNil `ShCons` n) stores)
+ return (arrayMap fst res
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
+ EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do
+ let STArr _ (STPair t1 ttape) = typeOf ebog
+ let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef
+ parr <- interpret' env ep
+ zi <- interpret' env ezi
+ bog <- interpret' env ebog
+ arrctg <- interpret' env ed
+ let sh `ShCons` n = arrayShape parr
+ res <- arrayGenerateM sh $ \idx -> do
+ let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
+ loop i !ctg !inpctgs = do
+ let (prefix, tape) = arrayIndex bog (idx `IxCons` i)
+ (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg
+ loop (i - 1) ctg1 (ctg2 : inpctgs)
+ (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
+ return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
+ return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res)
+ ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
+ arrayIndexLinear (snd (arrayIndex res idx)) i)
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
@@ -411,3 +447,11 @@ ixUncons (IxCons idx i) = (idx, i)
shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)
+
+mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
+mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f'
+ where f' x = do
+ s <- get
+ (s', y) <- lift $ f s x
+ put s'
+ return y