{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Main where import Control.Monad import qualified Data.Array.RankedS as OR import qualified Data.ByteString as BS import Data.Foldable (toList) import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign import GHC.TypeLits import qualified GHC.TypeNats as TN import qualified Data.Array.Mixed as X import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) import Data.Array.Nested import qualified Data.Array.Nested.Internal as I -- test framework stuff import Hedgehog import Hedgehog.Internal.Property (forAllT) import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import qualified System.Random as Random import Test.Tasty import Test.Tasty.Hedgehog -- import Debug.Trace -- Returns highest value that satisfies the predicate, or `lo` if none does binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a binarySearch div2 = \lo hi f -> case (f lo, f hi) of (False, _) -> lo (_, True) -> hi (_, _ ) -> go lo hi f where go lo hi f = -- invariant: f lo && not (f hi) let mid = lo + div2 (hi - lo) in if mid `elem` [lo, hi] then mid else if f mid then go mid hi f else go lo mid f genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () genRank k = do rank <- forAll $ Gen.int (Range.linear 0 8) TN.withSomeSNat (fromIntegral rank) k genLowBiased :: RealFloat a => (a, a) -> Gen a genLowBiased (lo, hi) = do x <- Gen.realFloat (Range.linearFrac 0 1) return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR go nbag bag (_ :$: sh) = do idx <- Gen.int (Range.linear 0 (nbag - 1)) let (dim, bag') = case splitAt idx bag of (pre, n : post) -> (n, pre ++ post) _ -> error "unreachable" (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) genShR sn = do let n = fromSNat' sn targetSize <- Gen.int (Range.linear 0 100_000) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do dim <- Gen.int (Range.linear 0 20) dims <- genDims m 0 return (dim :$: dims) genDims (SS m) tgt = do dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) ,(2 , return tgt) ,(4 , return 1) ,(1 , return 0)] dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) dims <- genDims sn targetSize let dimsL = toList dims maxdim = maximum dimsL cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) shuffleShR (min cap <$> dims) genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng seed <- Gen.resize 99 $ Gen.int Range.linearBounded let gen0 = Random.mkStdGen seed (bs, _) = Random.genByteString (8 * n) gen0 let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) return $ VS.generate n (f . readW64) orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a orSumOuter1 (sn@SNat :: SNat n) = let n = fromSNat' sn in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) rshTail :: ShR (n + 1) i -> ShR n i rshTail (_ :$: sh) = sh rshTail ZSR = error "unreachable" main :: IO () main = defaultMain $ testGroup "Tests" [testGroup "C" [testGroup "sum" [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr -- annotateShow rarr Refl <- return $ I.lemRankReplicate outrank let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr let rhs = orSumOuter1 outrank arr -- annotateShow lhs -- annotateShow rhs lhs === rhs ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. outrank :: SNat n <- return $ SNat @(nm1 + 1) let inrank = SNat @(n + 1) sh <- forAll $ do shtt <- genShR outrankm1 -- nm1 sht <- shuffleShR (0 :$: shtt) -- n n <- Gen.int (Range.linear 0 20) return (n :$: sht) -- n + 1 guard (any (== 0) (toList (rshTail sh))) -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) let arr = OR.fromList @Double @(n + 1) (toList sh) [] let rarr = rfromOrthotope inrank arr Refl <- return $ I.lemRankReplicate outrank let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr OR.toList lhs === [] ] ] ]