aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed/Shape/Internal.hs
blob: cf44522eaa029fadbc91f24b142190d514dcab16 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Nested.Mixed.Shape.Internal where

import Language.Haskell.TH


-- | A TH stub function to avoid having to write the same code three times for
-- the three kinds of shapes.
ixFromLinearStub :: String
                 -> TypeQ -> TypeQ
                 -> PatQ -> (PatQ -> PatQ -> PatQ)
                 -> ExpQ -> ExpQ
                 -> ExpQ
                 -> DecsQ
ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
  let fname = mkName fname'
  typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]

  locals <- [d|
      fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
      fromLin $zshC _ !_ = $ixz
      fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
        let (q, r) = i `quotRem` suff  -- suff == shrSize sh'
        in $ixcons (fromIntegral q) (fromLin sh' suffs r)
      fromLin _ _ _ = error "impossible"

      {-# NOINLINE outrange #-}
      outrange :: $ishty sh -> Int -> a
      outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
                              " in array of shape " ++ show sh ++ ")" |]

  body <- [|
    \sh ->  -- give this function arity 1 so that 'suffixes' is shared when
            -- it's called many times
      let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
      in \i ->
          if i < 0 then outrange sh i else
          case (sh, suffixes) of  -- unfold first iteration of fromLin to do the range check
            ($zshC, _) | i > 0 -> outrange sh i
                       | otherwise -> $ixz
            ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
              let (q, r) = i `quotRem` suff
              in if q >= n then outrange sh i else
                 $ixcons (fromIntegral q) (fromLin sh' suffs r)
            _ -> error "impossible" |]

  return [SigD fname typesig
         ,FunD fname [Clause [] (NormalB body) locals]]