aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Gen.hs31
-rw-r--r--test/Tests/C.hs94
-rw-r--r--test/Tests/Permutation.hs2
-rw-r--r--test/Util.hs16
4 files changed, 104 insertions, 39 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
index 281c620..952e8db 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -4,7 +4,6 @@
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -20,10 +19,10 @@ import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
import Data.Array.Nested
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -59,9 +58,12 @@ shuffleShR = \sh -> go (length sh) (toList sh) sh
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
-genShR sn = do
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 100_000)
+ targetSize <- Gen.int (Range.linear 0 targetMax)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
@@ -76,9 +78,8 @@ genShR sn = do
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
dims <- genDims sn targetSize
- let dimsL = toList dims
- maxdim = maximum dimsL
- cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ let maxdim = maximum dims
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize)
shuffleShR (min cap <$> dims)
-- | Example: given 3 and 7, might return:
@@ -93,10 +94,14 @@ genShR sn = do
-- other dimensions might be zero.
genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
genReplicatedShR = \m n -> do
- sh1 <- genShR m
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
(sh2, sh3) <- injectOnes n sh1 sh1
return (sh1, sh2, sh3)
where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
injectOnes n@SNat shOnes sh
| m@SNat <- shrRank sh
@@ -105,7 +110,7 @@ genReplicatedShR = \m n -> do
EQI -> return (shOnes, sh)
GTI -> do
index <- Gen.int (Range.linear 0 (fromSNat' m))
- value <- Gen.int (Range.linear 1 5)
+ value <- Gen.int (uncurry Range.linear repvalrange)
Refl <- return (lem n m)
injectOnes n (inject index 1 shOnes) (inject index value sh)
@@ -115,7 +120,7 @@ genReplicatedShR = \m n -> do
inject :: Int -> Int -> IShR m -> IShR (m + 1)
inject 0 v sh = v :$: sh
inject i v (w :$: sh) = w :$: inject (i - 1) v sh
- inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+ inject _ _ ZSR = error "unreachable"
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
@@ -134,7 +139,7 @@ genStaticShX = \n k -> case n of
genStaticShX n' $ \ssh ->
k (item :!% ssh)
where
- genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m ()
+ genItem :: Monad m => (forall n. SMayNat () n -> PropertyT m ()) -> PropertyT m ()
genItem k = do
b <- forAll Gen.bool
if b
@@ -157,7 +162,7 @@ genPermR n = Gen.shuffle [0 .. n-1]
genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
genPerm n@SNat k = do
list <- forAll $ genPermR (fromSNat' n)
- permFromList list $ \perm -> do
+ permFromListCont list $ \perm -> do
case permCheckPermutation perm $
case sameNat' (permRank perm) n of
Just Refl -> Just (k perm)
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 3b78bc0..e26c3dd 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -1,9 +1,12 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
{-# LANGUAGE TypeAbstractions #-}
+#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -18,13 +21,13 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Nested.Types (fromSNat')
import Data.Array.Nested
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
import Hedgehog
import Hedgehog.Gen qualified as Gen
-import Hedgehog.Internal.Property (forAllT)
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -39,8 +42,11 @@ import Util
fineTol :: Double
fineTol = 1e-8
-prop_sum_nonempty :: Property
-prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+debugCoverage :: Bool
+debugCoverage = False
+
+gen_red_nonempty :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_nonempty f = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
@@ -49,11 +55,10 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ f inrank outrank arr
-prop_sum_empty :: Property
-prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+gen_red_empty :: (forall n. SNat (n + 1) -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_empty f = property $ genRank $ \outrankm1@(SNat @nm1) -> do
-- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
_outrank :: SNat n <- return $ SNat @(nm1 + 1)
let inrank = SNat @(n + 1)
@@ -62,14 +67,13 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (0 `elem` toList (shrTail sh))
+ guard (0 `elem` shrTail sh)
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @(n + 1) @Double (toList sh) []
- let rarr = rfromOrthotope inrank arr
- OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+ f inrank arr
-prop_sum_lasteq1 :: Property
-prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+gen_red_lasteq1 :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_lasteq1 f = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
guard (all (> 0) outsh)
@@ -77,11 +81,10 @@ prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
genStorables (Range.singleton (product insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ f inrank outrank arr
-prop_sum_replicated :: Bool -> Property
-prop_sum_replicated doTranspose = property $
+gen_red_replicated :: Bool -> (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_replicated doTranspose f = property $
genRank $ \inrank1@(SNat @m) ->
genRank $ \outrank@(SNat @nm1) -> do
inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
@@ -89,6 +92,10 @@ prop_sum_replicated doTranspose = property $
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
guard (all (> 0) sh3)
arr <- forAllT $
OR.stretch (toList sh3)
@@ -100,8 +107,50 @@ prop_sum_replicated doTranspose = property $
if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
return $ OR.transpose perm arr
else return arr
- let rarr = rfromOrthotope inrank2 arrTrans
- almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+ f inrank2 outrank arrTrans
+
+
+prop_sum_nonempty :: Property
+prop_sum_nonempty = gen_red_nonempty $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_empty :: Property
+prop_sum_empty = gen_red_empty $ \inrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = gen_red_lasteq1 $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = gen_red_replicated doTranspose $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+
+prop_sumall_nonempty :: Property
+prop_sumall_nonempty = gen_red_nonempty $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr)
+
+prop_sumall_empty :: Property
+prop_sumall_empty = gen_red_empty $ \inrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ rsumAllPrim rarr === 0.0
+
+prop_sumall_lasteq1 :: Property
+prop_sumall_lasteq1 = gen_red_lasteq1 $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr)
+
+prop_sumall_replicated :: Bool -> Property
+prop_sumall_replicated doTranspose = gen_red_replicated doTranspose $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq 1e-6 (rsumAllPrim rarr) (OR.sumA arr)
+
prop_negate_with :: forall f b. Show b
=> ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
@@ -130,6 +179,13 @@ tests = testGroup "C"
,testProperty "replicated" (prop_sum_replicated False)
,testProperty "replicated_transposed" (prop_sum_replicated True)
]
+ ,testGroup "sumAll"
+ [testProperty "nonempty" prop_sumall_nonempty
+ ,testProperty "empty" prop_sumall_empty
+ ,testProperty "last==1" prop_sumall_lasteq1
+ ,testProperty "replicated" (prop_sumall_replicated False)
+ ,testProperty "replicated_transposed" (prop_sumall_replicated True)
+ ]
,testGroup "negate"
[testProperty "normalised" $ prop_negate_with
(\k -> genRank (k (Const ())))
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 98a6da5..4e75d64 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -24,7 +24,7 @@ tests = testGroup "Permutation"
[testProperty "permCheckPermutation" $ property $ do
n <- forAll $ Gen.int (Range.linear 0 10)
list <- forAll $ genPermR n
- let r = permFromList list $ \perm ->
+ let r = permFromListCont list $ \perm ->
permCheckPermutation perm ()
case r of
Just () -> return ()
diff --git a/test/Util.hs b/test/Util.hs
index 8a5ba72..6514fbf 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -36,16 +36,20 @@ orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
-class AlmostEq f where
- type AlmostEqConstr f :: Type -> Constraint
+class AlmostEq t where
+ type EltOf t :: Type
-- | absolute tolerance, lhs, rhs
- almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
- => a -> f a -> f a -> m ()
+ almostEq :: MonadTest m => EltOf t -> t -> t -> m ()
-instance AlmostEq (OR.Array n) where
- type AlmostEqConstr (OR.Array n) = OR.Unbox
+instance (OR.Unbox a, Ord a, Show a, Fractional a) => AlmostEq (OR.Array n a) where
+ type EltOf (OR.Array n a) = a
almostEq atol lhs rhs
| OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
success
| otherwise =
failDiff lhs rhs
+
+instance AlmostEq Double where
+ type EltOf Double = Double
+ almostEq atol lhs rhs | abs (lhs - rhs) < atol = success
+ | otherwise = failDiff lhs rhs