aboutsummaryrefslogtreecommitdiff
path: root/test/Util.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Util.hs')
-rw-r--r--test/Util.hs18
1 files changed, 18 insertions, 0 deletions
diff --git a/test/Util.hs b/test/Util.hs
index f377e5b..ce6ec23 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -4,12 +4,16 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Util where
import Data.Array.RankedS qualified as OR
+import Data.Kind
+import Hedgehog
+import Hedgehog.Internal.Property (failDiff)
import GHC.TypeLits
import Data.Array.Mixed.Types (fromSNat')
@@ -32,3 +36,17 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n
orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+class AlmostEq f where
+ type AlmostEqConstr f :: Type -> Constraint
+ -- | absolute tolerance, lhs, rhs
+ almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
+ => a -> f a -> f a -> m ()
+
+instance KnownNat n => AlmostEq (OR.Array n) where
+ type AlmostEqConstr (OR.Array n) = OR.Unbox
+ almostEq atol lhs rhs
+ | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
+ success
+ | otherwise =
+ failDiff lhs rhs