aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-25 12:27:32 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-25 12:27:32 +0200
commit13433346340e4376d8bc286f2e883f57e3962314 (patch)
tree8d9f80e90e79ca7c11368dd0eefcabab63288371 /test/Main.hs
parent84b00455fff01c21953262325c1b6fca69b16ff0 (diff)
Less warnings in test
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs11
1 files changed, 8 insertions, 3 deletions
diff --git a/test/Main.hs b/test/Main.hs
index dd59586..b5237e5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where
@@ -34,7 +35,7 @@ import qualified System.Random as Random
import Test.Tasty
import Test.Tasty.Hedgehog
-import Debug.Trace
+-- import Debug.Trace
-- Returns highest value that satisfies the predicate, or `lo` if none does
@@ -109,6 +110,10 @@ orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+rshTail :: ShR (n + 1) i -> ShR n i
+rshTail (_ :$: sh) = sh
+rshTail ZSR = error "unreachable"
+
main :: IO ()
main = defaultMain $
testGroup "Tests"
@@ -119,7 +124,7 @@ main = defaultMain $
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (tail (toList sh))) -- only constrain the tail
+ guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
@@ -141,7 +146,7 @@ main = defaultMain $
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (any (== 0) (tail (toList sh)))
+ guard (any (== 0) (toList (rshTail sh)))
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @Double @(n + 1) (toList sh) []
let rarr = rfromOrthotope inrank arr