diff options
Diffstat (limited to 'test/Util.hs')
-rw-r--r-- | test/Util.hs | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/test/Util.hs b/test/Util.hs new file mode 100644 index 0000000..1249bf9 --- /dev/null +++ b/test/Util.hs @@ -0,0 +1,38 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Util where + +import qualified Data.Array.RankedS as OR +import GHC.TypeLits + +import Data.Array.Mixed (fromSNat') +import Data.Array.Nested + + +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of + (False, _) -> lo + (_, True) -> hi + (_, _ ) -> go lo hi f + where + go lo hi f = -- invariant: f lo && not (f hi) + let mid = lo + div2 (hi - lo) + in if mid `elem` [lo, hi] + then mid + else if f mid then go mid hi f else go lo mid f + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = + let n = fromSNat' sn + in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +rshTail :: ShR (n + 1) i -> ShR n i +rshTail (_ :$: sh) = sh +rshTail ZSR = error "unreachable" |