aboutsummaryrefslogtreecommitdiff
path: root/test/Tests/C.hs
blob: 148e7f615b5cc573f5d455b3d5252c255c060032 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Tests.C where

import Control.Monad
import Data.Array.RankedS qualified as OR
import Data.Foldable (toList)
import Data.Type.Equality
import Foreign
import GHC.TypeLits

import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
import Data.Array.Nested.Internal.Shape

import Hedgehog
import Hedgehog.Internal.Property (forAllT)
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog

-- import Debug.Trace

import Gen
import Util


prop_sum_nonempty :: Property
prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
  -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
  let inrank = SNat @(n + 1)
  sh <- forAll $ genShR inrank
  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
  guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail
  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
           genStorables (Range.singleton (product sh))
                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
  let rarr = rfromOrthotope inrank arr
  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr

prop_sum_empty :: Property
prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
  -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
  _outrank :: SNat n <- return $ SNat @(nm1 + 1)
  let inrank = SNat @(n + 1)
  sh <- forAll $ do
    shtt <- genShR outrankm1  -- nm1
    sht <- shuffleShR (0 :$: shtt)  -- n
    n <- Gen.int (Range.linear 0 20)
    return (n :$: sht)  -- n + 1
  guard (any (== 0) (toList (shrTail sh)))
  -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
  let arr = OR.fromList @Double @(n + 1) (toList sh) []
  let rarr = rfromOrthotope inrank arr
  OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []

prop_sum_lasteq1 :: Property
prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
  let inrank = SNat @(n + 1)
  outsh <- forAll $ genShR outrank
  guard (all (> 0) (toList outsh))
  let insh = shrAppend outsh (1 :$: ZSR)
  arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
           genStorables (Range.singleton (product insh))
                        (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
  let rarr = rfromOrthotope inrank arr
  rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr

prop_sum_replicated :: Bool -> Property
prop_sum_replicated doTranspose = property $
  genRank $ \inrank1@(SNat @m) ->
  genRank $ \outrank@(SNat @nm1) -> do
    inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
    (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
      LTI -> return Refl  -- actually we only continue if m < n
      _ -> discard
    (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
    guard (all (> 0) (toList sh3))
    arr <- forAllT $
      OR.stretch (toList sh3)
      . OR.reshape (toList sh2)
      . OR.fromVector @Double @m (toList sh1) <$>
               genStorables (Range.singleton (product sh1))
                            (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
    arrTrans <-
      if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
                             return $ OR.transpose perm arr
                     else return arr
    let rarr = rfromOrthotope inrank2 arrTrans
    almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)

tests :: TestTree
tests = testGroup "C"
  [testGroup "sum"
    [testProperty "nonempty" prop_sum_nonempty
    ,testProperty "empty" prop_sum_empty
    ,testProperty "last==1" prop_sum_lasteq1
    ,testProperty "replicated" (prop_sum_replicated False)
    ,testProperty "replicated_transposed" (prop_sum_replicated True)
    ]
  ]