aboutsummaryrefslogtreecommitdiff
path: root/test/Tests
diff options
context:
space:
mode:
Diffstat (limited to 'test/Tests')
-rw-r--r--test/Tests/C.hs34
-rw-r--r--test/Tests/Permutation.hs4
2 files changed, 24 insertions, 14 deletions
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index a0f103d..0656107 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -1,9 +1,12 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
{-# LANGUAGE TypeAbstractions #-}
+#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -18,13 +21,13 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
import Hedgehog
-import Hedgehog.Internal.Property (forAllT)
import Hedgehog.Gen qualified as Gen
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -39,18 +42,21 @@ import Util
fineTol :: Double
fineTol = 1e-8
+debugCoverage :: Bool
+debugCoverage = False
+
prop_sum_nonempty :: Property
prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
+ guard (all (> 0) (shrTail sh)) -- only constrain the tail
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
prop_sum_empty :: Property
prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
@@ -62,23 +68,23 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (any (== 0) (toList (shrTail sh)))
+ guard (0 `elem` shrTail sh)
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @(n + 1) @Double (toList sh) []
let rarr = rfromOrthotope inrank arr
- OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+ OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === []
prop_sum_lasteq1 :: Property
prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
- guard (all (> 0) (toList outsh))
+ guard (all (> 0) outsh)
let insh = shrAppend outsh (1 :$: ZSR)
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
genStorables (Range.singleton (product insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
prop_sum_replicated :: Bool -> Property
prop_sum_replicated doTranspose = property $
@@ -89,7 +95,11 @@ prop_sum_replicated doTranspose = property $
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
- guard (all (> 0) (toList sh3))
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
+ guard (all (> 0) sh3)
arr <- forAllT $
OR.stretch (toList sh3)
. OR.reshape (toList sh2)
@@ -101,7 +111,7 @@ prop_sum_replicated doTranspose = property $
return $ OR.transpose perm arr
else return arr
let rarr = rfromOrthotope inrank2 arrTrans
- almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arrTrans)
prop_negate_with :: forall f b. Show b
=> ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
@@ -111,7 +121,7 @@ prop_negate_with :: forall f b. Show b
prop_negate_with genRank' genB preproc = property $
genRank' $ \extra rank@(SNat @n) -> do
sh <- forAll $ genShR rank
- guard (all (> 0) (toList sh))
+ guard (all (> 0) sh)
arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 1e7ad13..4e75d64 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -6,7 +6,7 @@ module Tests.Permutation where
import Data.Type.Equality
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Permutation
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -24,7 +24,7 @@ tests = testGroup "Permutation"
[testProperty "permCheckPermutation" $ property $ do
n <- forAll $ Gen.int (Range.linear 0 10)
list <- forAll $ genPermR n
- let r = permFromList list $ \perm ->
+ let r = permFromListCont list $ \perm ->
permCheckPermutation perm ()
case r of
Just () -> return ()