diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Gen.hs | 38 | ||||
| -rw-r--r-- | test/Tests/C.hs | 107 | ||||
| -rw-r--r-- | test/Util.hs | 18 | 
3 files changed, 126 insertions, 37 deletions
| diff --git a/test/Gen.hs b/test/Gen.hs index 29dffb2..559fecf 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -15,6 +15,7 @@ module Gen where  import Data.ByteString qualified as BS  import Data.Foldable (toList)  import Data.Type.Equality +import Data.Type.Ord  import Data.Vector.Storable qualified as VS  import Foreign  import GHC.TypeLits @@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Nested +import Data.Array.Nested.Internal.Shape  import Hedgehog  import Hedgehog.Gen qualified as Gen @@ -81,6 +83,42 @@ genShR sn = do        cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)    shuffleShR (min cap <$> dims) +-- | Example: given 3 and 7, might return: +-- +-- @ +-- ([   13,       4, 27   ] +-- ,[1, 13, 1, 1, 4, 27, 1] +-- ,[4, 13, 1, 3, 4, 27, 2]) +-- @ +-- +-- The up-replicated dimensions are always nonzero and not very large, but the +-- other dimensions might be zero. +genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) +genReplicatedShR = \m n -> do +  sh1 <- genShR m +  (sh2, sh3) <- injectOnes n sh1 sh1 +  return (sh1, sh2, sh3) +  where +    injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) +    injectOnes n@SNat shOnes sh +      | m@SNat <- shrLengthSNat sh +      = case cmpNat n m of +          LTI -> error "unreachable" +          EQI -> return (shOnes, sh) +          GTI -> do +            index <- Gen.int (Range.linear 0 (fromSNat' m)) +            value <- Gen.int (Range.linear 1 5) +            Refl <- return (lem n m) +            injectOnes n (inject index 1 shOnes) (inject index value sh) + +    lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True +    lem _ _ = unsafeCoerceRefl + +    inject :: Int -> Int -> IShR m -> IShR (m + 1) +    inject 0 v sh = v :$: sh +    inject i v (w :$: sh) = w :$: inject (i - 1) v sh +    inject _ v ZSR = v :$: ZSR  -- invalid input, but meh +  genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)  genStorables rng f = do    n <- Gen.int rng diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 2a3949f..148e7f6 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -16,10 +16,8 @@ import Data.Type.Equality  import Foreign  import GHC.TypeLits -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types (fromSNat')  import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Shape  import Hedgehog @@ -35,42 +33,77 @@ import Gen  import Util +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do +  -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. +  let inrank = SNat @(n + 1) +  sh <- forAll $ genShR inrank +  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +  guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail +  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> +           genStorables (Range.singleton (product sh)) +                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +  let rarr = rfromOrthotope inrank arr +  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do +  -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. +  _outrank :: SNat n <- return $ SNat @(nm1 + 1) +  let inrank = SNat @(n + 1) +  sh <- forAll $ do +    shtt <- genShR outrankm1  -- nm1 +    sht <- shuffleShR (0 :$: shtt)  -- n +    n <- Gen.int (Range.linear 0 20) +    return (n :$: sht)  -- n + 1 +  guard (any (== 0) (toList (shrTail sh))) +  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +  let arr = OR.fromList @Double @(n + 1) (toList sh) [] +  let rarr = rfromOrthotope inrank arr +  OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do +  let inrank = SNat @(n + 1) +  outsh <- forAll $ genShR outrank +  guard (all (> 0) (toList outsh)) +  let insh = shrAppend outsh (1 :$: ZSR) +  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> +           genStorables (Range.singleton (product insh)) +                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +  let rarr = rfromOrthotope inrank arr +  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ +  genRank $ \inrank1@(SNat @m) -> +  genRank $ \outrank@(SNat @nm1) -> do +    inrank2 :: SNat n <- return $ SNat @(nm1 + 1) +    (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of +      LTI -> return Refl  -- actually we only continue if m < n +      _ -> discard +    (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 +    guard (all (> 0) (toList sh3)) +    arr <- forAllT $ +      OR.stretch (toList sh3) +      . OR.reshape (toList sh2) +      . OR.fromVector @Double @m (toList sh1) <$> +               genStorables (Range.singleton (product sh1)) +                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +    arrTrans <- +      if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) +                             return $ OR.transpose perm arr +                     else return arr +    let rarr = rfromOrthotope inrank2 arrTrans +    almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) +  tests :: TestTree  tests = testGroup "C"    [testGroup "sum" -    [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do -      -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. -      let inrank = SNat @(n + 1) -      sh <- forAll $ genShR inrank -      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -      guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail -      arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> -               genStorables (Range.singleton (product sh)) -                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) -      let rarr = rfromOrthotope inrank arr -      -- annotateShow rarr -      Refl <- return $ lemRankReplicate outrank -      let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr -      let rhs = orSumOuter1 outrank arr -      -- annotateShow lhs -      -- annotateShow rhs -      lhs === rhs - -    ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do -      -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. -      outrank :: SNat n <- return $ SNat @(nm1 + 1) -      let inrank = SNat @(n + 1) -      sh <- forAll $ do -        shtt <- genShR outrankm1  -- nm1 -        sht <- shuffleShR (0 :$: shtt)  -- n -        n <- Gen.int (Range.linear 0 20) -        return (n :$: sht)  -- n + 1 -      guard (any (== 0) (toList (shrTail sh))) -      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -      let arr = OR.fromList @Double @(n + 1) (toList sh) [] -      let rarr = rfromOrthotope inrank arr -      Refl <- return $ lemRankReplicate outrank -      let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr -      OR.toList lhs === [] +    [testProperty "nonempty" prop_sum_nonempty +    ,testProperty "empty" prop_sum_empty +    ,testProperty "last==1" prop_sum_lasteq1 +    ,testProperty "replicated" (prop_sum_replicated False) +    ,testProperty "replicated_transposed" (prop_sum_replicated True)      ]    ] diff --git a/test/Util.hs b/test/Util.hs index f377e5b..ce6ec23 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -4,12 +4,16 @@  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Util where  import Data.Array.RankedS qualified as OR +import Data.Kind +import Hedgehog +import Hedgehog.Internal.Property (failDiff)  import GHC.TypeLits  import Data.Array.Mixed.Types (fromSNat') @@ -32,3 +36,17 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n  orSumOuter1 (sn@SNat :: SNat n) =    let n = fromSNat' sn    in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +class AlmostEq f where +  type AlmostEqConstr f :: Type -> Constraint +  -- | absolute tolerance, lhs, rhs +  almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) +           => a -> f a -> f a -> m () + +instance KnownNat n => AlmostEq (OR.Array n) where +  type AlmostEqConstr (OR.Array n) = OR.Unbox +  almostEq atol lhs rhs +    | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = +        success +    | otherwise = +        failDiff lhs rhs | 
