blob: 148e7f615b5cc573f5d455b3d5252c255c060032 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Tests.C where
import Control.Monad
import Data.Array.RankedS qualified as OR
import Data.Foldable (toList)
import Data.Type.Equality
import Foreign
import GHC.TypeLits
import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
import Data.Array.Nested.Internal.Shape
import Hedgehog
import Hedgehog.Internal.Property (forAllT)
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
-- import Debug.Trace
import Gen
import Util
prop_sum_nonempty :: Property
prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
guard (all (> 0) (toList (shrTail sh))) -- only constrain the tail
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
prop_sum_empty :: Property
prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
-- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
_outrank :: SNat n <- return $ SNat @(nm1 + 1)
let inrank = SNat @(n + 1)
sh <- forAll $ do
shtt <- genShR outrankm1 -- nm1
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
guard (any (== 0) (toList (shrTail sh)))
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @Double @(n + 1) (toList sh) []
let rarr = rfromOrthotope inrank arr
OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
prop_sum_lasteq1 :: Property
prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
guard (all (> 0) (toList outsh))
let insh = shrAppend outsh (1 :$: ZSR)
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
genStorables (Range.singleton (product insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
prop_sum_replicated :: Bool -> Property
prop_sum_replicated doTranspose = property $
genRank $ \inrank1@(SNat @m) ->
genRank $ \outrank@(SNat @nm1) -> do
inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
(Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
guard (all (> 0) (toList sh3))
arr <- forAllT $
OR.stretch (toList sh3)
. OR.reshape (toList sh2)
. OR.fromVector @Double @m (toList sh1) <$>
genStorables (Range.singleton (product sh1))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
arrTrans <-
if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
return $ OR.transpose perm arr
else return arr
let rarr = rfromOrthotope inrank2 arrTrans
almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
tests :: TestTree
tests = testGroup "C"
[testGroup "sum"
[testProperty "nonempty" prop_sum_nonempty
,testProperty "empty" prop_sum_empty
,testProperty "last==1" prop_sum_lasteq1
,testProperty "replicated" (prop_sum_replicated False)
,testProperty "replicated_transposed" (prop_sum_replicated True)
]
]
|