aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed/Shape/Internal.hs
blob: 9997b0f9dcd981d982499f011618790a763fde5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Nested.Mixed.Shape.Internal where

import Language.Haskell.TH


-- | A TH stub function to avoid having to write the same code three times for
-- the three kinds of shapes.
ixFromLinearStub :: String
                 -> TypeQ -> TypeQ
                 -> PatQ -> (PatQ -> PatQ -> PatQ)
                 -> ExpQ -> ExpQ
                 -> ExpQ
                 -> DecsQ
ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
  let fname = mkName fname'
  typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]

  locals <- [d|
    -- Unfold first iteration of fromLin to do the range check.
    -- Don't inline because if this is inlined, GHC seems to stop sharing
    -- 'suffixes' over multiple calls, which breaks performance in sh*Enum.
    {-# NOINLINE fromLin0 #-}
    fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
    fromLin0 sh suffixes i =
        if i < 0 then outrange sh i else
        case (sh, suffixes) of
          ($zshC, _) | i > 0 -> outrange sh i
                     | otherwise -> $ixz
          ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
            let (q, r) = i `quotRem` suff
            in if q >= n then outrange sh i else
               $ixcons (fromIntegral q) (fromLin sh' suffs r)
          _ -> error "impossible"

    fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
    fromLin $zshC _ !_ = $ixz
    fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
      let (q, r) = i `quotRem` suff  -- suff == shrSize sh'
      in $ixcons (fromIntegral q) (fromLin sh' suffs r)
    fromLin _ _ _ = error "impossible"

    {-# NOINLINE outrange #-}
    outrange :: $ishty sh -> Int -> a
    outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
                            " in array of shape " ++ show sh ++ ")" |]

  body <- [|
    \sh ->  -- give this function arity 1 so that 'suffixes' is shared when
            -- it's called many times
      let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
      in fromLin0 sh suffixes |]

  return [SigD fname typesig
         ,FunD fname [Clause [] (NormalB body) locals]]