From 955af83f664639701fdbee54718186e07b31d42f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 Oct 2025 11:56:40 +0100 Subject: Better fold D{1,2} primitives --- src/Interpreter.hs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) (limited to 'src/Interpreter.hs') diff --git a/src/Interpreter.hs b/src/Interpreter.hs index db66540..db7033d 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -148,7 +148,7 @@ interpret'Rec env = \case arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) EFold1InnerD1 _ _ a b c -> do let t = typeOf b - let f = \x y -> (\(z, tape) -> (z, (x, tape))) <$> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr @@ -160,23 +160,25 @@ interpret'Rec env = \case return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) - EFold1InnerD2 _ _ t2 ef ep ezi ebog ed -> do - let STArr _ (STPair t1 ttape) = typeOf ebog - let f = \tape x y ctg -> interpret' (V (fromSMTy t2) ctg `SCons` V t1 y `SCons` V t1 x `SCons` V ttape tape `SCons` env) ef - parr <- interpret' env ep - zi <- interpret' env ezi + EFold1InnerD2 _ _ ef ez eplus ebog ed -> do + let STArr _ tB = typeOf ebog + t2 = typeOf ez + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + zeroval <- interpret' env ez + let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus bog <- interpret' env ebog arrctg <- interpret' env ed - let sh `ShCons` n = arrayShape parr + let sh `ShCons` n = arrayShape bog res <- arrayGenerateM sh $ \idx -> do let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) loop i !ctg !inpctgs = do - let (prefix, tape) = arrayIndex bog (idx `IxCons` i) - (ctg1, ctg2) <- f tape prefix (arrayIndex parr (idx `IxCons` i)) ctg + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg loop (i - 1) ctg1 (ctg2 : inpctgs) (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) - return (foldl' (\x (y, _) -> addM t2 x y) (zeroM t2 zi) (arrayToList res) + x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res) + return (x0ctg ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v -- cgit v1.2.3-70-g09d2