From c5108efd1402dcb52beca27d13b4880eed35ef5b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Jun 2024 21:29:53 +0200 Subject: Properly test C reductions --- test/Gen.hs | 38 ++++++++++++++++++++ test/Tests/C.hs | 107 ++++++++++++++++++++++++++++++++++++-------------------- test/Util.hs | 18 ++++++++++ 3 files changed, 126 insertions(+), 37 deletions(-) (limited to 'test') diff --git a/test/Gen.hs b/test/Gen.hs index 29dffb2..559fecf 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -15,6 +15,7 @@ module Gen where import Data.ByteString qualified as BS import Data.Foldable (toList) import Data.Type.Equality +import Data.Type.Ord import Data.Vector.Storable qualified as VS import Foreign import GHC.TypeLits @@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested +import Data.Array.Nested.Internal.Shape import Hedgehog import Hedgehog.Gen qualified as Gen @@ -81,6 +83,42 @@ genShR sn = do cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) shuffleShR (min cap <$> dims) +-- | Example: given 3 and 7, might return: +-- +-- @ +-- ([ 13, 4, 27 ] +-- ,[1, 13, 1, 1, 4, 27, 1] +-- ,[4, 13, 1, 3, 4, 27, 2]) +-- @ +-- +-- The up-replicated dimensions are always nonzero and not very large, but the +-- other dimensions might be zero. +genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) +genReplicatedShR = \m n -> do + sh1 <- genShR m + (sh2, sh3) <- injectOnes n sh1 sh1 + return (sh1, sh2, sh3) + where + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) + injectOnes n@SNat shOnes sh + | m@SNat <- shrLengthSNat sh + = case cmpNat n m of + LTI -> error "unreachable" + EQI -> return (shOnes, sh) + GTI -> do + index <- Gen.int (Range.linear 0 (fromSNat' m)) + value <- Gen.int (Range.linear 1 5) + Refl <- return (lem n m) + injectOnes n (inject index 1 shOnes) (inject index value sh) + + lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem _ _ = unsafeCoerceRefl + + inject :: Int -> Int -> IShR m -> IShR (m + 1) + inject 0 v sh = v :$: sh + inject i v (w :$: sh) = w :$: inject (i - 1) v sh + inject _ v ZSR = v :$: ZSR -- invalid input, but meh + genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 2a3949f..148e7f6 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -16,10 +16,8 @@ import Data.Type.Equality import Foreign import GHC.TypeLits -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types (fromSNat') import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape import Hedgehog @@ -35,42 +33,77 @@ import Gen import Util +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + _outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (any (== 0) (toList (shrTail sh))) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @Double @(n + 1) (toList sh) [] + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do + let inrank = SNat @(n + 1) + outsh <- forAll $ genShR outrank + guard (all (> 0) (toList outsh)) + let insh = shrAppend outsh (1 :$: ZSR) + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> + genStorables (Range.singleton (product insh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ + genRank $ \inrank1@(SNat @m) -> + genRank $ \outrank@(SNat @nm1) -> do + inrank2 :: SNat n <- return $ SNat @(nm1 + 1) + (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of + LTI -> return Refl -- actually we only continue if m < n + _ -> discard + (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 + guard (all (> 0) (toList sh3)) + arr <- forAllT $ + OR.stretch (toList sh3) + . OR.reshape (toList sh2) + . OR.fromVector @Double @m (toList sh1) <$> + genStorables (Range.singleton (product sh1)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + arrTrans <- + if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) + return $ OR.transpose perm arr + else return arr + let rarr = rfromOrthotope inrank2 arrTrans + almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + tests :: TestTree tests = testGroup "C" [testGroup "sum" - [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do - -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. - let inrank = SNat @(n + 1) - sh <- forAll $ genShR inrank - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> - genStorables (Range.singleton (product sh)) - (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - -- annotateShow rarr - Refl <- return $ lemRankReplicate outrank - let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - let rhs = orSumOuter1 outrank arr - -- annotateShow lhs - -- annotateShow rhs - lhs === rhs - - ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do - -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. - outrank :: SNat n <- return $ SNat @(nm1 + 1) - let inrank = SNat @(n + 1) - sh <- forAll $ do - shtt <- genShR outrankm1 -- nm1 - sht <- shuffleShR (0 :$: shtt) -- n - n <- Gen.int (Range.linear 0 20) - return (n :$: sht) -- n + 1 - guard (any (== 0) (toList (shrTail sh))) - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - let arr = OR.fromList @Double @(n + 1) (toList sh) [] - let rarr = rfromOrthotope inrank arr - Refl <- return $ lemRankReplicate outrank - let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - OR.toList lhs === [] + [testProperty "nonempty" prop_sum_nonempty + ,testProperty "empty" prop_sum_empty + ,testProperty "last==1" prop_sum_lasteq1 + ,testProperty "replicated" (prop_sum_replicated False) + ,testProperty "replicated_transposed" (prop_sum_replicated True) ] ] diff --git a/test/Util.hs b/test/Util.hs index f377e5b..ce6ec23 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -4,12 +4,16 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Util where import Data.Array.RankedS qualified as OR +import Data.Kind +import Hedgehog +import Hedgehog.Internal.Property (failDiff) import GHC.TypeLits import Data.Array.Mixed.Types (fromSNat') @@ -32,3 +36,17 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n orSumOuter1 (sn@SNat :: SNat n) = let n = fromSNat' sn in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +class AlmostEq f where + type AlmostEqConstr f :: Type -> Constraint + -- | absolute tolerance, lhs, rhs + almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) + => a -> f a -> f a -> m () + +instance KnownNat n => AlmostEq (OR.Array n) where + type AlmostEqConstr (OR.Array n) = OR.Unbox + almostEq atol lhs rhs + | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = + success + | otherwise = + failDiff lhs rhs -- cgit v1.2.3-70-g09d2