diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 | 
| commit | c5108efd1402dcb52beca27d13b4880eed35ef5b (patch) | |
| tree | b25e4ee26c1f894671db2e68c0afdaf6a1378cb5 /test/Tests | |
| parent | 0fd727dcb3fe05816aa9c68be5ebac84a55fcf4b (diff) | |
Properly test C reductions
Diffstat (limited to 'test/Tests')
| -rw-r--r-- | test/Tests/C.hs | 107 | 
1 files changed, 70 insertions, 37 deletions
| diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 2a3949f..148e7f6 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -16,10 +16,8 @@ import Data.Type.Equality  import Foreign  import GHC.TypeLits -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types (fromSNat')  import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Shape  import Hedgehog @@ -35,42 +33,77 @@ import Gen  import Util +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do +  -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. +  let inrank = SNat @(n + 1) +  sh <- forAll $ genShR inrank +  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +  guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail +  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> +           genStorables (Range.singleton (product sh)) +                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +  let rarr = rfromOrthotope inrank arr +  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do +  -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. +  _outrank :: SNat n <- return $ SNat @(nm1 + 1) +  let inrank = SNat @(n + 1) +  sh <- forAll $ do +    shtt <- genShR outrankm1  -- nm1 +    sht <- shuffleShR (0 :$: shtt)  -- n +    n <- Gen.int (Range.linear 0 20) +    return (n :$: sht)  -- n + 1 +  guard (any (== 0) (toList (shrTail sh))) +  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) +  let arr = OR.fromList @Double @(n + 1) (toList sh) [] +  let rarr = rfromOrthotope inrank arr +  OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do +  let inrank = SNat @(n + 1) +  outsh <- forAll $ genShR outrank +  guard (all (> 0) (toList outsh)) +  let insh = shrAppend outsh (1 :$: ZSR) +  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> +           genStorables (Range.singleton (product insh)) +                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +  let rarr = rfromOrthotope inrank arr +  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ +  genRank $ \inrank1@(SNat @m) -> +  genRank $ \outrank@(SNat @nm1) -> do +    inrank2 :: SNat n <- return $ SNat @(nm1 + 1) +    (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of +      LTI -> return Refl  -- actually we only continue if m < n +      _ -> discard +    (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 +    guard (all (> 0) (toList sh3)) +    arr <- forAllT $ +      OR.stretch (toList sh3) +      . OR.reshape (toList sh2) +      . OR.fromVector @Double @m (toList sh1) <$> +               genStorables (Range.singleton (product sh1)) +                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +    arrTrans <- +      if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) +                             return $ OR.transpose perm arr +                     else return arr +    let rarr = rfromOrthotope inrank2 arrTrans +    almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) +  tests :: TestTree  tests = testGroup "C"    [testGroup "sum" -    [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do -      -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. -      let inrank = SNat @(n + 1) -      sh <- forAll $ genShR inrank -      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -      guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail -      arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> -               genStorables (Range.singleton (product sh)) -                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) -      let rarr = rfromOrthotope inrank arr -      -- annotateShow rarr -      Refl <- return $ lemRankReplicate outrank -      let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr -      let rhs = orSumOuter1 outrank arr -      -- annotateShow lhs -      -- annotateShow rhs -      lhs === rhs - -    ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do -      -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. -      outrank :: SNat n <- return $ SNat @(nm1 + 1) -      let inrank = SNat @(n + 1) -      sh <- forAll $ do -        shtt <- genShR outrankm1  -- nm1 -        sht <- shuffleShR (0 :$: shtt)  -- n -        n <- Gen.int (Range.linear 0 20) -        return (n :$: sht)  -- n + 1 -      guard (any (== 0) (toList (shrTail sh))) -      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -      let arr = OR.fromList @Double @(n + 1) (toList sh) [] -      let rarr = rfromOrthotope inrank arr -      Refl <- return $ lemRankReplicate outrank -      let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr -      OR.toList lhs === [] +    [testProperty "nonempty" prop_sum_nonempty +    ,testProperty "empty" prop_sum_empty +    ,testProperty "last==1" prop_sum_lasteq1 +    ,testProperty "replicated" (prop_sum_replicated False) +    ,testProperty "replicated_transposed" (prop_sum_replicated True)      ]    ] | 
