aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-03 21:29:53 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-03 21:29:53 +0200
commitc5108efd1402dcb52beca27d13b4880eed35ef5b (patch)
treeb25e4ee26c1f894671db2e68c0afdaf6a1378cb5
parent0fd727dcb3fe05816aa9c68be5ebac84a55fcf4b (diff)
Properly test C reductions
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs20
-rw-r--r--test/Gen.hs38
-rw-r--r--test/Tests/C.hs107
-rw-r--r--test/Util.hs18
4 files changed, 141 insertions, 42 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index bb3ee4a..6417413 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -22,6 +22,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable (Storable)
import GHC.TypeLits
+import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
import System.IO.Unsafe
@@ -133,7 +134,6 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
VS.unsafeFreeze outv
| otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
--- TODO: test all the weird cases of this function
-- | Reduce along the inner dimension
{-# NOINLINE vectorRedInnerOp #-}
vectorRedInnerOp :: forall a b n. (Num a, Storable a)
@@ -155,9 +155,15 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
(RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
-- now there is useful work along the inner dimension
| otherwise =
- let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
- (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
- ndimsF = length shF
+ let -- replicated dimensions: dimensions with zero stride. The reduction
+ -- kernel need not concern itself with those (and in fact has a
+ -- precondition that there are no such dimensions in its input).
+ replDims = map (== 0) strides
+ -- filter out replicated dimensions
+ (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
+ -- replace replicated dimensions with ones
+ shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims
+ ndimsF = length shF -- > 0, otherwise `last strides == 0`
in unsafePerformIO $ do
outv <- VSM.unsafeNew (product (init shF))
VSM.unsafeWith outv $ \poutv ->
@@ -165,7 +171,11 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
- RS.fromVector (init sh) <$> VS.unsafeFreeze outv
+ TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
+ RS.stretch (init sh)
+ . RS.reshape (init shOnes)
+ . RS.fromVector @_ @lenFm1 (init shF)
+ <$> VS.unsafeFreeze outv
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
diff --git a/test/Gen.hs b/test/Gen.hs
index 29dffb2..559fecf 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -15,6 +15,7 @@ module Gen where
import Data.ByteString qualified as BS
import Data.Foldable (toList)
import Data.Type.Equality
+import Data.Type.Ord
import Data.Vector.Storable qualified as VS
import Foreign
import GHC.TypeLits
@@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested
+import Data.Array.Nested.Internal.Shape
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -81,6 +83,42 @@ genShR sn = do
cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
shuffleShR (min cap <$> dims)
+-- | Example: given 3 and 7, might return:
+--
+-- @
+-- ([ 13, 4, 27 ]
+-- ,[1, 13, 1, 1, 4, 27, 1]
+-- ,[4, 13, 1, 3, 4, 27, 2])
+-- @
+--
+-- The up-replicated dimensions are always nonzero and not very large, but the
+-- other dimensions might be zero.
+genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
+genReplicatedShR = \m n -> do
+ sh1 <- genShR m
+ (sh2, sh3) <- injectOnes n sh1 sh1
+ return (sh1, sh2, sh3)
+ where
+ injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
+ injectOnes n@SNat shOnes sh
+ | m@SNat <- shrLengthSNat sh
+ = case cmpNat n m of
+ LTI -> error "unreachable"
+ EQI -> return (shOnes, sh)
+ GTI -> do
+ index <- Gen.int (Range.linear 0 (fromSNat' m))
+ value <- Gen.int (Range.linear 1 5)
+ Refl <- return (lem n m)
+ injectOnes n (inject index 1 shOnes) (inject index value sh)
+
+ lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+ lem _ _ = unsafeCoerceRefl
+
+ inject :: Int -> Int -> IShR m -> IShR (m + 1)
+ inject 0 v sh = v :$: sh
+ inject i v (w :$: sh) = w :$: inject (i - 1) v sh
+ inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
n <- Gen.int rng
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 2a3949f..148e7f6 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -16,10 +16,8 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Shape
import Hedgehog
@@ -35,42 +33,77 @@ import Gen
import Util
+prop_sum_nonempty :: Property
+prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+
+prop_sum_empty :: Property
+prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ _outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (any (== 0) (toList (shrTail sh)))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @Double @(n + 1) (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+ let inrank = SNat @(n + 1)
+ outsh <- forAll $ genShR outrank
+ guard (all (> 0) (toList outsh))
+ let insh = shrAppend outsh (1 :$: ZSR)
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
+ genStorables (Range.singleton (product insh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = property $
+ genRank $ \inrank1@(SNat @m) ->
+ genRank $ \outrank@(SNat @nm1) -> do
+ inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
+ (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
+ LTI -> return Refl -- actually we only continue if m < n
+ _ -> discard
+ (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ guard (all (> 0) (toList sh3))
+ arr <- forAllT $
+ OR.stretch (toList sh3)
+ . OR.reshape (toList sh2)
+ . OR.fromVector @Double @m (toList sh1) <$>
+ genStorables (Range.singleton (product sh1))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ arrTrans <-
+ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
+ return $ OR.transpose perm arr
+ else return arr
+ let rarr = rfromOrthotope inrank2 arrTrans
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+
tests :: TestTree
tests = testGroup "C"
[testGroup "sum"
- [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
- -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
- let inrank = SNat @(n + 1)
- sh <- forAll $ genShR inrank
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
- genStorables (Range.singleton (product sh))
- (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- -- annotateShow rarr
- Refl <- return $ lemRankReplicate outrank
- let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
- let rhs = orSumOuter1 outrank arr
- -- annotateShow lhs
- -- annotateShow rhs
- lhs === rhs
-
- ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
- -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
- outrank :: SNat n <- return $ SNat @(nm1 + 1)
- let inrank = SNat @(n + 1)
- sh <- forAll $ do
- shtt <- genShR outrankm1 -- nm1
- sht <- shuffleShR (0 :$: shtt) -- n
- n <- Gen.int (Range.linear 0 20)
- return (n :$: sht) -- n + 1
- guard (any (== 0) (toList (shrTail sh)))
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- let arr = OR.fromList @Double @(n + 1) (toList sh) []
- let rarr = rfromOrthotope inrank arr
- Refl <- return $ lemRankReplicate outrank
- let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
- OR.toList lhs === []
+ [testProperty "nonempty" prop_sum_nonempty
+ ,testProperty "empty" prop_sum_empty
+ ,testProperty "last==1" prop_sum_lasteq1
+ ,testProperty "replicated" (prop_sum_replicated False)
+ ,testProperty "replicated_transposed" (prop_sum_replicated True)
]
]
diff --git a/test/Util.hs b/test/Util.hs
index f377e5b..ce6ec23 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -4,12 +4,16 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Util where
import Data.Array.RankedS qualified as OR
+import Data.Kind
+import Hedgehog
+import Hedgehog.Internal.Property (failDiff)
import GHC.TypeLits
import Data.Array.Mixed.Types (fromSNat')
@@ -32,3 +36,17 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n
orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+class AlmostEq f where
+ type AlmostEqConstr f :: Type -> Constraint
+ -- | absolute tolerance, lhs, rhs
+ almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
+ => a -> f a -> f a -> m ()
+
+instance KnownNat n => AlmostEq (OR.Array n) where
+ type AlmostEqConstr (OR.Array n) = OR.Unbox
+ almostEq atol lhs rhs
+ | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
+ success
+ | otherwise =
+ failDiff lhs rhs