{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Tests.C where import Control.Monad import Data.Array.RankedS qualified as OR import Data.Foldable (toList) import Data.Type.Equality import Foreign import GHC.TypeLits import Data.Array.Mixed.Types (fromSNat') import Data.Array.Nested import Data.Array.Nested.Internal.Shape import Hedgehog import Hedgehog.Internal.Property (forAllT) import Hedgehog.Gen qualified as Gen import Hedgehog.Range qualified as Range import Test.Tasty import Test.Tasty.Hedgehog -- import Debug.Trace import Gen import Util prop_sum_nonempty :: Property prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr prop_sum_empty :: Property prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. _outrank :: SNat n <- return $ SNat @(nm1 + 1) let inrank = SNat @(n + 1) sh <- forAll $ do shtt <- genShR outrankm1 -- nm1 sht <- shuffleShR (0 :$: shtt) -- n n <- Gen.int (Range.linear 0 20) return (n :$: sht) -- n + 1 guard (any (== 0) (toList (shrTail sh))) -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) let arr = OR.fromList @Double @(n + 1) (toList sh) [] let rarr = rfromOrthotope inrank arr OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] prop_sum_lasteq1 :: Property prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do let inrank = SNat @(n + 1) outsh <- forAll $ genShR outrank guard (all (> 0) (toList outsh)) let insh = shrAppend outsh (1 :$: ZSR) arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> genStorables (Range.singleton (product insh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr prop_sum_replicated :: Bool -> Property prop_sum_replicated doTranspose = property $ genRank $ \inrank1@(SNat @m) -> genRank $ \outrank@(SNat @nm1) -> do inrank2 :: SNat n <- return $ SNat @(nm1 + 1) (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of LTI -> return Refl -- actually we only continue if m < n _ -> discard (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 guard (all (> 0) (toList sh3)) arr <- forAllT $ OR.stretch (toList sh3) . OR.reshape (toList sh2) . OR.fromVector @Double @m (toList sh1) <$> genStorables (Range.singleton (product sh1)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) arrTrans <- if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) return $ OR.transpose perm arr else return arr let rarr = rfromOrthotope inrank2 arrTrans almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) tests :: TestTree tests = testGroup "C" [testGroup "sum" [testProperty "nonempty" prop_sum_nonempty ,testProperty "empty" prop_sum_empty ,testProperty "last==1" prop_sum_lasteq1 ,testProperty "replicated" (prop_sum_replicated False) ,testProperty "replicated_transposed" (prop_sum_replicated True) ] ]