aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs75
1 files changed, 64 insertions, 11 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 002c606..dd59586 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,17 +1,22 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE DataKinds #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where
+import Control.Monad
import qualified Data.Array.RankedS as OR
+import qualified Data.ByteString as BS
import Data.Foldable (toList)
import Data.Type.Equality
+import qualified Data.Vector.Storable as VS
+import Foreign
import GHC.TypeLits
import qualified GHC.TypeNats as TN
@@ -22,14 +27,29 @@ import qualified Data.Array.Nested.Internal as I
-- test framework stuff
import Hedgehog
+import Hedgehog.Internal.Property (forAllT)
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
+import qualified System.Random as Random
import Test.Tasty
import Test.Tasty.Hedgehog
import Debug.Trace
+-- Returns highest value that satisfies the predicate, or `lo` if none does
+binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
+binarySearch div2 = \lo hi f -> case (f lo, f hi) of
+ (False, _) -> lo
+ (_, True) -> hi
+ (_, _ ) -> go lo hi f
+ where
+ go lo hi f = -- invariant: f lo && not (f hi)
+ let mid = lo + div2 (hi - lo)
+ in if mid `elem` [lo, hi]
+ then mid
+ else if f mid then go mid hi f else go lo mid f
+
genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
genRank k = do
rank <- forAll $ Gen.int (Range.linear 0 8)
@@ -55,7 +75,7 @@ shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
genShR :: SNat n -> Gen (IShR n)
genShR sn = do
let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n))
+ targetSize <- Gen.int (Range.linear 0 100_000)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
@@ -69,7 +89,20 @@ genShR sn = do
,(1 , return 0)]
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
- shuffleShR =<< genDims sn targetSize
+ dims <- genDims sn targetSize
+ let dimsL = toList dims
+ maxdim = maximum dimsL
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ shuffleShR (min cap <$> dims)
+
+genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
+genStorables rng f = do
+ n <- Gen.int rng
+ seed <- Gen.resize 99 $ Gen.int Range.linearBounded
+ let gen0 = Random.mkStdGen seed
+ (bs, _) = Random.genByteString (8 * n) gen0
+ let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
+ return $ VS.generate n (f . readW64)
orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
orSumOuter1 (sn@SNat :: SNat n) =
@@ -81,20 +114,40 @@ main = defaultMain $
testGroup "Tests"
[testGroup "C"
[testGroup "sum"
- [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do
+ [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
- arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$>
- Gen.list (Range.singleton (product sh))
- (Gen.realFloat (Range.linearFrac @Double 0 1))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (tail (toList sh))) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
- annotateShow rarr
+ -- annotateShow rarr
Refl <- return $ I.lemRankReplicate outrank
let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
let rhs = orSumOuter1 outrank arr
- annotateShow lhs
- annotateShow rhs
+ -- annotateShow lhs
+ -- annotateShow rhs
lhs === rhs
+
+ ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (any (== 0) (tail (toList sh)))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @Double @(n + 1) (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ Refl <- return $ I.lemRankReplicate outrank
+ let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
+ OR.toList lhs === []
]
]
]