From c5108efd1402dcb52beca27d13b4880eed35ef5b Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 3 Jun 2024 21:29:53 +0200
Subject: Properly test C reductions

---
 test/Gen.hs     |  38 ++++++++++++++++++++
 test/Tests/C.hs | 107 ++++++++++++++++++++++++++++++++++++--------------------
 test/Util.hs    |  18 ++++++++++
 3 files changed, 126 insertions(+), 37 deletions(-)

(limited to 'test')

diff --git a/test/Gen.hs b/test/Gen.hs
index 29dffb2..559fecf 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -15,6 +15,7 @@ module Gen where
 import Data.ByteString qualified as BS
 import Data.Foldable (toList)
 import Data.Type.Equality
+import Data.Type.Ord
 import Data.Vector.Storable qualified as VS
 import Foreign
 import GHC.TypeLits
@@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation
 import Data.Array.Mixed.Shape
 import Data.Array.Mixed.Types
 import Data.Array.Nested
+import Data.Array.Nested.Internal.Shape
 
 import Hedgehog
 import Hedgehog.Gen qualified as Gen
@@ -81,6 +83,42 @@ genShR sn = do
       cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
   shuffleShR (min cap <$> dims)
 
+-- | Example: given 3 and 7, might return:
+--
+-- @
+-- ([   13,       4, 27   ]
+-- ,[1, 13, 1, 1, 4, 27, 1]
+-- ,[4, 13, 1, 3, 4, 27, 2])
+-- @
+--
+-- The up-replicated dimensions are always nonzero and not very large, but the
+-- other dimensions might be zero.
+genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
+genReplicatedShR = \m n -> do
+  sh1 <- genShR m
+  (sh2, sh3) <- injectOnes n sh1 sh1
+  return (sh1, sh2, sh3)
+  where
+    injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
+    injectOnes n@SNat shOnes sh
+      | m@SNat <- shrLengthSNat sh
+      = case cmpNat n m of
+          LTI -> error "unreachable"
+          EQI -> return (shOnes, sh)
+          GTI -> do
+            index <- Gen.int (Range.linear 0 (fromSNat' m))
+            value <- Gen.int (Range.linear 1 5)
+            Refl <- return (lem n m)
+            injectOnes n (inject index 1 shOnes) (inject index value sh)
+
+    lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+    lem _ _ = unsafeCoerceRefl
+
+    inject :: Int -> Int -> IShR m -> IShR (m + 1)
+    inject 0 v sh = v :$: sh
+    inject i v (w :$: sh) = w :$: inject (i - 1) v sh
+    inject _ v ZSR = v :$: ZSR  -- invalid input, but meh
+
 genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
 genStorables rng f = do
   n <- Gen.int rng
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 2a3949f..148e7f6 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -16,10 +16,8 @@ import Data.Type.Equality
 import Foreign
 import GHC.TypeLits
 
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Types (fromSNat')
 import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed
 import Data.Array.Nested.Internal.Shape
 
 import Hedgehog
@@ -35,42 +33,77 @@ import Gen
 import Util
 
 
+prop_sum_nonempty :: Property
+prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+  -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+  let inrank = SNat @(n + 1)
+  sh <- forAll $ genShR inrank
+  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+  guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail
+  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+           genStorables (Range.singleton (product sh))
+                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+  let rarr = rfromOrthotope inrank arr
+  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+
+prop_sum_empty :: Property
+prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+  -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+  _outrank :: SNat n <- return $ SNat @(nm1 + 1)
+  let inrank = SNat @(n + 1)
+  sh <- forAll $ do
+    shtt <- genShR outrankm1  -- nm1
+    sht <- shuffleShR (0 :$: shtt)  -- n
+    n <- Gen.int (Range.linear 0 20)
+    return (n :$: sht)  -- n + 1
+  guard (any (== 0) (toList (shrTail sh)))
+  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+  let arr = OR.fromList @Double @(n + 1) (toList sh) []
+  let rarr = rfromOrthotope inrank arr
+  OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+  let inrank = SNat @(n + 1)
+  outsh <- forAll $ genShR outrank
+  guard (all (> 0) (toList outsh))
+  let insh = shrAppend outsh (1 :$: ZSR)
+  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
+           genStorables (Range.singleton (product insh))
+                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+  let rarr = rfromOrthotope inrank arr
+  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = property $
+  genRank $ \inrank1@(SNat @m) ->
+  genRank $ \outrank@(SNat @nm1) -> do
+    inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
+    (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
+      LTI -> return Refl  -- actually we only continue if m < n
+      _ -> discard
+    (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+    guard (all (> 0) (toList sh3))
+    arr <- forAllT $
+      OR.stretch (toList sh3)
+      . OR.reshape (toList sh2)
+      . OR.fromVector @Double @m (toList sh1) <$>
+               genStorables (Range.singleton (product sh1))
+                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+    arrTrans <-
+      if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
+                             return $ OR.transpose perm arr
+                     else return arr
+    let rarr = rfromOrthotope inrank2 arrTrans
+    almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+
 tests :: TestTree
 tests = testGroup "C"
   [testGroup "sum"
-    [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
-      -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
-      let inrank = SNat @(n + 1)
-      sh <- forAll $ genShR inrank
-      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
-      guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail
-      arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
-               genStorables (Range.singleton (product sh))
-                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
-      let rarr = rfromOrthotope inrank arr
-      -- annotateShow rarr
-      Refl <- return $ lemRankReplicate outrank
-      let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
-      let rhs = orSumOuter1 outrank arr
-      -- annotateShow lhs
-      -- annotateShow rhs
-      lhs === rhs
-
-    ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
-      -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
-      outrank :: SNat n <- return $ SNat @(nm1 + 1)
-      let inrank = SNat @(n + 1)
-      sh <- forAll $ do
-        shtt <- genShR outrankm1  -- nm1
-        sht <- shuffleShR (0 :$: shtt)  -- n
-        n <- Gen.int (Range.linear 0 20)
-        return (n :$: sht)  -- n + 1
-      guard (any (== 0) (toList (shrTail sh)))
-      -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
-      let arr = OR.fromList @Double @(n + 1) (toList sh) []
-      let rarr = rfromOrthotope inrank arr
-      Refl <- return $ lemRankReplicate outrank
-      let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
-      OR.toList lhs === []
+    [testProperty "nonempty" prop_sum_nonempty
+    ,testProperty "empty" prop_sum_empty
+    ,testProperty "last==1" prop_sum_lasteq1
+    ,testProperty "replicated" (prop_sum_replicated False)
+    ,testProperty "replicated_transposed" (prop_sum_replicated True)
     ]
   ]
diff --git a/test/Util.hs b/test/Util.hs
index f377e5b..ce6ec23 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -4,12 +4,16 @@
 {-# LANGUAGE PatternSynonyms #-}
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
 {-# LANGUAGE TypeOperators #-}
 {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
 {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
 module Util where
 
 import Data.Array.RankedS qualified as OR
+import Data.Kind
+import Hedgehog
+import Hedgehog.Internal.Property (failDiff)
 import GHC.TypeLits
 
 import Data.Array.Mixed.Types (fromSNat')
@@ -32,3 +36,17 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n
 orSumOuter1 (sn@SNat :: SNat n) =
   let n = fromSNat' sn
   in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+class AlmostEq f where
+  type AlmostEqConstr f :: Type -> Constraint
+  -- | absolute tolerance, lhs, rhs
+  almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
+           => a -> f a -> f a -> m ()
+
+instance KnownNat n => AlmostEq (OR.Array n) where
+  type AlmostEqConstr (OR.Array n) = OR.Unbox
+  almostEq atol lhs rhs
+    | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
+        success
+    | otherwise =
+        failDiff lhs rhs
-- 
cgit v1.2.3-70-g09d2