diff options
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 20 | ||||
-rw-r--r-- | test/Gen.hs | 38 | ||||
-rw-r--r-- | test/Tests/C.hs | 107 | ||||
-rw-r--r-- | test/Util.hs | 18 |
4 files changed, 141 insertions, 42 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index bb3ee4a..6417413 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -22,6 +22,7 @@ import Foreign.C.Types import Foreign.Ptr import Foreign.Storable (Storable) import GHC.TypeLits +import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH import System.IO.Unsafe @@ -133,7 +134,6 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases VS.unsafeFreeze outv | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) --- TODO: test all the weird cases of this function -- | Reduce along the inner dimension {-# NOINLINE vectorRedInnerOp #-} vectorRedInnerOp :: forall a b n. (Num a, Storable a) @@ -155,9 +155,15 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) -- now there is useful work along the inner dimension | otherwise = - let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those - (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) - ndimsF = length shF + let -- replicated dimensions: dimensions with zero stride. The reduction + -- kernel need not concern itself with those (and in fact has a + -- precondition that there are no such dimensions in its input). + replDims = map (== 0) strides + -- filter out replicated dimensions + (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) + -- replace replicated dimensions with ones + shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims + ndimsF = length shF -- > 0, otherwise `last strides == 0` in unsafePerformIO $ do outv <- VSM.unsafeNew (product (init shF)) VSM.unsafeWith outv $ \poutv -> @@ -165,7 +171,11 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) - RS.fromVector (init sh) <$> VS.unsafeFreeze outv + TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> + RS.stretch (init sh) + . RS.reshape (init shOnes) + . RS.fromVector @_ @lenFm1 (init shF) + <$> VS.unsafeFreeze outv flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () diff --git a/test/Gen.hs b/test/Gen.hs index 29dffb2..559fecf 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -15,6 +15,7 @@ module Gen where import Data.ByteString qualified as BS import Data.Foldable (toList) import Data.Type.Equality +import Data.Type.Ord import Data.Vector.Storable qualified as VS import Foreign import GHC.TypeLits @@ -24,6 +25,7 @@ import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested +import Data.Array.Nested.Internal.Shape import Hedgehog import Hedgehog.Gen qualified as Gen @@ -81,6 +83,42 @@ genShR sn = do cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) shuffleShR (min cap <$> dims) +-- | Example: given 3 and 7, might return: +-- +-- @ +-- ([ 13, 4, 27 ] +-- ,[1, 13, 1, 1, 4, 27, 1] +-- ,[4, 13, 1, 3, 4, 27, 2]) +-- @ +-- +-- The up-replicated dimensions are always nonzero and not very large, but the +-- other dimensions might be zero. +genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) +genReplicatedShR = \m n -> do + sh1 <- genShR m + (sh2, sh3) <- injectOnes n sh1 sh1 + return (sh1, sh2, sh3) + where + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) + injectOnes n@SNat shOnes sh + | m@SNat <- shrLengthSNat sh + = case cmpNat n m of + LTI -> error "unreachable" + EQI -> return (shOnes, sh) + GTI -> do + index <- Gen.int (Range.linear 0 (fromSNat' m)) + value <- Gen.int (Range.linear 1 5) + Refl <- return (lem n m) + injectOnes n (inject index 1 shOnes) (inject index value sh) + + lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem _ _ = unsafeCoerceRefl + + inject :: Int -> Int -> IShR m -> IShR (m + 1) + inject 0 v sh = v :$: sh + inject i v (w :$: sh) = w :$: inject (i - 1) v sh + inject _ v ZSR = v :$: ZSR -- invalid input, but meh + genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 2a3949f..148e7f6 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -16,10 +16,8 @@ import Data.Type.Equality import Foreign import GHC.TypeLits -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types (fromSNat') import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape import Hedgehog @@ -35,42 +33,77 @@ import Gen import Util +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + _outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (any (== 0) (toList (shrTail sh))) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @Double @(n + 1) (toList sh) [] + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do + let inrank = SNat @(n + 1) + outsh <- forAll $ genShR outrank + guard (all (> 0) (toList outsh)) + let insh = shrAppend outsh (1 :$: ZSR) + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> + genStorables (Range.singleton (product insh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ + genRank $ \inrank1@(SNat @m) -> + genRank $ \outrank@(SNat @nm1) -> do + inrank2 :: SNat n <- return $ SNat @(nm1 + 1) + (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of + LTI -> return Refl -- actually we only continue if m < n + _ -> discard + (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 + guard (all (> 0) (toList sh3)) + arr <- forAllT $ + OR.stretch (toList sh3) + . OR.reshape (toList sh2) + . OR.fromVector @Double @m (toList sh1) <$> + genStorables (Range.singleton (product sh1)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + arrTrans <- + if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) + return $ OR.transpose perm arr + else return arr + let rarr = rfromOrthotope inrank2 arrTrans + almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + tests :: TestTree tests = testGroup "C" [testGroup "sum" - [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do - -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. - let inrank = SNat @(n + 1) - sh <- forAll $ genShR inrank - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> - genStorables (Range.singleton (product sh)) - (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - -- annotateShow rarr - Refl <- return $ lemRankReplicate outrank - let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - let rhs = orSumOuter1 outrank arr - -- annotateShow lhs - -- annotateShow rhs - lhs === rhs - - ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do - -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. - outrank :: SNat n <- return $ SNat @(nm1 + 1) - let inrank = SNat @(n + 1) - sh <- forAll $ do - shtt <- genShR outrankm1 -- nm1 - sht <- shuffleShR (0 :$: shtt) -- n - n <- Gen.int (Range.linear 0 20) - return (n :$: sht) -- n + 1 - guard (any (== 0) (toList (shrTail sh))) - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - let arr = OR.fromList @Double @(n + 1) (toList sh) [] - let rarr = rfromOrthotope inrank arr - Refl <- return $ lemRankReplicate outrank - let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr - OR.toList lhs === [] + [testProperty "nonempty" prop_sum_nonempty + ,testProperty "empty" prop_sum_empty + ,testProperty "last==1" prop_sum_lasteq1 + ,testProperty "replicated" (prop_sum_replicated False) + ,testProperty "replicated_transposed" (prop_sum_replicated True) ] ] diff --git a/test/Util.hs b/test/Util.hs index f377e5b..ce6ec23 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -4,12 +4,16 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Util where import Data.Array.RankedS qualified as OR +import Data.Kind +import Hedgehog +import Hedgehog.Internal.Property (failDiff) import GHC.TypeLits import Data.Array.Mixed.Types (fromSNat') @@ -32,3 +36,17 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n orSumOuter1 (sn@SNat :: SNat n) = let n = fromSNat' sn in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +class AlmostEq f where + type AlmostEqConstr f :: Type -> Constraint + -- | absolute tolerance, lhs, rhs + almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) + => a -> f a -> f a -> m () + +instance KnownNat n => AlmostEq (OR.Array n) where + type AlmostEqConstr (OR.Array n) = OR.Unbox + almostEq atol lhs rhs + | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = + success + | otherwise = + failDiff lhs rhs |