aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 002c606b88328a9e7e63184b68c6ffa05d73738a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where

import qualified Data.Array.RankedS as OR
import Data.Foldable (toList)
import Data.Type.Equality
import GHC.TypeLits
import qualified GHC.TypeNats as TN

import qualified Data.Array.Mixed as X
import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
import Data.Array.Nested
import qualified Data.Array.Nested.Internal as I

-- test framework stuff
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Tasty
import Test.Tasty.Hedgehog

import Debug.Trace


genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO ()
genRank k = do
  rank <- forAll $ Gen.int (Range.linear 0 8)
  TN.withSomeSNat (fromIntegral rank) k

genLowBiased :: RealFloat a => (a, a) -> Gen a
genLowBiased (lo, hi) = do
  x <- Gen.realFloat (Range.linearFrac 0 1)
  return (lo + x * x * x * (hi - lo))

shuffleShR :: IShR n -> Gen (IShR n)
shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh
  where
    go :: Int -> [Int] -> IShR n -> Gen (IShR n)
    go _    _   ZSR = return ZSR
    go nbag bag (_ :$: sh) = do
      idx <- Gen.int (Range.linear 0 (nbag - 1))
      let (dim, bag') = case splitAt idx bag of
                          (pre, n : post) -> (n, pre ++ post)
                          _ -> error "unreachable"
      (dim :$:) <$> go (nbag - 1) bag' sh

genShR :: SNat n -> Gen (IShR n)
genShR sn = do
  let n = fromSNat' sn
  targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n))
  let genDims :: SNat m -> Int -> Gen (IShR m)
      genDims SZ _ = return ZSR
      genDims (SS m) 0 = do
        dim <- Gen.int (Range.linear 0 20)
        dims <- genDims m 0
        return (dim :$: dims)
      genDims (SS m) tgt = do
        dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
                             ,(2     , return tgt)
                             ,(4     , return 1)
                             ,(1     , return 0)]
        dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
        return (dim :$: dims)
  shuffleShR =<< genDims sn targetSize

orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
orSumOuter1 (sn@SNat :: SNat n) =
  let n = fromSNat' sn
  in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])

main :: IO ()
main = defaultMain $
  testGroup "Tests"
    [testGroup "C"
      [testGroup "sum"
        [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do
          let inrank = SNat @(n + 1)
          sh <- forAll $ genShR inrank
          arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$>
                   Gen.list (Range.singleton (product sh))
                            (Gen.realFloat (Range.linearFrac @Double 0 1))
          let rarr = rfromOrthotope inrank arr
          annotateShow rarr
          Refl <- return $ I.lemRankReplicate outrank
          let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
          let rhs = orSumOuter1 outrank arr
          annotateShow lhs
          annotateShow rhs
          lhs === rhs
        ]
      ]
    ]