aboutsummaryrefslogtreecommitdiff
path: root/test/Tests
diff options
context:
space:
mode:
Diffstat (limited to 'test/Tests')
-rw-r--r--test/Tests/C.hs160
-rw-r--r--test/Tests/Permutation.hs39
2 files changed, 199 insertions, 0 deletions
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
new file mode 100644
index 0000000..9567393
--- /dev/null
+++ b/test/Tests/C.hs
@@ -0,0 +1,160 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.C where
+
+import Control.Monad
+import Data.Array.RankedS qualified as OR
+import Data.Foldable (toList)
+import Data.Functor.Const
+import Data.Type.Equality
+import Foreign
+import GHC.TypeLits
+
+import Data.Array.Nested
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
+import Hedgehog.Range qualified as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+import Util
+
+
+-- | Appropriate for simple different summation orders
+fineTol :: Double
+fineTol = 1e-8
+
+debugCoverage :: Bool
+debugCoverage = False
+
+prop_sum_nonempty :: Property
+prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (shrTail sh)) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_empty :: Property
+prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ _outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (0 `elem` shrTail sh)
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @(n + 1) @Double (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+ let inrank = SNat @(n + 1)
+ outsh <- forAll $ genShR outrank
+ guard (all (> 0) outsh)
+ let insh = shrAppend outsh (1 :$: ZSR)
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
+ genStorables (Range.singleton (product insh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = property $
+ genRank $ \inrank1@(SNat @m) ->
+ genRank $ \outrank@(SNat @nm1) -> do
+ inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
+ (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
+ LTI -> return Refl -- actually we only continue if m < n
+ _ -> discard
+ (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
+ guard (all (> 0) sh3)
+ arr <- forAllT $
+ OR.stretch (toList sh3)
+ . OR.reshape (toList sh2)
+ . OR.fromVector @Double @m (toList sh1) <$>
+ genStorables (Range.singleton (product sh1))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ arrTrans <-
+ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
+ return $ OR.transpose perm arr
+ else return arr
+ let rarr = rfromOrthotope inrank2 arrTrans
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+
+prop_negate_with :: forall f b. Show b
+ => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
+ -> (forall n. f n -> IShR n -> Gen b)
+ -> (forall n. f n -> b -> OR.Array n Double -> OR.Array n Double)
+ -> Property
+prop_negate_with genRank' genB preproc = property $
+ genRank' $ \extra rank@(SNat @n) -> do
+ sh <- forAll $ genShR rank
+ guard (all (> 0) sh)
+ arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ bval <- forAll $ genB extra sh
+ let arr' = preproc extra bval arr
+ annotate (show (OR.shapeL arr'))
+ let rarr = rfromOrthotope rank arr'
+ rtoOrthotope (negate rarr) === OR.mapA negate arr'
+
+tests :: TestTree
+tests = testGroup "C"
+ [testGroup "sum"
+ [testProperty "nonempty" prop_sum_nonempty
+ ,testProperty "empty" prop_sum_empty
+ ,testProperty "last==1" prop_sum_lasteq1
+ ,testProperty "replicated" (prop_sum_replicated False)
+ ,testProperty "replicated_transposed" (prop_sum_replicated True)
+ ]
+ ,testGroup "negate"
+ [testProperty "normalised" $ prop_negate_with
+ (\k -> genRank (k (Const ())))
+ (\_ _ -> pure ())
+ (\_ _ -> id)
+ ,testProperty "slice 1D" $ prop_negate_with @((:~:) 1)
+ (\k -> k Refl (SNat @1))
+ (\Refl (n :$: _) -> do lo <- Gen.integral (Range.constant 0 (n-1))
+ len <- Gen.integral (Range.constant 0 (n-lo))
+ return [(lo, len)])
+ (\_ -> OR.slice)
+ ,testProperty "slice nD" $ prop_negate_with
+ (\k -> genRank (k (Const ())))
+ (\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1))
+ len <- Gen.integral (Range.constant 0 (n-lo-1))
+ return (lo, len)
+ pairs <- mapM genPair (toList sh)
+ return pairs)
+ (\_ -> OR.slice)
+ ]
+ ]
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
new file mode 100644
index 0000000..98a6da5
--- /dev/null
+++ b/test/Tests/Permutation.hs
@@ -0,0 +1,39 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.Permutation where
+
+import Data.Type.Equality
+
+import Data.Array.Nested.Permutation
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+
+
+tests :: TestTree
+tests = testGroup "Permutation"
+ [testProperty "permCheckPermutation" $ property $ do
+ n <- forAll $ Gen.int (Range.linear 0 10)
+ list <- forAll $ genPermR n
+ let r = permFromList list $ \perm ->
+ permCheckPermutation perm ()
+ case r of
+ Just () -> return ()
+ Nothing -> failure
+ ,testProperty "permInverse" $ property $
+ genRank $ \n ->
+ genPerm n $ \perm ->
+ genStaticShX n $ \ssh ->
+ permInverse perm $ \_invperm proof ->
+ case proof ssh of
+ Refl -> return ()
+ ]