aboutsummaryrefslogtreecommitdiff
path: root/test/Util.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Util.hs')
-rw-r--r--test/Util.hs38
1 files changed, 38 insertions, 0 deletions
diff --git a/test/Util.hs b/test/Util.hs
new file mode 100644
index 0000000..1249bf9
--- /dev/null
+++ b/test/Util.hs
@@ -0,0 +1,38 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Util where
+
+import qualified Data.Array.RankedS as OR
+import GHC.TypeLits
+
+import Data.Array.Mixed (fromSNat')
+import Data.Array.Nested
+
+
+-- Returns highest value that satisfies the predicate, or `lo` if none does
+binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
+binarySearch div2 = \lo hi f -> case (f lo, f hi) of
+ (False, _) -> lo
+ (_, True) -> hi
+ (_, _ ) -> go lo hi f
+ where
+ go lo hi f = -- invariant: f lo && not (f hi)
+ let mid = lo + div2 (hi - lo)
+ in if mid `elem` [lo, hi]
+ then mid
+ else if f mid then go mid hi f else go lo mid f
+
+orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
+orSumOuter1 (sn@SNat :: SNat n) =
+ let n = fromSNat' sn
+ in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+rshTail :: ShR (n + 1) i -> ShR n i
+rshTail (_ :$: sh) = sh
+rshTail ZSR = error "unreachable"