aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ox-arrays.cabal1
-rw-r--r--test/Main.hs11
2 files changed, 9 insertions, 3 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 192471a..af985f4 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -37,6 +37,7 @@ test-suite test
base,
bytestring,
ghc-typelits-knownnat,
+ ghc-typelits-natnormalise,
hedgehog,
orthotope,
random,
diff --git a/test/Main.hs b/test/Main.hs
index dd59586..b5237e5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Main where
@@ -34,7 +35,7 @@ import qualified System.Random as Random
import Test.Tasty
import Test.Tasty.Hedgehog
-import Debug.Trace
+-- import Debug.Trace
-- Returns highest value that satisfies the predicate, or `lo` if none does
@@ -109,6 +110,10 @@ orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+rshTail :: ShR (n + 1) i -> ShR n i
+rshTail (_ :$: sh) = sh
+rshTail ZSR = error "unreachable"
+
main :: IO ()
main = defaultMain $
testGroup "Tests"
@@ -119,7 +124,7 @@ main = defaultMain $
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (tail (toList sh))) -- only constrain the tail
+ guard (all (> 0) (toList (rshTail sh))) -- only constrain the tail
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
@@ -141,7 +146,7 @@ main = defaultMain $
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (any (== 0) (tail (toList sh)))
+ guard (any (== 0) (toList (rshTail sh)))
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @Double @(n + 1) (toList sh) []
let rarr = rfromOrthotope inrank arr