diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 21:46:34 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 21:51:22 +0200 |
commit | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (patch) | |
tree | 64dcb00c9c61ad57177db5ec01c189d74dbc2d4a /test/Main.hs | |
parent | 5a802da40e5836ee19d46b9a2c771912dbff010e (diff) |
Reorganise test files
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 151 |
1 files changed, 2 insertions, 149 deletions
diff --git a/test/Main.hs b/test/Main.hs index b5237e5..7e62641 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,158 +1,11 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE NumericUnderscores #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Main where -import Control.Monad -import qualified Data.Array.RankedS as OR -import qualified Data.ByteString as BS -import Data.Foldable (toList) -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import Foreign -import GHC.TypeLits -import qualified GHC.TypeNats as TN - -import qualified Data.Array.Mixed as X -import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) -import Data.Array.Nested -import qualified Data.Array.Nested.Internal as I - --- test framework stuff -import Hedgehog -import Hedgehog.Internal.Property (forAllT) -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range -import qualified System.Random as Random import Test.Tasty -import Test.Tasty.Hedgehog - --- import Debug.Trace - - --- Returns highest value that satisfies the predicate, or `lo` if none does -binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a -binarySearch div2 = \lo hi f -> case (f lo, f hi) of - (False, _) -> lo - (_, True) -> hi - (_, _ ) -> go lo hi f - where - go lo hi f = -- invariant: f lo && not (f hi) - let mid = lo + div2 (hi - lo) - in if mid `elem` [lo, hi] - then mid - else if f mid then go mid hi f else go lo mid f - -genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () -genRank k = do - rank <- forAll $ Gen.int (Range.linear 0 8) - TN.withSomeSNat (fromIntegral rank) k -genLowBiased :: RealFloat a => (a, a) -> Gen a -genLowBiased (lo, hi) = do - x <- Gen.realFloat (Range.linearFrac 0 1) - return (lo + x * x * x * (hi - lo)) +import qualified Tests.C -shuffleShR :: IShR n -> Gen (IShR n) -shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh - where - go :: Int -> [Int] -> IShR n -> Gen (IShR n) - go _ _ ZSR = return ZSR - go nbag bag (_ :$: sh) = do - idx <- Gen.int (Range.linear 0 (nbag - 1)) - let (dim, bag') = case splitAt idx bag of - (pre, n : post) -> (n, pre ++ post) - _ -> error "unreachable" - (dim :$:) <$> go (nbag - 1) bag' sh - -genShR :: SNat n -> Gen (IShR n) -genShR sn = do - let n = fromSNat' sn - targetSize <- Gen.int (Range.linear 0 100_000) - let genDims :: SNat m -> Int -> Gen (IShR m) - genDims SZ _ = return ZSR - genDims (SS m) 0 = do - dim <- Gen.int (Range.linear 0 20) - dims <- genDims m 0 - return (dim :$: dims) - genDims (SS m) tgt = do - dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) - ,(2 , return tgt) - ,(4 , return 1) - ,(1 , return 0)] - dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) - return (dim :$: dims) - dims <- genDims sn targetSize - let dimsL = toList dims - maxdim = maximum dimsL - cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) - shuffleShR (min cap <$> dims) - -genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) -genStorables rng f = do - n <- Gen.int rng - seed <- Gen.resize 99 $ Gen.int Range.linearBounded - let gen0 = Random.mkStdGen seed - (bs, _) = Random.genByteString (8 * n) gen0 - let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) - return $ VS.generate n (f . readW64) - -orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a -orSumOuter1 (sn@SNat :: SNat n) = - let n = fromSNat' sn - in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) - -rshTail :: ShR (n + 1) i -> ShR n i -rshTail (_ :$: sh) = sh -rshTail ZSR = error "unreachable" main :: IO () main = defaultMain $ testGroup "Tests" - [testGroup "C" - [testGroup "sum" - [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do - -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. - let inrank = SNat @(n + 1) - sh <- forAll $ genShR inrank - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> - genStorables (Range.singleton (product sh)) - (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - -- annotateShow rarr - Refl <- return $ I.lemRankReplicate outrank - let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - let rhs = orSumOuter1 outrank arr - -- annotateShow lhs - -- annotateShow rhs - lhs === rhs - - ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do - -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. - outrank :: SNat n <- return $ SNat @(nm1 + 1) - let inrank = SNat @(n + 1) - sh <- forAll $ do - shtt <- genShR outrankm1 -- nm1 - sht <- shuffleShR (0 :$: shtt) -- n - n <- Gen.int (Range.linear 0 20) - return (n :$: sht) -- n + 1 - guard (any (== 0) (toList (rshTail sh))) - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - let arr = OR.fromList @Double @(n + 1) (toList sh) [] - let rarr = rfromOrthotope inrank arr - Refl <- return $ I.lemRankReplicate outrank - let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - OR.toList lhs === [] - ] - ] - ] + [Tests.C.tests] |