1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where
import Control.Monad
import qualified Data.Array.RankedS as OR
import qualified Data.ByteString as BS
import Data.Foldable (toList)
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
import Foreign
import GHC.TypeLits
import qualified GHC.TypeNats as TN
import qualified Data.Array.Mixed as X
import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
import Data.Array.Nested
import qualified Data.Array.Nested.Internal as I
-- test framework stuff
import Hedgehog
import Hedgehog.Internal.Property (forAllT)
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import qualified System.Random as Random
import Test.Tasty
import Test.Tasty.Hedgehog
import Debug.Trace
-- Returns highest value that satisfies the predicate, or `lo` if none does
binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
binarySearch div2 = \lo hi f -> case (f lo, f hi) of
(False, _) -> lo
(_, True) -> hi
(_, _ ) -> go lo hi f
where
go lo hi f = -- invariant: f lo && not (f hi)
let mid = lo + div2 (hi - lo)
in if mid `elem` [lo, hi]
then mid
else if f mid then go mid hi f else go lo mid f
genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
genRank k = do
rank <- forAll $ Gen.int (Range.linear 0 8)
TN.withSomeSNat (fromIntegral rank) k
genLowBiased :: RealFloat a => (a, a) -> Gen a
genLowBiased (lo, hi) = do
x <- Gen.realFloat (Range.linearFrac 0 1)
return (lo + x * x * x * (hi - lo))
shuffleShR :: IShR n -> Gen (IShR n)
shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
where
go :: Int -> [Int] -> IShR n -> Gen (IShR n)
go _ _ ZSR = return ZSR
go nbag bag (_ :$: sh) = do
idx <- Gen.int (Range.linear 0 (nbag - 1))
let (dim, bag') = case splitAt idx bag of
(pre, n : post) -> (n, pre ++ post)
_ -> error "unreachable"
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
genShR sn = do
let n = fromSNat' sn
targetSize <- Gen.int (Range.linear 0 100_000)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
dim <- Gen.int (Range.linear 0 20)
dims <- genDims m 0
return (dim :$: dims)
genDims (SS m) tgt = do
dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
,(2 , return tgt)
,(4 , return 1)
,(1 , return 0)]
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
dims <- genDims sn targetSize
let dimsL = toList dims
maxdim = maximum dimsL
cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
shuffleShR (min cap <$> dims)
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
n <- Gen.int rng
seed <- Gen.resize 99 $ Gen.int Range.linearBounded
let gen0 = Random.mkStdGen seed
(bs, _) = Random.genByteString (8 * n) gen0
let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
return $ VS.generate n (f . readW64)
orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
main :: IO ()
main = defaultMain $
testGroup "Tests"
[testGroup "C"
[testGroup "sum"
[testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
guard (all (> 0) (tail (toList sh))) -- only constrain the tail
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
-- annotateShow rarr
Refl <- return $ I.lemRankReplicate outrank
let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
let rhs = orSumOuter1 outrank arr
-- annotateShow lhs
-- annotateShow rhs
lhs === rhs
,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
-- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
outrank :: SNat n <- return $ SNat @(nm1 + 1)
let inrank = SNat @(n + 1)
sh <- forAll $ do
shtt <- genShR outrankm1 -- nm1
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
guard (any (== 0) (tail (toList sh)))
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @Double @(n + 1) (toList sh) []
let rarr = rfromOrthotope inrank arr
Refl <- return $ I.lemRankReplicate outrank
let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
OR.toList lhs === []
]
]
]
|