aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 23:13:02 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 23:13:02 +0200
commit2ef5f573f15ff2406b568b7bae79d0cf52eecc4b (patch)
tree3124e68d0fbcbcac791b2367694c0a2188cc51d8
parent6a8520f957531bd0e41bd8adde9dedbf1cc916be (diff)
WIP testing
-rw-r--r--ox-arrays.cabal15
-rw-r--r--test/Main.hs100
2 files changed, 115 insertions, 0 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 3f4fa5b..94a4529 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -28,6 +28,21 @@ library
ghc-options: -Wall
other-extensions: TemplateHaskell
+test-suite test
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ build-depends:
+ ox-arrays,
+ base,
+ ghc-typelits-knownnat,
+ hedgehog,
+ orthotope,
+ tasty,
+ tasty-hedgehog
+ hs-source-dirs: test
+ default-language: Haskell2010
+ ghc-options: -Wall
+
test-suite example
type: exitcode-stdio-1.0
main-is: Main.hs
diff --git a/test/Main.hs b/test/Main.hs
new file mode 100644
index 0000000..002c606
--- /dev/null
+++ b/test/Main.hs
@@ -0,0 +1,100 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE DataKinds #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Main where
+
+import qualified Data.Array.RankedS as OR
+import Data.Foldable (toList)
+import Data.Type.Equality
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+
+import qualified Data.Array.Mixed as X
+import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
+import Data.Array.Nested
+import qualified Data.Array.Nested.Internal as I
+
+-- test framework stuff
+import Hedgehog
+import qualified Hedgehog.Gen as Gen
+import qualified Hedgehog.Range as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+import Debug.Trace
+
+
+genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
+genRank k = do
+ rank <- forAll $ Gen.int (Range.linear 0 8)
+ TN.withSomeSNat (fromIntegral rank) k
+
+genLowBiased :: RealFloat a => (a, a) -> Gen a
+genLowBiased (lo, hi) = do
+ x <- Gen.realFloat (Range.linearFrac 0 1)
+ return (lo + x * x * x * (hi - lo))
+
+shuffleShR :: IShR n -> Gen (IShR n)
+shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
+ where
+ go :: Int -> [Int] -> IShR n -> Gen (IShR n)
+ go _ _ ZSR = return ZSR
+ go nbag bag (_ :$: sh) = do
+ idx <- Gen.int (Range.linear 0 (nbag - 1))
+ let (dim, bag') = case splitAt idx bag of
+ (pre, n : post) -> (n, pre ++ post)
+ _ -> error "unreachable"
+ (dim :$:) <$> go (nbag - 1) bag' sh
+
+genShR :: SNat n -> Gen (IShR n)
+genShR sn = do
+ let n = fromSNat' sn
+ targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n))
+ let genDims :: SNat m -> Int -> Gen (IShR m)
+ genDims SZ _ = return ZSR
+ genDims (SS m) 0 = do
+ dim <- Gen.int (Range.linear 0 20)
+ dims <- genDims m 0
+ return (dim :$: dims)
+ genDims (SS m) tgt = do
+ dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
+ ,(2 , return tgt)
+ ,(4 , return 1)
+ ,(1 , return 0)]
+ dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
+ return (dim :$: dims)
+ shuffleShR =<< genDims sn targetSize
+
+orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
+orSumOuter1 (sn@SNat :: SNat n) =
+ let n = fromSNat' sn
+ in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+main :: IO ()
+main = defaultMain $
+ testGroup "Tests"
+ [testGroup "C"
+ [testGroup "sum"
+ [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$>
+ Gen.list (Range.singleton (product sh))
+ (Gen.realFloat (Range.linearFrac @Double 0 1))
+ let rarr = rfromOrthotope inrank arr
+ annotateShow rarr
+ Refl <- return $ I.lemRankReplicate outrank
+ let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
+ let rhs = orSumOuter1 outrank arr
+ annotateShow lhs
+ annotateShow rhs
+ lhs === rhs
+ ]
+ ]
+ ]