aboutsummaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs10
1 files changed, 4 insertions, 6 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 79d5014..9e3d2a6 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -164,15 +164,14 @@ interpret'Rec env = \case
return (arrayMap fst res
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
arrayIndexLinear (snd (arrayIndex res idx)) i)
- EFold1InnerD2 _ _ ef ez eplus ebog ed -> do
+ EFold1InnerD2 _ _ ef ebog ed -> do
let STArr _ tB = typeOf ebog
- t2 = typeOf ez
+ STArr _ t2 = typeOf ed
let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef
- zeroval <- interpret' env ez
- let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus
bog <- interpret' env ebog
arrctg <- interpret' env ed
let sh `ShCons` n = arrayShape bog
+ when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2"
res <- arrayGenerateM sh $ \idx -> do
let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
loop i !ctg !inpctgs = do
@@ -181,8 +180,7 @@ interpret'Rec env = \case
loop (i - 1) ctg1 (ctg2 : inpctgs)
(x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
- x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res)
- return (x0ctg
+ return (arrayMap fst res
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
arrayIndexLinear (snd (arrayIndex res idx)) i)
EConst _ _ v -> return v