aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Convert.hs7
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs41
-rw-r--r--src/Data/Array/Nested/Mixed/Shape/Internal.hs59
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs14
-rw-r--r--src/Data/Array/Nested/Shaped.hs6
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs12
6 files changed, 58 insertions, 81 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 32248c4..91752c4 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -63,8 +63,7 @@ import Data.Array.Nested.Types
ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
ixrFromIxS = unsafeCoerce
-ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX = unsafeCoerce
+-- ixrFromIxX re-exported
shrFromShS :: ShS sh -> IShR (Rank sh)
shrFromShS ZSS = ZSR
@@ -97,9 +96,7 @@ ixsFromIxR' ZSS ZIR = ZIS
ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
--- TODO: remove the ShS now that no KnownNats is inside IxS.
-ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX _ = unsafeCoerce
+-- ixsFromIxX re-exported
-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS.
-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 5ffd40c..ebf0a07 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -46,7 +46,6 @@ import GHC.TypeLits
import GHC.TypeLits.Orphans ()
#endif
-import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
@@ -293,6 +292,41 @@ ixxToLinear = \sh i -> go sh i 0
go ZSX ZIX a = a
go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i)
+{-# INLINEABLE ixxFromLinear #-}
+ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i
+ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times
+ let suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+ in fromLin0 sh suffixes
+ where
+ -- Unfold first iteration of fromLin to do the range check.
+ -- Don't inline this function at first to allow GHC to inline the outer
+ -- function and realise that 'suffixes' is shared. But then later inline it
+ -- anyway, to avoid the function call. Removing the pragma makes GHC
+ -- somehow unable to recognise that 'suffixes' can be shared in a loop.
+ {-# NOINLINE [0] fromLin0 #-}
+ fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
+ fromLin0 sh suffixes i =
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of
+ (ZSX, _) | i > 0 -> outrange sh i
+ | otherwise -> ZIX
+ ((fromSMayNat' -> n) :$% sh', suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= n then outrange sh i else
+ fromIntegral q :.% fromLin sh' suffs r
+ _ -> error "impossible"
+
+ fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
+ fromLin ZSX _ !_ = ZIX
+ fromLin (_ :$% sh') (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in fromIntegral q :.% fromLin sh' suffs r
+ fromLin _ _ _ = error "impossible"
+
+ {-# NOINLINE outrange #-}
+ outrange :: IShX sh -> Int -> a
+ outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
-- * Mixed shape-like lists to be used for ShX and StaticShX
@@ -798,8 +832,3 @@ instance KnownShX sh => IsList (IShX sh) where
type Item (IShX sh) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
-
--- This needs to be at the bottom of the file to not split the file into
--- pieces; some of the shape/index stuff refers to StaticShX.
-$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |])
-{-# INLINEABLE ixxFromLinear #-}
diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
deleted file mode 100644
index 2a86ac1..0000000
--- a/src/Data/Array/Nested/Mixed/Shape/Internal.hs
+++ /dev/null
@@ -1,59 +0,0 @@
-{-# LANGUAGE BangPatterns #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Nested.Mixed.Shape.Internal where
-
-import Language.Haskell.TH
-
-
--- | A TH stub function to avoid having to write the same code three times for
--- the three kinds of shapes.
-ixFromLinearStub :: String
- -> TypeQ -> TypeQ
- -> PatQ -> (PatQ -> PatQ -> PatQ)
- -> ExpQ -> ExpQ
- -> ExpQ
- -> DecsQ
-ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
- let fname = mkName fname'
- typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]
-
- locals <- [d|
- -- Unfold first iteration of fromLin to do the range check.
- -- Don't inline this function at first to allow GHC to inline the outer
- -- function and realise that 'suffixes' is shared. But then later inline it
- -- anyway, to avoid the function call. Removing the pragma makes GHC
- -- somehow unable to recognise that 'suffixes' can be shared in a loop.
- {-# NOINLINE [0] fromLin0 #-}
- fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
- fromLin0 sh suffixes i =
- if i < 0 then outrange sh i else
- case (sh, suffixes) of
- ($zshC, _) | i > 0 -> outrange sh i
- | otherwise -> $ixz
- ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
- let (q, r) = i `quotRem` suff
- in if q >= n then outrange sh i else
- $ixcons (fromIntegral q) (fromLin sh' suffs r)
- _ -> error "impossible"
-
- fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
- fromLin $zshC _ !_ = $ixz
- fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
- let (q, r) = i `quotRem` suff -- suff == shrSize sh'
- in $ixcons (fromIntegral q) (fromLin sh' suffs r)
- fromLin _ _ _ = error "impossible"
-
- {-# NOINLINE outrange #-}
- outrange :: $ishty sh -> Int -> a
- outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")" |]
-
- body <- [|
- \sh -> -- give this function arity 1 so that 'suffixes' is shared when
- -- it's called many times
- let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
- in fromLin0 sh suffixes |]
-
- return [SigD fname typesig
- ,FunD fname [Clause [] (NormalB body) locals]]
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 6ce0f4f..59289fb 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -39,10 +39,10 @@ import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -300,6 +300,15 @@ ixrToLinear = \sh i -> go sh i 0
go ZSR ZIR a = a
go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i)
+{-# INLINEABLE ixrFromLinear #-}
+ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i
+ixrFromLinear (ShR sh) i
+ | Refl <- lemRankReplicate (Proxy @m)
+ = ixrFromIxX $ ixxFromLinear sh i
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX = unsafeCoerce
+
-- * Ranked shapes
@@ -505,6 +514,3 @@ listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
listrCastWithName _ SZ ZR = ZR
listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l
listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
-
-$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
-{-# INLINEABLE ixrFromLinear #-}
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 85042f2..23a4fc8 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -70,7 +70,7 @@ sindexPartial sarr@(Shaped arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX))
-- | See 'mgeneratePrim'.
{-# INLINE sgeneratePrim #-}
@@ -253,11 +253,11 @@ siota sn = Shaped (miota sn)
-- | Throws if the array is empty.
sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
+sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr)
-- | Throws if the array is empty.
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr)
sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 521ec2f..f57e7dd 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -41,9 +41,9 @@ import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -318,6 +318,13 @@ ixsToLinear = \sh i -> go sh i 0
go ZSS ZIS a = a
go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i)
+{-# INLINEABLE ixsFromLinear #-}
+ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i
+ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i
+
+ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX = unsafeCoerce
+
-- * Shaped shapes
@@ -533,6 +540,3 @@ instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
fromList = shsFromList (knownShS @sh)
toList = shsToList
-
-$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |])
-{-# INLINEABLE ixsFromLinear #-}