From 411d563023c65270aca746f12c4d597b49122b45 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 16 Jan 2026 11:57:44 +0100 Subject: test: Port sum tests to sumAll --- test/Tests/C.hs | 78 +++++++++++++++++++++++++++++++++++++++++++++------------ test/Util.hs | 16 +++++++----- 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 0656107..e26c3dd 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -45,8 +45,8 @@ fineTol = 1e-8 debugCoverage :: Bool debugCoverage = False -prop_sum_nonempty :: Property -prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do +gen_red_nonempty :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_nonempty f = property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank @@ -55,11 +55,10 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + f inrank outrank arr -prop_sum_empty :: Property -prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do +gen_red_empty :: (forall n. SNat (n + 1) -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_empty f = property $ genRank $ \outrankm1@(SNat @nm1) -> do -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. _outrank :: SNat n <- return $ SNat @(nm1 + 1) let inrank = SNat @(n + 1) @@ -71,11 +70,10 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do guard (0 `elem` shrTail sh) -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) let arr = OR.fromList @(n + 1) @Double (toList sh) [] - let rarr = rfromOrthotope inrank arr - OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === [] + f inrank arr -prop_sum_lasteq1 :: Property -prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do +gen_red_lasteq1 :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_lasteq1 f = property $ genRank $ \outrank@(SNat @n) -> do let inrank = SNat @(n + 1) outsh <- forAll $ genShR outrank guard (all (> 0) outsh) @@ -83,11 +81,10 @@ prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> genStorables (Range.singleton (product insh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + f inrank outrank arr -prop_sum_replicated :: Bool -> Property -prop_sum_replicated doTranspose = property $ +gen_red_replicated :: Bool -> (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_replicated doTranspose f = property $ genRank $ \inrank1@(SNat @m) -> genRank $ \outrank@(SNat @nm1) -> do inrank2 :: SNat n <- return $ SNat @(nm1 + 1) @@ -110,8 +107,50 @@ prop_sum_replicated doTranspose = property $ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) return $ OR.transpose perm arr else return arr - let rarr = rfromOrthotope inrank2 arrTrans - almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arrTrans) + f inrank2 outrank arrTrans + + +prop_sum_nonempty :: Property +prop_sum_nonempty = gen_red_nonempty $ \inrank outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + +prop_sum_empty :: Property +prop_sum_empty = gen_red_empty $ \inrank arr -> do + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = gen_red_lasteq1 $ \inrank outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = gen_red_replicated doTranspose $ \inrank outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + + +prop_sumall_nonempty :: Property +prop_sumall_nonempty = gen_red_nonempty $ \inrank _outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr) + +prop_sumall_empty :: Property +prop_sumall_empty = gen_red_empty $ \inrank arr -> do + let rarr = rfromOrthotope inrank arr + rsumAllPrim rarr === 0.0 + +prop_sumall_lasteq1 :: Property +prop_sumall_lasteq1 = gen_red_lasteq1 $ \inrank _outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr) + +prop_sumall_replicated :: Bool -> Property +prop_sumall_replicated doTranspose = gen_red_replicated doTranspose $ \inrank _outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq 1e-6 (rsumAllPrim rarr) (OR.sumA arr) + prop_negate_with :: forall f b. Show b => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ()) @@ -140,6 +179,13 @@ tests = testGroup "C" ,testProperty "replicated" (prop_sum_replicated False) ,testProperty "replicated_transposed" (prop_sum_replicated True) ] + ,testGroup "sumAll" + [testProperty "nonempty" prop_sumall_nonempty + ,testProperty "empty" prop_sumall_empty + ,testProperty "last==1" prop_sumall_lasteq1 + ,testProperty "replicated" (prop_sumall_replicated False) + ,testProperty "replicated_transposed" (prop_sumall_replicated True) + ] ,testGroup "negate" [testProperty "normalised" $ prop_negate_with (\k -> genRank (k (Const ()))) diff --git a/test/Util.hs b/test/Util.hs index 8a5ba72..6514fbf 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -36,16 +36,20 @@ orSumOuter1 (sn@SNat :: SNat n) = let n = fromSNat' sn in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) -class AlmostEq f where - type AlmostEqConstr f :: Type -> Constraint +class AlmostEq t where + type EltOf t :: Type -- | absolute tolerance, lhs, rhs - almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) - => a -> f a -> f a -> m () + almostEq :: MonadTest m => EltOf t -> t -> t -> m () -instance AlmostEq (OR.Array n) where - type AlmostEqConstr (OR.Array n) = OR.Unbox +instance (OR.Unbox a, Ord a, Show a, Fractional a) => AlmostEq (OR.Array n a) where + type EltOf (OR.Array n a) = a almostEq atol lhs rhs | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = success | otherwise = failDiff lhs rhs + +instance AlmostEq Double where + type EltOf Double = Double + almostEq atol lhs rhs | abs (lhs - rhs) < atol = success + | otherwise = failDiff lhs rhs -- cgit v1.2.3-70-g09d2 From 96795853db5a3ee85d7c838a508b4153988e6042 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 16 Jan 2026 19:14:20 +0100 Subject: C: Fix REDUCEFULL Only the last inner vector was kept... --- cbits/arith.c | 4 +++- ops/Data/Array/Strided/Arith/Internal.hs | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/cbits/arith.c b/cbits/arith.c index 1066463..ee248a4 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -494,7 +494,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { if (rank == 0) return arr[0]; \ typ result = 0; \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ - REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ + typ dest = 0; \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, dest); \ + result = result op dest; \ }); \ return result; \ } diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index d94fc65..7578dd8 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -396,6 +396,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off Nothing -> error "impossible" -- TODO: test handling of negative strides +-- TODO: simplify away normalised dimensions -- | Reduce full array {-# NOINLINE vectorRedFullOp #-} vectorRedFullOp :: forall a b n. (Num a, Storable a) -- cgit v1.2.3-70-g09d2 From 0216dacb82f305e30f147ec7242dcd8599da721a Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 16 Jan 2026 19:14:59 +0100 Subject: Use numEltSumFull in X.sumFull Thanks Mikolaj :) --- src/Data/Array/XArray.hs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 6389e67..1445ce6 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -268,11 +268,7 @@ transpose2 ssh1 ssh2 (XArray arr) = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a -sumFull _ (XArray arr) = - S.unScalar $ - liftO1 (numEltSum1Inner (SNat @0)) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr +sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -- cgit v1.2.3-70-g09d2