aboutsummaryrefslogtreecommitdiff
path: root/test/Util.hs
blob: 9afa922e9a747733b3882afa3c12c29474575036 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Util where

import qualified Data.Array.RankedS as OR
import GHC.TypeLits

import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested


-- Returns highest value that satisfies the predicate, or `lo` if none does
binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
binarySearch div2 = \lo hi f -> case (f lo, f hi) of
  (False, _) -> lo
  (_, True) -> hi
  (_, _ ) -> go lo hi f
  where
    go lo hi f =  -- invariant: f lo && not (f hi)
      let mid = lo + div2 (hi - lo)
      in if mid `elem` [lo, hi]
           then mid
           else if f mid then go mid hi f else go lo mid f

orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
orSumOuter1 (sn@SNat :: SNat n) =
  let n = fromSNat' sn
  in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])

rshTail :: ShR (n + 1) i -> ShR n i
rshTail (_ :$: sh) = sh
rshTail ZSR = error "unreachable"