diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Gen.hs | 39 | ||||
| -rw-r--r-- | test/Tests/C.hs | 34 | ||||
| -rw-r--r-- | test/Tests/Permutation.hs | 4 | ||||
| -rw-r--r-- | test/Util.hs | 7 |
4 files changed, 48 insertions, 36 deletions
diff --git a/test/Gen.hs b/test/Gen.hs index 695b83f..4f5fe96 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -2,10 +2,8 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NumericUnderscores #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -21,11 +19,10 @@ import Foreign import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types import Data.Array.Nested -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Hedgehog import Hedgehog.Gen qualified as Gen @@ -49,7 +46,7 @@ genLowBiased (lo, hi) = do return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) -shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh +shuffleShR = \sh -> go (length sh) (toList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR @@ -61,9 +58,12 @@ shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) -genShR sn = do +genShR = genShRwithTarget 100_000 + +genShRwithTarget :: Int -> SNat n -> Gen (IShR n) +genShRwithTarget targetMax sn = do let n = fromSNat' sn - targetSize <- Gen.int (Range.linear 0 100_000) + targetSize <- Gen.int (Range.linear 0 targetMax) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do @@ -78,9 +78,8 @@ genShR sn = do dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) dims <- genDims sn targetSize - let dimsL = toList dims - maxdim = maximum dimsL - cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + let maxdim = maximum dims + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize) shuffleShR (min cap <$> dims) -- | Example: given 3 and 7, might return: @@ -95,10 +94,14 @@ genShR sn = do -- other dimensions might be zero. genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) genReplicatedShR = \m n -> do - sh1 <- genShR m + let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m)) + sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m (sh2, sh3) <- injectOnes n sh1 sh1 return (sh1, sh2, sh3) where + repvalrange = (1::Int, 5) + repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) injectOnes n@SNat shOnes sh | m@SNat <- shrRank sh @@ -107,24 +110,24 @@ genReplicatedShR = \m n -> do EQI -> return (shOnes, sh) GTI -> do index <- Gen.int (Range.linear 0 (fromSNat' m)) - value <- Gen.int (Range.linear 1 5) + value <- Gen.int (uncurry Range.linear repvalrange) Refl <- return (lem n m) injectOnes n (inject index 1 shOnes) (inject index value sh) - lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True lem _ _ = unsafeCoerceRefl inject :: Int -> Int -> IShR m -> IShR (m + 1) inject 0 v sh = v :$: sh inject i v (w :$: sh) = w :$: inject (i - 1) v sh - inject _ v ZSR = v :$: ZSR -- invalid input, but meh + inject _ _ ZSR = error "unreachable" genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng seed <- Gen.resize 99 $ Gen.int Range.linearBounded let gen0 = Random.mkStdGen seed - (bs, _) = Random.genByteString (8 * n) gen0 + (bs, _) = Random.uniformByteString (8 * n) gen0 let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) return $ VS.generate n (f . readW64) @@ -159,7 +162,7 @@ genPermR n = Gen.shuffle [0 .. n-1] genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r genPerm n@SNat k = do list <- forAll $ genPermR (fromSNat' n) - permFromList list $ \perm -> do + permFromListCont list $ \perm -> do case permCheckPermutation perm $ case sameNat' (permRank perm) n of Just Refl -> Just (k perm) diff --git a/test/Tests/C.hs b/test/Tests/C.hs index a0f103d..0656107 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) {-# LANGUAGE TypeAbstractions #-} +#endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -18,13 +21,13 @@ import Data.Type.Equality import Foreign import GHC.TypeLits -import Data.Array.Mixed.Types (fromSNat') import Data.Array.Nested -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types (fromSNat') import Hedgehog -import Hedgehog.Internal.Property (forAllT) import Hedgehog.Gen qualified as Gen +import Hedgehog.Internal.Property (LabelName(..), forAllT) import Hedgehog.Range qualified as Range import Test.Tasty import Test.Tasty.Hedgehog @@ -39,18 +42,21 @@ import Util fineTol :: Double fineTol = 1e-8 +debugCoverage :: Bool +debugCoverage = False + prop_sum_nonempty :: Property prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail + guard (all (> 0) (shrTail sh)) -- only constrain the tail arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr - almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) prop_sum_empty :: Property prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do @@ -62,23 +68,23 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do sht <- shuffleShR (0 :$: shtt) -- n n <- Gen.int (Range.linear 0 20) return (n :$: sht) -- n + 1 - guard (any (== 0) (toList (shrTail sh))) + guard (0 `elem` shrTail sh) -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) let arr = OR.fromList @(n + 1) @Double (toList sh) [] let rarr = rfromOrthotope inrank arr - OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === [] prop_sum_lasteq1 :: Property prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do let inrank = SNat @(n + 1) outsh <- forAll $ genShR outrank - guard (all (> 0) (toList outsh)) + guard (all (> 0) outsh) let insh = shrAppend outsh (1 :$: ZSR) arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> genStorables (Range.singleton (product insh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr - almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) prop_sum_replicated :: Bool -> Property prop_sum_replicated doTranspose = property $ @@ -89,7 +95,11 @@ prop_sum_replicated doTranspose = property $ LTI -> return Refl -- actually we only continue if m < n _ -> discard (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 - guard (all (> 0) (toList sh3)) + when debugCoverage $ do + label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1))) + label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int))) + label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int))) + guard (all (> 0) sh3) arr <- forAllT $ OR.stretch (toList sh3) . OR.reshape (toList sh2) @@ -101,7 +111,7 @@ prop_sum_replicated doTranspose = property $ return $ OR.transpose perm arr else return arr let rarr = rfromOrthotope inrank2 arrTrans - almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arrTrans) prop_negate_with :: forall f b. Show b => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ()) @@ -111,7 +121,7 @@ prop_negate_with :: forall f b. Show b prop_negate_with genRank' genB preproc = property $ genRank' $ \extra rank@(SNat @n) -> do sh <- forAll $ genShR rank - guard (all (> 0) (toList sh)) + guard (all (> 0) sh) arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$> genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs index 1e7ad13..4e75d64 100644 --- a/test/Tests/Permutation.hs +++ b/test/Tests/Permutation.hs @@ -6,7 +6,7 @@ module Tests.Permutation where import Data.Type.Equality -import Data.Array.Mixed.Permutation +import Data.Array.Nested.Permutation import Hedgehog import Hedgehog.Gen qualified as Gen @@ -24,7 +24,7 @@ tests = testGroup "Permutation" [testProperty "permCheckPermutation" $ property $ do n <- forAll $ Gen.int (Range.linear 0 10) list <- forAll $ genPermR n - let r = permFromList list $ \perm -> + let r = permFromListCont list $ \perm -> permCheckPermutation perm () case r of Just () -> return () diff --git a/test/Util.hs b/test/Util.hs index ce6ec23..8a5ba72 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -12,11 +11,11 @@ module Util where import Data.Array.RankedS qualified as OR import Data.Kind +import GHC.TypeLits import Hedgehog import Hedgehog.Internal.Property (failDiff) -import GHC.TypeLits -import Data.Array.Mixed.Types (fromSNat') +import Data.Array.Nested.Types (fromSNat') -- Returns highest value that satisfies the predicate, or `lo` if none does @@ -43,7 +42,7 @@ class AlmostEq f where almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) => a -> f a -> f a -> m () -instance KnownNat n => AlmostEq (OR.Array n) where +instance AlmostEq (OR.Array n) where type AlmostEqConstr (OR.Array n) = OR.Unbox almostEq atol lhs rhs | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = |
