blob: 6514fbf95dd1f36b480cc372d5c732393292a735 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Util where
import Data.Array.RankedS qualified as OR
import Data.Kind
import GHC.TypeLits
import Hedgehog
import Hedgehog.Internal.Property (failDiff)
import Data.Array.Nested.Types (fromSNat')
-- Returns highest value that satisfies the predicate, or `lo` if none does
binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
binarySearch div2 = \lo hi f -> case (f lo, f hi) of
(False, _) -> lo
(_, True) -> hi
(_, _ ) -> go lo hi f
where
go lo hi f = -- invariant: f lo && not (f hi)
let mid = lo + div2 (hi - lo)
in if mid `elem` [lo, hi]
then mid
else if f mid then go mid hi f else go lo mid f
orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
class AlmostEq t where
type EltOf t :: Type
-- | absolute tolerance, lhs, rhs
almostEq :: MonadTest m => EltOf t -> t -> t -> m ()
instance (OR.Unbox a, Ord a, Show a, Fractional a) => AlmostEq (OR.Array n a) where
type EltOf (OR.Array n a) = a
almostEq atol lhs rhs
| OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
success
| otherwise =
failDiff lhs rhs
instance AlmostEq Double where
type EltOf Double = Double
almostEq atol lhs rhs | abs (lhs - rhs) < atol = success
| otherwise = failDiff lhs rhs
|