aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: dd59586b5e1c1eff16b0b70a918e2296c566a5e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where

import Control.Monad
import qualified Data.Array.RankedS as OR
import qualified Data.ByteString as BS
import Data.Foldable (toList)
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
import Foreign
import GHC.TypeLits
import qualified GHC.TypeNats as TN

import qualified Data.Array.Mixed as X
import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
import Data.Array.Nested
import qualified Data.Array.Nested.Internal as I

-- test framework stuff
import Hedgehog
import Hedgehog.Internal.Property (forAllT)
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import qualified System.Random as Random
import Test.Tasty
import Test.Tasty.Hedgehog

import Debug.Trace


-- Returns highest value that satisfies the predicate, or `lo` if none does
binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
binarySearch div2 = \lo hi f -> case (f lo, f hi) of
  (False, _) -> lo
  (_, True) -> hi
  (_, _ ) -> go lo hi f
  where
    go lo hi f =  -- invariant: f lo && not (f hi)
      let mid = lo + div2 (hi - lo)
      in if mid `elem` [lo, hi]
           then mid
           else if f mid then go mid hi f else go lo mid f

genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
genRank k = do
  rank <- forAll $ Gen.int (Range.linear 0 8)
  TN.withSomeSNat (fromIntegral rank) k

genLowBiased :: RealFloat a => (a, a) -> Gen a
genLowBiased (lo, hi) = do
  x <- Gen.realFloat (Range.linearFrac 0 1)
  return (lo + x * x * x * (hi - lo))

shuffleShR :: IShR n -> Gen (IShR n)
shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
  where
    go :: Int -> [Int] -> IShR n -> Gen (IShR n)
    go _    _   ZSR = return ZSR
    go nbag bag (_ :$: sh) = do
      idx <- Gen.int (Range.linear 0 (nbag - 1))
      let (dim, bag') = case splitAt idx bag of
                          (pre, n : post) -> (n, pre ++ post)
                          _ -> error "unreachable"
      (dim :$:) <$> go (nbag - 1) bag' sh

genShR :: SNat n -> Gen (IShR n)
genShR sn = do
  let n = fromSNat' sn
  targetSize <- Gen.int (Range.linear 0 100_000)
  let genDims :: SNat m -> Int -> Gen (IShR m)
      genDims SZ _ = return ZSR
      genDims (SS m) 0 = do
        dim <- Gen.int (Range.linear 0 20)
        dims <- genDims m 0
        return (dim :$: dims)
      genDims (SS m) tgt = do
        dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
                             ,(2     , return tgt)
                             ,(4     , return 1)
                             ,(1     , return 0)]
        dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
        return (dim :$: dims)
  dims <- genDims sn targetSize
  let dimsL = toList dims
      maxdim = maximum dimsL
      cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
  shuffleShR (min cap <$> dims)

genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
  n <- Gen.int rng
  seed <- Gen.resize 99 $ Gen.int Range.linearBounded
  let gen0 = Random.mkStdGen seed
      (bs, _) = Random.genByteString (8 * n) gen0
  let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
  return $ VS.generate n (f . readW64)

orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
orSumOuter1 (sn@SNat :: SNat n) =
  let n = fromSNat' sn
  in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])

main :: IO ()
main = defaultMain $
  testGroup "Tests"
    [testGroup "C"
      [testGroup "sum"
        [testProperty "random nonempty" $ property $ genRank $ \outrank@(SNat @n) -> do
          -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
          let inrank = SNat @(n + 1)
          sh <- forAll $ genShR inrank
          -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
          guard (all (> 0) (tail (toList sh)))  -- only constrain the tail
          arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
                   genStorables (Range.singleton (product sh))
                                (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
          let rarr = rfromOrthotope inrank arr
          -- annotateShow rarr
          Refl <- return $ I.lemRankReplicate outrank
          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
          let rhs = orSumOuter1 outrank arr
          -- annotateShow lhs
          -- annotateShow rhs
          lhs === rhs

        ,testProperty "random empty" $ property $ genRank $ \outrankm1@(SNat @nm1) -> do
          -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
          outrank :: SNat n <- return $ SNat @(nm1 + 1)
          let inrank = SNat @(n + 1)
          sh <- forAll $ do
            shtt <- genShR outrankm1  -- nm1
            sht <- shuffleShR (0 :$: shtt)  -- n
            n <- Gen.int (Range.linear 0 20)
            return (n :$: sht)  -- n + 1
          guard (any (== 0) (tail (toList sh)))
          -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
          let arr = OR.fromList @Double @(n + 1) (toList sh) []
          let rarr = rfromOrthotope inrank arr
          Refl <- return $ I.lemRankReplicate outrank
          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
          OR.toList lhs === []
        ]
      ]
    ]