aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Gen.hs174
-rw-r--r--test/Main.hs32
-rw-r--r--test/Tests/C.hs160
-rw-r--r--test/Tests/Permutation.hs39
-rw-r--r--test/Util.hs51
5 files changed, 433 insertions, 23 deletions
diff --git a/test/Gen.hs b/test/Gen.hs
new file mode 100644
index 0000000..044de14
--- /dev/null
+++ b/test/Gen.hs
@@ -0,0 +1,174 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Gen where
+
+import Data.ByteString qualified as BS
+import Data.Foldable (toList)
+import Data.Type.Equality
+import Data.Type.Ord
+import Data.Vector.Storable qualified as VS
+import Foreign
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
+import System.Random qualified as Random
+
+import Util
+
+
+-- | Generates zero with small probability, because there's typically only one
+-- interesting case for 0 anyway.
+genRank :: Monad m => (forall n. SNat n -> PropertyT m ()) -> PropertyT m ()
+genRank k = do
+ rank <- forAll $ Gen.frequency [(1, return 0)
+ ,(49, Gen.int (Range.linear 1 8))]
+ TN.withSomeSNat (fromIntegral rank) k
+
+genLowBiased :: RealFloat a => (a, a) -> Gen a
+genLowBiased (lo, hi) = do
+ x <- Gen.realFloat (Range.linearFrac 0 1)
+ return (lo + x * x * x * (hi - lo))
+
+shuffleShR :: IShR n -> Gen (IShR n)
+shuffleShR = \sh -> go (length sh) (toList sh) sh
+ where
+ go :: Int -> [Int] -> IShR n -> Gen (IShR n)
+ go _ _ ZSR = return ZSR
+ go nbag bag (_ :$: sh) = do
+ idx <- Gen.int (Range.linear 0 (nbag - 1))
+ let (dim, bag') = case splitAt idx bag of
+ (pre, n : post) -> (n, pre ++ post)
+ _ -> error "unreachable"
+ (dim :$:) <$> go (nbag - 1) bag' sh
+
+genShR :: SNat n -> Gen (IShR n)
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
+ let n = fromSNat' sn
+ targetSize <- Gen.int (Range.linear 0 targetMax)
+ let genDims :: SNat m -> Int -> Gen (IShR m)
+ genDims SZ _ = return ZSR
+ genDims (SS m) 0 = do
+ dim <- Gen.int (Range.linear 0 20)
+ dims <- genDims m 0
+ return (dim :$: dims)
+ genDims (SS m) tgt = do
+ dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
+ ,(2 , return tgt)
+ ,(4 , return 1)
+ ,(1 , return 0)]
+ dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
+ return (dim :$: dims)
+ dims <- genDims sn targetSize
+ let dimsL = toList dims
+ maxdim = maximum dimsL
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ shuffleShR (min cap <$> dims)
+
+-- | Example: given 3 and 7, might return:
+--
+-- @
+-- ([ 13, 4, 27 ]
+-- ,[1, 13, 1, 1, 4, 27, 1]
+-- ,[4, 13, 1, 3, 4, 27, 2])
+-- @
+--
+-- The up-replicated dimensions are always nonzero and not very large, but the
+-- other dimensions might be zero.
+genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
+genReplicatedShR = \m n -> do
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
+ (sh2, sh3) <- injectOnes n sh1 sh1
+ return (sh1, sh2, sh3)
+ where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
+ injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
+ injectOnes n@SNat shOnes sh
+ | m@SNat <- shrRank sh
+ = case cmpNat n m of
+ LTI -> error "unreachable"
+ EQI -> return (shOnes, sh)
+ GTI -> do
+ index <- Gen.int (Range.linear 0 (fromSNat' m))
+ value <- Gen.int (uncurry Range.linear repvalrange)
+ Refl <- return (lem n m)
+ injectOnes n (inject index 1 shOnes) (inject index value sh)
+
+ lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+ lem _ _ = unsafeCoerceRefl
+
+ inject :: Int -> Int -> IShR m -> IShR (m + 1)
+ inject 0 v sh = v :$: sh
+ inject i v (w :$: sh) = w :$: inject (i - 1) v sh
+ inject _ _ ZSR = error "unreachable"
+
+genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
+genStorables rng f = do
+ n <- Gen.int rng
+ seed <- Gen.resize 99 $ Gen.int Range.linearBounded
+ let gen0 = Random.mkStdGen seed
+ (bs, _) = Random.uniformByteString (8 * n) gen0
+ let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
+ return $ VS.generate n (f . readW64)
+
+genStaticShX :: Monad m => SNat n -> (forall sh. Rank sh ~ n => StaticShX sh -> PropertyT m ()) -> PropertyT m ()
+genStaticShX = \n k -> case n of
+ SZ -> k ZKX
+ SS n' ->
+ genItem $ \item ->
+ genStaticShX n' $ \ssh ->
+ k (item :!% ssh)
+ where
+ genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m ()
+ genItem k = do
+ b <- forAll Gen.bool
+ if b
+ then do
+ n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4))
+ ,(1, return 0)]
+ TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn)
+ else k (SUnknown ())
+
+genShX :: StaticShX sh -> Gen (IShX sh)
+genShX ZKX = return ZSX
+genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh
+genShX (SUnknown () :!% ssh) = do
+ dim <- Gen.int (Range.linear 1 4)
+ (SUnknown dim :$%) <$> genShX ssh
+
+genPermR :: Int -> Gen PermR
+genPermR n = Gen.shuffle [0 .. n-1]
+
+genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
+genPerm n@SNat k = do
+ list <- forAll $ genPermR (fromSNat' n)
+ permFromList list $ \perm -> do
+ case permCheckPermutation perm $
+ case sameNat' (permRank perm) n of
+ Just Refl -> Just (k perm)
+ Nothing -> Nothing
+ of
+ Just (Just act) -> act
+ _ -> error ""
diff --git a/test/Main.hs b/test/Main.hs
index 2363813..575bb15 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,29 +1,15 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ImportQualifiedPost #-}
module Main where
-import Data.Array.Nested
+import Test.Tasty
+import Tests.C qualified
+import Tests.Permutation qualified
-arr :: Ranked I2 (Shaped [2, 3] (Double, Int))
-arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
- sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) ->
- let s = 24*i + 6*j + 3*k + l
- in (fromIntegral s, s)
-
-foo :: (Double, Int)
-foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS)
-
-bad :: Ranked I2 (Ranked I1 Double)
-bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
- rgenerate (i :$: ZSR) $ \(k :.: ZIR) ->
- let s = 24*i + 6*j + 3*k
- in fromIntegral s
main :: IO ()
-main = do
- print arr
- print foo
- print (rtranspose [1,0] arr)
- -- print bad
+main = defaultMain $
+ testGroup "Tests"
+ [Tests.C.tests
+ ,Tests.Permutation.tests
+ ]
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
new file mode 100644
index 0000000..9567393
--- /dev/null
+++ b/test/Tests/C.hs
@@ -0,0 +1,160 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.C where
+
+import Control.Monad
+import Data.Array.RankedS qualified as OR
+import Data.Foldable (toList)
+import Data.Functor.Const
+import Data.Type.Equality
+import Foreign
+import GHC.TypeLits
+
+import Data.Array.Nested
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
+import Hedgehog.Range qualified as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+import Util
+
+
+-- | Appropriate for simple different summation orders
+fineTol :: Double
+fineTol = 1e-8
+
+debugCoverage :: Bool
+debugCoverage = False
+
+prop_sum_nonempty :: Property
+prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (shrTail sh)) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_empty :: Property
+prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ _outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (0 `elem` shrTail sh)
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @(n + 1) @Double (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+ let inrank = SNat @(n + 1)
+ outsh <- forAll $ genShR outrank
+ guard (all (> 0) outsh)
+ let insh = shrAppend outsh (1 :$: ZSR)
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
+ genStorables (Range.singleton (product insh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = property $
+ genRank $ \inrank1@(SNat @m) ->
+ genRank $ \outrank@(SNat @nm1) -> do
+ inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
+ (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
+ LTI -> return Refl -- actually we only continue if m < n
+ _ -> discard
+ (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
+ guard (all (> 0) sh3)
+ arr <- forAllT $
+ OR.stretch (toList sh3)
+ . OR.reshape (toList sh2)
+ . OR.fromVector @Double @m (toList sh1) <$>
+ genStorables (Range.singleton (product sh1))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ arrTrans <-
+ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
+ return $ OR.transpose perm arr
+ else return arr
+ let rarr = rfromOrthotope inrank2 arrTrans
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+
+prop_negate_with :: forall f b. Show b
+ => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
+ -> (forall n. f n -> IShR n -> Gen b)
+ -> (forall n. f n -> b -> OR.Array n Double -> OR.Array n Double)
+ -> Property
+prop_negate_with genRank' genB preproc = property $
+ genRank' $ \extra rank@(SNat @n) -> do
+ sh <- forAll $ genShR rank
+ guard (all (> 0) sh)
+ arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ bval <- forAll $ genB extra sh
+ let arr' = preproc extra bval arr
+ annotate (show (OR.shapeL arr'))
+ let rarr = rfromOrthotope rank arr'
+ rtoOrthotope (negate rarr) === OR.mapA negate arr'
+
+tests :: TestTree
+tests = testGroup "C"
+ [testGroup "sum"
+ [testProperty "nonempty" prop_sum_nonempty
+ ,testProperty "empty" prop_sum_empty
+ ,testProperty "last==1" prop_sum_lasteq1
+ ,testProperty "replicated" (prop_sum_replicated False)
+ ,testProperty "replicated_transposed" (prop_sum_replicated True)
+ ]
+ ,testGroup "negate"
+ [testProperty "normalised" $ prop_negate_with
+ (\k -> genRank (k (Const ())))
+ (\_ _ -> pure ())
+ (\_ _ -> id)
+ ,testProperty "slice 1D" $ prop_negate_with @((:~:) 1)
+ (\k -> k Refl (SNat @1))
+ (\Refl (n :$: _) -> do lo <- Gen.integral (Range.constant 0 (n-1))
+ len <- Gen.integral (Range.constant 0 (n-lo))
+ return [(lo, len)])
+ (\_ -> OR.slice)
+ ,testProperty "slice nD" $ prop_negate_with
+ (\k -> genRank (k (Const ())))
+ (\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1))
+ len <- Gen.integral (Range.constant 0 (n-lo-1))
+ return (lo, len)
+ pairs <- mapM genPair (toList sh)
+ return pairs)
+ (\_ -> OR.slice)
+ ]
+ ]
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
new file mode 100644
index 0000000..98a6da5
--- /dev/null
+++ b/test/Tests/Permutation.hs
@@ -0,0 +1,39 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.Permutation where
+
+import Data.Type.Equality
+
+import Data.Array.Nested.Permutation
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+
+
+tests :: TestTree
+tests = testGroup "Permutation"
+ [testProperty "permCheckPermutation" $ property $ do
+ n <- forAll $ Gen.int (Range.linear 0 10)
+ list <- forAll $ genPermR n
+ let r = permFromList list $ \perm ->
+ permCheckPermutation perm ()
+ case r of
+ Just () -> return ()
+ Nothing -> failure
+ ,testProperty "permInverse" $ property $
+ genRank $ \n ->
+ genPerm n $ \perm ->
+ genStaticShX n $ \ssh ->
+ permInverse perm $ \_invperm proof ->
+ case proof ssh of
+ Refl -> return ()
+ ]
diff --git a/test/Util.hs b/test/Util.hs
new file mode 100644
index 0000000..8a5ba72
--- /dev/null
+++ b/test/Util.hs
@@ -0,0 +1,51 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Util where
+
+import Data.Array.RankedS qualified as OR
+import Data.Kind
+import GHC.TypeLits
+import Hedgehog
+import Hedgehog.Internal.Property (failDiff)
+
+import Data.Array.Nested.Types (fromSNat')
+
+
+-- Returns highest value that satisfies the predicate, or `lo` if none does
+binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
+binarySearch div2 = \lo hi f -> case (f lo, f hi) of
+ (False, _) -> lo
+ (_, True) -> hi
+ (_, _ ) -> go lo hi f
+ where
+ go lo hi f = -- invariant: f lo && not (f hi)
+ let mid = lo + div2 (hi - lo)
+ in if mid `elem` [lo, hi]
+ then mid
+ else if f mid then go mid hi f else go lo mid f
+
+orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
+orSumOuter1 (sn@SNat :: SNat n) =
+ let n = fromSNat' sn
+ in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+class AlmostEq f where
+ type AlmostEqConstr f :: Type -> Constraint
+ -- | absolute tolerance, lhs, rhs
+ almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
+ => a -> f a -> f a -> m ()
+
+instance AlmostEq (OR.Array n) where
+ type AlmostEqConstr (OR.Array n) = OR.Unbox
+ almostEq atol lhs rhs
+ | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
+ success
+ | otherwise =
+ failDiff lhs rhs