aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed/Shape
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-12-02 15:03:10 +0100
committerTom Smeding <tom@tomsmeding.com>2025-12-02 15:03:43 +0100
commitb63642a41f3bddc991d92f2f59b9e3ad53c1f15e (patch)
treea63b978be8baab76c7aa7a99b13a93b408bfc913 /src/Data/Array/Nested/Mixed/Shape
parentaf0c099079dae7aa52a660b883204035cbed99c3 (diff)
Provide ix*FromLinear for all three shape kinds
This speeds up {r,s}generatePrim
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape/Internal.hs50
1 files changed, 50 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
new file mode 100644
index 0000000..cf44522
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
@@ -0,0 +1,50 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Nested.Mixed.Shape.Internal where
+
+import Language.Haskell.TH
+
+
+-- | A TH stub function to avoid having to write the same code three times for
+-- the three kinds of shapes.
+ixFromLinearStub :: String
+ -> TypeQ -> TypeQ
+ -> PatQ -> (PatQ -> PatQ -> PatQ)
+ -> ExpQ -> ExpQ
+ -> ExpQ
+ -> DecsQ
+ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
+ let fname = mkName fname'
+ typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]
+
+ locals <- [d|
+ fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
+ fromLin $zshC _ !_ = $ixz
+ fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in $ixcons (fromIntegral q) (fromLin sh' suffs r)
+ fromLin _ _ _ = error "impossible"
+
+ {-# NOINLINE outrange #-}
+ outrange :: $ishty sh -> Int -> a
+ outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")" |]
+
+ body <- [|
+ \sh -> -- give this function arity 1 so that 'suffixes' is shared when
+ -- it's called many times
+ let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
+ in \i ->
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check
+ ($zshC, _) | i > 0 -> outrange sh i
+ | otherwise -> $ixz
+ ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= n then outrange sh i else
+ $ixcons (fromIntegral q) (fromLin sh' suffs r)
+ _ -> error "impossible" |]
+
+ return [SigD fname typesig
+ ,FunD fname [Clause [] (NormalB body) locals]]