diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Gen.hs | 174 | ||||
-rw-r--r-- | test/Main.hs | 32 | ||||
-rw-r--r-- | test/Tests/C.hs | 160 | ||||
-rw-r--r-- | test/Tests/Permutation.hs | 39 | ||||
-rw-r--r-- | test/Util.hs | 51 |
5 files changed, 433 insertions, 23 deletions
diff --git a/test/Gen.hs b/test/Gen.hs new file mode 100644 index 0000000..044de14 --- /dev/null +++ b/test/Gen.hs @@ -0,0 +1,174 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Gen where + +import Data.ByteString qualified as BS +import Data.Foldable (toList) +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types + +import Hedgehog +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range +import System.Random qualified as Random + +import Util + + +-- | Generates zero with small probability, because there's typically only one +-- interesting case for 0 anyway. +genRank :: Monad m => (forall n. SNat n -> PropertyT m ()) -> PropertyT m () +genRank k = do + rank <- forAll $ Gen.frequency [(1, return 0) + ,(49, Gen.int (Range.linear 1 8))] + TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do + x <- Gen.realFloat (Range.linearFrac 0 1) + return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length sh) (toList sh) sh + where + go :: Int -> [Int] -> IShR n -> Gen (IShR n) + go _ _ ZSR = return ZSR + go nbag bag (_ :$: sh) = do + idx <- Gen.int (Range.linear 0 (nbag - 1)) + let (dim, bag') = case splitAt idx bag of + (pre, n : post) -> (n, pre ++ post) + _ -> error "unreachable" + (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR = genShRwithTarget 100_000 + +genShRwithTarget :: Int -> SNat n -> Gen (IShR n) +genShRwithTarget targetMax sn = do + let n = fromSNat' sn + targetSize <- Gen.int (Range.linear 0 targetMax) + let genDims :: SNat m -> Int -> Gen (IShR m) + genDims SZ _ = return ZSR + genDims (SS m) 0 = do + dim <- Gen.int (Range.linear 0 20) + dims <- genDims m 0 + return (dim :$: dims) + genDims (SS m) tgt = do + dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) + ,(2 , return tgt) + ,(4 , return 1) + ,(1 , return 0)] + dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) + return (dim :$: dims) + dims <- genDims sn targetSize + let dimsL = toList dims + maxdim = maximum dimsL + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + shuffleShR (min cap <$> dims) + +-- | Example: given 3 and 7, might return: +-- +-- @ +-- ([ 13, 4, 27 ] +-- ,[1, 13, 1, 1, 4, 27, 1] +-- ,[4, 13, 1, 3, 4, 27, 2]) +-- @ +-- +-- The up-replicated dimensions are always nonzero and not very large, but the +-- other dimensions might be zero. +genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) +genReplicatedShR = \m n -> do + let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m)) + sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m + (sh2, sh3) <- injectOnes n sh1 sh1 + return (sh1, sh2, sh3) + where + repvalrange = (1::Int, 5) + repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double + + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) + injectOnes n@SNat shOnes sh + | m@SNat <- shrRank sh + = case cmpNat n m of + LTI -> error "unreachable" + EQI -> return (shOnes, sh) + GTI -> do + index <- Gen.int (Range.linear 0 (fromSNat' m)) + value <- Gen.int (uncurry Range.linear repvalrange) + Refl <- return (lem n m) + injectOnes n (inject index 1 shOnes) (inject index value sh) + + lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem _ _ = unsafeCoerceRefl + + inject :: Int -> Int -> IShR m -> IShR (m + 1) + inject 0 v sh = v :$: sh + inject i v (w :$: sh) = w :$: inject (i - 1) v sh + inject _ _ ZSR = error "unreachable" + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do + n <- Gen.int rng + seed <- Gen.resize 99 $ Gen.int Range.linearBounded + let gen0 = Random.mkStdGen seed + (bs, _) = Random.uniformByteString (8 * n) gen0 + let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) + return $ VS.generate n (f . readW64) + +genStaticShX :: Monad m => SNat n -> (forall sh. Rank sh ~ n => StaticShX sh -> PropertyT m ()) -> PropertyT m () +genStaticShX = \n k -> case n of + SZ -> k ZKX + SS n' -> + genItem $ \item -> + genStaticShX n' $ \ssh -> + k (item :!% ssh) + where + genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m () + genItem k = do + b <- forAll Gen.bool + if b + then do + n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4)) + ,(1, return 0)] + TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn) + else k (SUnknown ()) + +genShX :: StaticShX sh -> Gen (IShX sh) +genShX ZKX = return ZSX +genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh +genShX (SUnknown () :!% ssh) = do + dim <- Gen.int (Range.linear 1 4) + (SUnknown dim :$%) <$> genShX ssh + +genPermR :: Int -> Gen PermR +genPermR n = Gen.shuffle [0 .. n-1] + +genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r +genPerm n@SNat k = do + list <- forAll $ genPermR (fromSNat' n) + permFromList list $ \perm -> do + case permCheckPermutation perm $ + case sameNat' (permRank perm) n of + Just Refl -> Just (k perm) + Nothing -> Nothing + of + Just (Just act) -> act + _ -> error "" diff --git a/test/Main.hs b/test/Main.hs index 2363813..575bb15 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,29 +1,15 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ImportQualifiedPost #-} module Main where -import Data.Array.Nested +import Test.Tasty +import Tests.C qualified +import Tests.Permutation qualified -arr :: Ranked I2 (Shaped [2, 3] (Double, Int)) -arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> - sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) -> - let s = 24*i + 6*j + 3*k + l - in (fromIntegral s, s) - -foo :: (Double, Int) -foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS) - -bad :: Ranked I2 (Ranked I1 Double) -bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> - rgenerate (i :$: ZSR) $ \(k :.: ZIR) -> - let s = 24*i + 6*j + 3*k - in fromIntegral s main :: IO () -main = do - print arr - print foo - print (rtranspose [1,0] arr) - -- print bad +main = defaultMain $ + testGroup "Tests" + [Tests.C.tests + ,Tests.Permutation.tests + ] diff --git a/test/Tests/C.hs b/test/Tests/C.hs new file mode 100644 index 0000000..9567393 --- /dev/null +++ b/test/Tests/C.hs @@ -0,0 +1,160 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Tests.C where + +import Control.Monad +import Data.Array.RankedS qualified as OR +import Data.Foldable (toList) +import Data.Functor.Const +import Data.Type.Equality +import Foreign +import GHC.TypeLits + +import Data.Array.Nested +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types (fromSNat') + +import Hedgehog +import Hedgehog.Gen qualified as Gen +import Hedgehog.Internal.Property (LabelName(..), forAllT) +import Hedgehog.Range qualified as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +-- import Debug.Trace + +import Gen +import Util + + +-- | Appropriate for simple different summation orders +fineTol :: Double +fineTol = 1e-8 + +debugCoverage :: Bool +debugCoverage = False + +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (shrTail sh)) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + _outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (0 `elem` shrTail sh) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @(n + 1) @Double (toList sh) [] + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do + let inrank = SNat @(n + 1) + outsh <- forAll $ genShR outrank + guard (all (> 0) outsh) + let insh = shrAppend outsh (1 :$: ZSR) + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> + genStorables (Range.singleton (product insh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ + genRank $ \inrank1@(SNat @m) -> + genRank $ \outrank@(SNat @nm1) -> do + inrank2 :: SNat n <- return $ SNat @(nm1 + 1) + (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of + LTI -> return Refl -- actually we only continue if m < n + _ -> discard + (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 + when debugCoverage $ do + label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1))) + label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int))) + label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int))) + guard (all (> 0) sh3) + arr <- forAllT $ + OR.stretch (toList sh3) + . OR.reshape (toList sh2) + . OR.fromVector @Double @m (toList sh1) <$> + genStorables (Range.singleton (product sh1)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + arrTrans <- + if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) + return $ OR.transpose perm arr + else return arr + let rarr = rfromOrthotope inrank2 arrTrans + almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + +prop_negate_with :: forall f b. Show b + => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ()) + -> (forall n. f n -> IShR n -> Gen b) + -> (forall n. f n -> b -> OR.Array n Double -> OR.Array n Double) + -> Property +prop_negate_with genRank' genB preproc = property $ + genRank' $ \extra rank@(SNat @n) -> do + sh <- forAll $ genShR rank + guard (all (> 0) sh) + arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + bval <- forAll $ genB extra sh + let arr' = preproc extra bval arr + annotate (show (OR.shapeL arr')) + let rarr = rfromOrthotope rank arr' + rtoOrthotope (negate rarr) === OR.mapA negate arr' + +tests :: TestTree +tests = testGroup "C" + [testGroup "sum" + [testProperty "nonempty" prop_sum_nonempty + ,testProperty "empty" prop_sum_empty + ,testProperty "last==1" prop_sum_lasteq1 + ,testProperty "replicated" (prop_sum_replicated False) + ,testProperty "replicated_transposed" (prop_sum_replicated True) + ] + ,testGroup "negate" + [testProperty "normalised" $ prop_negate_with + (\k -> genRank (k (Const ()))) + (\_ _ -> pure ()) + (\_ _ -> id) + ,testProperty "slice 1D" $ prop_negate_with @((:~:) 1) + (\k -> k Refl (SNat @1)) + (\Refl (n :$: _) -> do lo <- Gen.integral (Range.constant 0 (n-1)) + len <- Gen.integral (Range.constant 0 (n-lo)) + return [(lo, len)]) + (\_ -> OR.slice) + ,testProperty "slice nD" $ prop_negate_with + (\k -> genRank (k (Const ()))) + (\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1)) + len <- Gen.integral (Range.constant 0 (n-lo-1)) + return (lo, len) + pairs <- mapM genPair (toList sh) + return pairs) + (\_ -> OR.slice) + ] + ] diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs new file mode 100644 index 0000000..98a6da5 --- /dev/null +++ b/test/Tests/Permutation.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Tests.Permutation where + +import Data.Type.Equality + +import Data.Array.Nested.Permutation + +import Hedgehog +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +-- import Debug.Trace + +import Gen + + +tests :: TestTree +tests = testGroup "Permutation" + [testProperty "permCheckPermutation" $ property $ do + n <- forAll $ Gen.int (Range.linear 0 10) + list <- forAll $ genPermR n + let r = permFromList list $ \perm -> + permCheckPermutation perm () + case r of + Just () -> return () + Nothing -> failure + ,testProperty "permInverse" $ property $ + genRank $ \n -> + genPerm n $ \perm -> + genStaticShX n $ \ssh -> + permInverse perm $ \_invperm proof -> + case proof ssh of + Refl -> return () + ] diff --git a/test/Util.hs b/test/Util.hs new file mode 100644 index 0000000..8a5ba72 --- /dev/null +++ b/test/Util.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Util where + +import Data.Array.RankedS qualified as OR +import Data.Kind +import GHC.TypeLits +import Hedgehog +import Hedgehog.Internal.Property (failDiff) + +import Data.Array.Nested.Types (fromSNat') + + +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of + (False, _) -> lo + (_, True) -> hi + (_, _ ) -> go lo hi f + where + go lo hi f = -- invariant: f lo && not (f hi) + let mid = lo + div2 (hi - lo) + in if mid `elem` [lo, hi] + then mid + else if f mid then go mid hi f else go lo mid f + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = + let n = fromSNat' sn + in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +class AlmostEq f where + type AlmostEqConstr f :: Type -> Constraint + -- | absolute tolerance, lhs, rhs + almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) + => a -> f a -> f a -> m () + +instance AlmostEq (OR.Array n) where + type AlmostEqConstr (OR.Array n) = OR.Unbox + almostEq atol lhs rhs + | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = + success + | otherwise = + failDiff lhs rhs |