From 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:58:08 +0100 Subject: Simplify foldD2 to not sum x0 contributions --- src/Interpreter.hs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'src/Interpreter.hs') diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 79d5014..9e3d2a6 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -164,15 +164,14 @@ interpret'Rec env = \case return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) - EFold1InnerD2 _ _ ef ez eplus ebog ed -> do + EFold1InnerD2 _ _ ef ebog ed -> do let STArr _ tB = typeOf ebog - t2 = typeOf ez + STArr _ t2 = typeOf ed let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef - zeroval <- interpret' env ez - let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus bog <- interpret' env ebog arrctg <- interpret' env ed let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" res <- arrayGenerateM sh $ \idx -> do let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) loop i !ctg !inpctgs = do @@ -181,8 +180,7 @@ interpret'Rec env = \case loop (i - 1) ctg1 (ctg2 : inpctgs) (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) - x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res) - return (x0ctg + return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v -- cgit v1.2.3-70-g09d2