diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..002c606 --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,100 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Main where + +import qualified Data.Array.RankedS as OR +import Data.Foldable (toList) +import Data.Type.Equality +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import qualified Data.Array.Mixed as X +import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) +import Data.Array.Nested +import qualified Data.Array.Nested.Internal as I + +-- test framework stuff +import Hedgehog +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +import Debug.Trace + + +genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () +genRank k = do + rank <- forAll $ Gen.int (Range.linear 0 8) + TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do + x <- Gen.realFloat (Range.linearFrac 0 1) + return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh + where + go :: Int -> [Int] -> IShR n -> Gen (IShR n) + go _ _ ZSR = return ZSR + go nbag bag (_ :$: sh) = do + idx <- Gen.int (Range.linear 0 (nbag - 1)) + let (dim, bag') = case splitAt idx bag of + (pre, n : post) -> (n, pre ++ post) + _ -> error "unreachable" + (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR sn = do + let n = fromSNat' sn + targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n)) + let genDims :: SNat m -> Int -> Gen (IShR m) + genDims SZ _ = return ZSR + genDims (SS m) 0 = do + dim <- Gen.int (Range.linear 0 20) + dims <- genDims m 0 + return (dim :$: dims) + genDims (SS m) tgt = do + dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) + ,(2 , return tgt) + ,(4 , return 1) + ,(1 , return 0)] + dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) + return (dim :$: dims) + shuffleShR =<< genDims sn targetSize + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = + let n = fromSNat' sn + in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +main :: IO () +main = defaultMain $ + testGroup "Tests" + [testGroup "C" + [testGroup "sum" + [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$> + Gen.list (Range.singleton (product sh)) + (Gen.realFloat (Range.linearFrac @Double 0 1)) + let rarr = rfromOrthotope inrank arr + annotateShow rarr + Refl <- return $ I.lemRankReplicate outrank + let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr + let rhs = orSumOuter1 outrank arr + annotateShow lhs + annotateShow rhs + lhs === rhs + ] + ] + ] |