diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 100 | 
1 files changed, 100 insertions, 0 deletions
| diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..002c606 --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,100 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Main where + +import qualified Data.Array.RankedS as OR +import Data.Foldable (toList) +import Data.Type.Equality +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import qualified Data.Array.Mixed as X +import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) +import Data.Array.Nested +import qualified Data.Array.Nested.Internal as I + +-- test framework stuff +import Hedgehog +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Range as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +import Debug.Trace + + +genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () +genRank k = do +  rank <- forAll $ Gen.int (Range.linear 0 8) +  TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do +  x <- Gen.realFloat (Range.linearFrac 0 1) +  return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh +  where +    go :: Int -> [Int] -> IShR n -> Gen (IShR n) +    go _    _   ZSR = return ZSR +    go nbag bag (_ :$: sh) = do +      idx <- Gen.int (Range.linear 0 (nbag - 1)) +      let (dim, bag') = case splitAt idx bag of +                          (pre, n : post) -> (n, pre ++ post) +                          _ -> error "unreachable" +      (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR sn = do +  let n = fromSNat' sn +  targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n)) +  let genDims :: SNat m -> Int -> Gen (IShR m) +      genDims SZ _ = return ZSR +      genDims (SS m) 0 = do +        dim <- Gen.int (Range.linear 0 20) +        dims <- genDims m 0 +        return (dim :$: dims) +      genDims (SS m) tgt = do +        dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) +                             ,(2     , return tgt) +                             ,(4     , return 1) +                             ,(1     , return 0)] +        dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) +        return (dim :$: dims) +  shuffleShR =<< genDims sn targetSize + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = +  let n = fromSNat' sn +  in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +main :: IO () +main = defaultMain $ +  testGroup "Tests" +    [testGroup "C" +      [testGroup "sum" +        [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do +          let inrank = SNat @(n + 1) +          sh <- forAll $ genShR inrank +          arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$> +                   Gen.list (Range.singleton (product sh)) +                            (Gen.realFloat (Range.linearFrac @Double 0 1)) +          let rarr = rfromOrthotope inrank arr +          annotateShow rarr +          Refl <- return $ I.lemRankReplicate outrank +          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr +          let rhs = orSumOuter1 outrank arr +          annotateShow lhs +          annotateShow rhs +          lhs === rhs +        ] +      ] +    ] | 
