diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 |
commit | c5108efd1402dcb52beca27d13b4880eed35ef5b (patch) | |
tree | b25e4ee26c1f894671db2e68c0afdaf6a1378cb5 /test/Tests/C.hs | |
parent | 0fd727dcb3fe05816aa9c68be5ebac84a55fcf4b (diff) |
Properly test C reductions
Diffstat (limited to 'test/Tests/C.hs')
-rw-r--r-- | test/Tests/C.hs | 107 |
1 files changed, 70 insertions, 37 deletions
diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 2a3949f..148e7f6 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -16,10 +16,8 @@ import Data.Type.Equality import Foreign import GHC.TypeLits -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types (fromSNat') import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape import Hedgehog @@ -35,42 +33,77 @@ import Gen import Util +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + _outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (any (== 0) (toList (shrTail sh))) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @Double @(n + 1) (toList sh) [] + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do + let inrank = SNat @(n + 1) + outsh <- forAll $ genShR outrank + guard (all (> 0) (toList outsh)) + let insh = shrAppend outsh (1 :$: ZSR) + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> + genStorables (Range.singleton (product insh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ + genRank $ \inrank1@(SNat @m) -> + genRank $ \outrank@(SNat @nm1) -> do + inrank2 :: SNat n <- return $ SNat @(nm1 + 1) + (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of + LTI -> return Refl -- actually we only continue if m < n + _ -> discard + (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 + guard (all (> 0) (toList sh3)) + arr <- forAllT $ + OR.stretch (toList sh3) + . OR.reshape (toList sh2) + . OR.fromVector @Double @m (toList sh1) <$> + genStorables (Range.singleton (product sh1)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + arrTrans <- + if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) + return $ OR.transpose perm arr + else return arr + let rarr = rfromOrthotope inrank2 arrTrans + almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + tests :: TestTree tests = testGroup "C" [testGroup "sum" - [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do - -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. - let inrank = SNat @(n + 1) - sh <- forAll $ genShR inrank - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> - genStorables (Range.singleton (product sh)) - (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - -- annotateShow rarr - Refl <- return $ lemRankReplicate outrank - let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - let rhs = orSumOuter1 outrank arr - -- annotateShow lhs - -- annotateShow rhs - lhs === rhs - - ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do - -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. - outrank :: SNat n <- return $ SNat @(nm1 + 1) - let inrank = SNat @(n + 1) - sh <- forAll $ do - shtt <- genShR outrankm1 -- nm1 - sht <- shuffleShR (0 :$: shtt) -- n - n <- Gen.int (Range.linear 0 20) - return (n :$: sht) -- n + 1 - guard (any (== 0) (toList (shrTail sh))) - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - let arr = OR.fromList @Double @(n + 1) (toList sh) [] - let rarr = rfromOrthotope inrank arr - Refl <- return $ lemRankReplicate outrank - let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - OR.toList lhs === [] + [testProperty "nonempty" prop_sum_nonempty + ,testProperty "empty" prop_sum_empty + ,testProperty "last==1" prop_sum_lasteq1 + ,testProperty "replicated" (prop_sum_replicated False) + ,testProperty "replicated_transposed" (prop_sum_replicated True) ] ] |