aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs34
-rw-r--r--src/Data/Array/Nested/Mixed/Shape/Internal.hs50
-rw-r--r--src/Data/Array/Nested/Ranked.hs11
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs5
-rw-r--r--src/Data/Array/Nested/Shaped.hs5
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs5
6 files changed, 75 insertions, 35 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 8aa5a77..5a45a09 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -44,6 +45,7 @@ import GHC.TypeLits
import GHC.TypeLits.Orphans ()
#endif
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
@@ -276,33 +278,6 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
-{-# INLINEABLE ixxFromLinear #-}
-ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i
-ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times
- let suffixes = drop 1 (scanr (*) 1 (shxToList sh))
- in \i ->
- if i < 0 then outrange sh i else
- case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check
- (ZSX, _) | i > 0 -> outrange sh i
- | otherwise -> ZIX
- (n :$% sh', suff : suffs) ->
- let (q, r) = i `quotRem` suff
- in if q >= fromSMayNat' n then outrange sh i else
- fromIntegral q :.% fromLin sh' suffs r
- _ -> error "impossible"
- where
- fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
- fromLin ZSX _ !_ = ZIX
- fromLin (_ :$% sh') (suff : suffs) i =
- let (q, r) = i `quotRem` suff -- suff == shrSize sh'
- in fromIntegral q :.% fromLin sh' suffs r
- fromLin _ _ _ = error "impossible"
-
- {-# NOINLINE outrange #-}
- outrange :: IShX sh -> Int -> a
- outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
-
ixxToLinear :: IShX sh -> IIxX sh -> Int
ixxToLinear = \sh i -> fst (go sh i)
where
@@ -684,3 +659,8 @@ instance KnownShX sh => IsList (ShX sh Int) where
type Item (ShX sh Int) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
+
+-- This needs to be at the bottom of the file to not split the file into
+-- pieces; some of the shape/index stuff refers to StaticShX.
+$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |])
+{-# INLINEABLE ixxFromLinear #-}
diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
new file mode 100644
index 0000000..cf44522
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
@@ -0,0 +1,50 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Nested.Mixed.Shape.Internal where
+
+import Language.Haskell.TH
+
+
+-- | A TH stub function to avoid having to write the same code three times for
+-- the three kinds of shapes.
+ixFromLinearStub :: String
+ -> TypeQ -> TypeQ
+ -> PatQ -> (PatQ -> PatQ -> PatQ)
+ -> ExpQ -> ExpQ
+ -> ExpQ
+ -> DecsQ
+ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
+ let fname = mkName fname'
+ typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]
+
+ locals <- [d|
+ fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
+ fromLin $zshC _ !_ = $ixz
+ fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in $ixcons (fromIntegral q) (fromLin sh' suffs r)
+ fromLin _ _ _ = error "impossible"
+
+ {-# NOINLINE outrange #-}
+ outrange :: $ishty sh -> Int -> a
+ outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")" |]
+
+ body <- [|
+ \sh -> -- give this function arity 1 so that 'suffixes' is shared when
+ -- it's called many times
+ let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
+ in \i ->
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check
+ ($zshC, _) | i > 0 -> outrange sh i
+ | otherwise -> $ixz
+ ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= n then outrange sh i else
+ $ixcons (fromIntegral q) (fromLin sh' suffs r)
+ _ -> error "impossible" |]
+
+ return [SigD fname typesig
+ ,FunD fname [Clause [] (NormalB body) locals]]
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index b77b529..d687983 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -70,16 +70,13 @@ rgenerate sh f
, Refl <- lemRankReplicate sn
= Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
--- TODO: this would be shorter and faster written with rfromVector,
--- but unfortunately we don't have ixrFromLinear
+-- | See 'mgeneratePrim'.
{-# INLINE rgeneratePrim #-}
rgeneratePrim :: forall n a i. (PrimElt a, Num i)
=> IShR n -> (IxR n i -> a) -> Ranked n a
-rgeneratePrim sh f
- | sn@SNat <- shrRank sh
- , Dict <- lemKnownReplicate sn
- , Refl <- lemRankReplicate sn
- = Ranked (mgeneratePrim (shxFromShR sh) (f . ixrFromIxX))
+rgeneratePrim sh f =
+ let g i = f (ixrFromLinear sh i)
+ in rfromVector sh $ VS.generate (shrSize sh) g
-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. Elt a
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 9815c42..02d65b6 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -19,6 +19,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -43,6 +44,7 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
@@ -417,3 +419,6 @@ listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
listrCastWithName _ SZ ZR = ZR
listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
+
+$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
+{-# INLINEABLE ixrFromLinear #-}
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 075549d..99ad590 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -72,10 +72,13 @@ sindexPartial sarr@(Shaped arr) idx =
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+-- | See 'mgeneratePrim'.
{-# INLINE sgeneratePrim #-}
sgeneratePrim :: forall sh a i. (PrimElt a, Num i)
=> ShS sh -> (IxS sh i -> a) -> Shaped sh a
-sgeneratePrim sh f = Shaped (mgeneratePrim (shxFromShS sh) (f . ixsFromIxX sh))
+sgeneratePrim sh f =
+ let g i = f (ixsFromLinear sh i)
+ in sfromVector sh $ VS.generate (shsSize sh) g
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. Elt a
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0a4c1b9..a237b88 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -17,6 +17,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -44,6 +45,7 @@ import GHC.IsList qualified as IsList
import GHC.TypeLits
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -465,3 +467,6 @@ instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
fromList = shsFromList (knownShS @sh)
toList = shsToList
+
+$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |])
+{-# INLINEABLE ixsFromLinear #-}