diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 75 |
1 files changed, 64 insertions, 11 deletions
diff --git a/test/Main.hs b/test/Main.hs index 002c606..dd59586 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,17 +1,22 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DataKinds #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Main where +import Control.Monad import qualified Data.Array.RankedS as OR +import qualified Data.ByteString as BS import Data.Foldable (toList) import Data.Type.Equality +import qualified Data.Vector.Storable as VS +import Foreign import GHC.TypeLits import qualified GHC.TypeNats as TN @@ -22,14 +27,29 @@ import qualified Data.Array.Nested.Internal as I -- test framework stuff import Hedgehog +import Hedgehog.Internal.Property (forAllT) import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range +import qualified System.Random as Random import Test.Tasty import Test.Tasty.Hedgehog import Debug.Trace +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of + (False, _) -> lo + (_, True) -> hi + (_, _ ) -> go lo hi f + where + go lo hi f = -- invariant: f lo && not (f hi) + let mid = lo + div2 (hi - lo) + in if mid `elem` [lo, hi] + then mid + else if f mid then go mid hi f else go lo mid f + genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () genRank k = do rank <- forAll $ Gen.int (Range.linear 0 8) @@ -55,7 +75,7 @@ shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh genShR :: SNat n -> Gen (IShR n) genShR sn = do let n = fromSNat' sn - targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n)) + targetSize <- Gen.int (Range.linear 0 100_000) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do @@ -69,7 +89,20 @@ genShR sn = do ,(1 , return 0)] dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) - shuffleShR =<< genDims sn targetSize + dims <- genDims sn targetSize + let dimsL = toList dims + maxdim = maximum dimsL + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + shuffleShR (min cap <$> dims) + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do + n <- Gen.int rng + seed <- Gen.resize 99 $ Gen.int Range.linearBounded + let gen0 = Random.mkStdGen seed + (bs, _) = Random.genByteString (8 * n) gen0 + let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) + return $ VS.generate n (f . readW64) orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a orSumOuter1 (sn@SNat :: SNat n) = @@ -81,20 +114,40 @@ main = defaultMain $ testGroup "Tests" [testGroup "C" [testGroup "sum" - [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do + [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank - arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$> - Gen.list (Range.singleton (product sh)) - (Gen.realFloat (Range.linearFrac @Double 0 1)) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (tail (toList sh))) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr - annotateShow rarr + -- annotateShow rarr Refl <- return $ I.lemRankReplicate outrank let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr let rhs = orSumOuter1 outrank arr - annotateShow lhs - annotateShow rhs + -- annotateShow lhs + -- annotateShow rhs lhs === rhs + + ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (any (== 0) (tail (toList sh))) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @Double @(n + 1) (toList sh) [] + let rarr = rfromOrthotope inrank arr + Refl <- return $ I.lemRankReplicate outrank + let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr + OR.toList lhs === [] ] ] ] |