aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Gen.hs29
-rw-r--r--test/Tests/C.hs23
-rw-r--r--test/Tests/Permutation.hs2
-rw-r--r--test/Util.hs7
4 files changed, 36 insertions, 25 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
index 244c735..044de14 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -2,7 +2,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NumericUnderscores #-}
-{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
@@ -21,11 +20,10 @@ import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -49,7 +47,7 @@ genLowBiased (lo, hi) = do
return (lo + x * x * x * (hi - lo))
shuffleShR :: IShR n -> Gen (IShR n)
-shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
+shuffleShR = \sh -> go (length sh) (toList sh) sh
where
go :: Int -> [Int] -> IShR n -> Gen (IShR n)
go _ _ ZSR = return ZSR
@@ -61,9 +59,12 @@ shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
-genShR sn = do
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 100_000)
+ targetSize <- Gen.int (Range.linear 0 targetMax)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
@@ -95,10 +96,14 @@ genShR sn = do
-- other dimensions might be zero.
genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
genReplicatedShR = \m n -> do
- sh1 <- genShR m
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
(sh2, sh3) <- injectOnes n sh1 sh1
return (sh1, sh2, sh3)
where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
injectOnes n@SNat shOnes sh
| m@SNat <- shrRank sh
@@ -107,17 +112,17 @@ genReplicatedShR = \m n -> do
EQI -> return (shOnes, sh)
GTI -> do
index <- Gen.int (Range.linear 0 (fromSNat' m))
- value <- Gen.int (Range.linear 1 5)
+ value <- Gen.int (uncurry Range.linear repvalrange)
Refl <- return (lem n m)
injectOnes n (inject index 1 shOnes) (inject index value sh)
- lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+ lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True
lem _ _ = unsafeCoerceRefl
inject :: Int -> Int -> IShR m -> IShR (m + 1)
inject 0 v sh = v :$: sh
inject i v (w :$: sh) = w :$: inject (i - 1) v sh
- inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+ inject _ _ ZSR = error "unreachable"
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index a0f103d..9567393 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -18,13 +18,13 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
import Hedgehog
-import Hedgehog.Internal.Property (forAllT)
import Hedgehog.Gen qualified as Gen
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -39,13 +39,16 @@ import Util
fineTol :: Double
fineTol = 1e-8
+debugCoverage :: Bool
+debugCoverage = False
+
prop_sum_nonempty :: Property
prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
+ guard (all (> 0) (shrTail sh)) -- only constrain the tail
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
@@ -62,7 +65,7 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (any (== 0) (toList (shrTail sh)))
+ guard (0 `elem` shrTail sh)
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @(n + 1) @Double (toList sh) []
let rarr = rfromOrthotope inrank arr
@@ -72,7 +75,7 @@ prop_sum_lasteq1 :: Property
prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
- guard (all (> 0) (toList outsh))
+ guard (all (> 0) outsh)
let insh = shrAppend outsh (1 :$: ZSR)
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
genStorables (Range.singleton (product insh))
@@ -89,7 +92,11 @@ prop_sum_replicated doTranspose = property $
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
- guard (all (> 0) (toList sh3))
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
+ guard (all (> 0) sh3)
arr <- forAllT $
OR.stretch (toList sh3)
. OR.reshape (toList sh2)
@@ -111,7 +118,7 @@ prop_negate_with :: forall f b. Show b
prop_negate_with genRank' genB preproc = property $
genRank' $ \extra rank@(SNat @n) -> do
sh <- forAll $ genShR rank
- guard (all (> 0) (toList sh))
+ guard (all (> 0) sh)
arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 1e7ad13..98a6da5 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -6,7 +6,7 @@ module Tests.Permutation where
import Data.Type.Equality
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Permutation
import Hedgehog
import Hedgehog.Gen qualified as Gen
diff --git a/test/Util.hs b/test/Util.hs
index ce6ec23..8a5ba72 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -1,7 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
@@ -12,11 +11,11 @@ module Util where
import Data.Array.RankedS qualified as OR
import Data.Kind
+import GHC.TypeLits
import Hedgehog
import Hedgehog.Internal.Property (failDiff)
-import GHC.TypeLits
-import Data.Array.Mixed.Types (fromSNat')
+import Data.Array.Nested.Types (fromSNat')
-- Returns highest value that satisfies the predicate, or `lo` if none does
@@ -43,7 +42,7 @@ class AlmostEq f where
almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
=> a -> f a -> f a -> m ()
-instance KnownNat n => AlmostEq (OR.Array n) where
+instance AlmostEq (OR.Array n) where
type AlmostEqConstr (OR.Array n) = OR.Unbox
almostEq atol lhs rhs
| OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =