diff options
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 44 | 
1 files changed, 44 insertions, 0 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index ffc2929..db66540 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,8 @@ module Interpreter (  ) where  import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put)  import Data.Bifunctor (bimap)  import Data.Bitraversable (bitraverse)  import Data.Char (isSpace) @@ -28,6 +30,7 @@ import Data.Functor.Identity  import qualified Data.Functor.Product as Product  import Data.Int (Int64)  import Data.IORef +import Data.Tuple (swap)  import System.IO (hPutStrLn, stderr)  import System.IO.Unsafe (unsafePerformIO) @@ -143,6 +146,39 @@ interpret'Rec env = \case          sh `ShCons` n = arrayShape arr      numericIsNum t $ return $        arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) +  EFold1InnerD1 _ _ a b c -> do +    let t = typeOf b +    let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a +    x0 <- interpret' env b +    arr <- interpret' env c +    let sh `ShCons` n = arrayShape arr +    -- TODO: this is very inefficient, even for an interpreter; with mutable +    -- arrays this can be a lot better with no lists +    res <- arrayGenerateM sh $ \idx -> do +             (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] +             return (y, arrayFromList (ShNil `ShCons` n) stores) +    return (arrayMap fst res +           ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> +              arrayIndexLinear (snd (arrayIndex res idx)) i) +  EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do +    let STArr _ (STPair t1 ttape) = typeOf ebog +    let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef +    parr <- interpret' env ep +    zi <- interpret' env ezi +    bog <- interpret' env ebog +    arrctg <- interpret' env ed +    let sh `ShCons` n = arrayShape parr +    res <- arrayGenerateM sh $ \idx -> do +             let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) +                 loop i !ctg !inpctgs = do +                   let (prefix, tape) = arrayIndex bog (idx `IxCons` i) +                   (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg +                   loop (i - 1) ctg1 (ctg2 : inpctgs) +             (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] +             return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) +    return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res) +           ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> +              arrayIndexLinear (snd (arrayIndex res idx)) i)    EConst _ _ v -> return v    EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e    EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) @@ -411,3 +447,11 @@ ixUncons (IxCons idx i) = (idx, i)  shUncons :: Shape (S n) -> (Shape n, Int)  shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' +  where f' x = do +          s <- get +          (s', y) <- lift $ f s x +          put s' +          return y | 
