aboutsummaryrefslogtreecommitdiff
path: root/test/Tests/C.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Tests/C.hs')
-rw-r--r--test/Tests/C.hs107
1 files changed, 70 insertions, 37 deletions
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 2a3949f..148e7f6 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -16,10 +16,8 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Shape
import Hedgehog
@@ -35,42 +33,77 @@ import Gen
import Util
+prop_sum_nonempty :: Property
+prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+
+prop_sum_empty :: Property
+prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ _outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (any (== 0) (toList (shrTail sh)))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @Double @(n + 1) (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+ let inrank = SNat @(n + 1)
+ outsh <- forAll $ genShR outrank
+ guard (all (> 0) (toList outsh))
+ let insh = shrAppend outsh (1 :$: ZSR)
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
+ genStorables (Range.singleton (product insh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = property $
+ genRank $ \inrank1@(SNat @m) ->
+ genRank $ \outrank@(SNat @nm1) -> do
+ inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
+ (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
+ LTI -> return Refl -- actually we only continue if m < n
+ _ -> discard
+ (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ guard (all (> 0) (toList sh3))
+ arr <- forAllT $
+ OR.stretch (toList sh3)
+ . OR.reshape (toList sh2)
+ . OR.fromVector @Double @m (toList sh1) <$>
+ genStorables (Range.singleton (product sh1))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ arrTrans <-
+ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
+ return $ OR.transpose perm arr
+ else return arr
+ let rarr = rfromOrthotope inrank2 arrTrans
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+
tests :: TestTree
tests = testGroup "C"
[testGroup "sum"
- [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
- -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
- let inrank = SNat @(n + 1)
- sh <- forAll $ genShR inrank
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
- genStorables (Range.singleton (product sh))
- (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- -- annotateShow rarr
- Refl <- return $ lemRankReplicate outrank
- let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
- let rhs = orSumOuter1 outrank arr
- -- annotateShow lhs
- -- annotateShow rhs
- lhs === rhs
-
- ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
- -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
- outrank :: SNat n <- return $ SNat @(nm1 + 1)
- let inrank = SNat @(n + 1)
- sh <- forAll $ do
- shtt <- genShR outrankm1 -- nm1
- sht <- shuffleShR (0 :$: shtt) -- n
- n <- Gen.int (Range.linear 0 20)
- return (n :$: sht) -- n + 1
- guard (any (== 0) (toList (shrTail sh)))
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- let arr = OR.fromList @Double @(n + 1) (toList sh) []
- let rarr = rfromOrthotope inrank arr
- Refl <- return $ lemRankReplicate outrank
- let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
- OR.toList lhs === []
+ [testProperty "nonempty" prop_sum_nonempty
+ ,testProperty "empty" prop_sum_empty
+ ,testProperty "last==1" prop_sum_lasteq1
+ ,testProperty "replicated" (prop_sum_replicated False)
+ ,testProperty "replicated_transposed" (prop_sum_replicated True)
]
]