aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs33
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs28
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs26
3 files changed, 35 insertions, 52 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index ebf0a07..fd8c4ce 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -328,6 +328,23 @@ ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared
outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
" in array of shape " ++ show sh ++ ")"
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = shxEnum'
+
+{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
+shxEnum' :: Num i => IShX sh -> [IxX sh i]
+shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+
+ fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
+ fromLin ZSX _ _ = ZIX
+ fromLin (_ :$% sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.% fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
+
-- * Mixed shape-like lists to be used for ShX and StaticShX
data SMayNat i n where
@@ -648,22 +665,6 @@ shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
-shxEnum :: IShX sh -> [IIxX sh]
-shxEnum = shxEnum'
-
-{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
-shxEnum' :: Num i => IShX sh -> [IxX sh i]
-shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
- where
- suffixes = drop 1 (scanr (*) 1 (shxToList sh))
-
- fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
- fromLin ZSX _ _ = ZIX
- fromLin (_ :$% sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
- in fromIntegral (I# q#) :.% fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
-
shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
shxCast ZKX ZSX = Just ZSX
shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 59289fb..0ac980e 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,4 +1,3 @@
-{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -34,7 +33,7 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, build, quotRemInt#)
+import GHC.Exts (build)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
@@ -309,6 +308,15 @@ ixrFromLinear (ShR sh) i
ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
ixrFromIxX = unsafeCoerce
+shrEnum :: IShR n -> [IIxR n]
+shrEnum = shrEnum'
+
+{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
+shrEnum' :: forall i n. Num i => IShR n -> [IxR n i]
+shrEnum' (ShR sh)
+ | Refl <- lemRankReplicate (Proxy @n)
+ = (unsafeCoerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh
+
-- * Ranked shapes
@@ -472,22 +480,6 @@ shrPermutePrefix = \perm sh ->
EQI -> shrIndex si l :$: applyPermRFull sm perm l
GTI -> error "shrPermutePrefix: Index in permutation out of range"
-shrEnum :: IShR sh -> [IIxR sh]
-shrEnum = shrEnum'
-
-{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
-shrEnum' :: Num i => IShR sh -> [IxR sh i]
-shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
- where
- suffixes = drop 1 (scanr (*) 1 (shrToList sh))
-
- fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
- fromLin ZSR _ _ = ZIR
- fromLin (_ :$: sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
- in fromIntegral (I# q#) :.: fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
-
-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ListR n i) where
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index f57e7dd..39be729 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,4 +1,3 @@
-{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -37,7 +36,7 @@ import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
+import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
@@ -325,6 +324,13 @@ ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i
ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i
ixsFromIxX = unsafeCoerce
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' (ShS sh) = (unsafeCoerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh
+
-- * Shaped shapes
@@ -506,22 +512,6 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
-shsEnum :: ShS sh -> [IIxS sh]
-shsEnum = shsEnum'
-
-{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
-shsEnum' :: Num i => ShS sh -> [IxS sh i]
-shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]]
- where
- suffixes = drop 1 (scanr (*) 1 (shsToList sh))
-
- fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i
- fromLin ZSS _ _ = ZIS
- fromLin (_ :$$ sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh'
- in fromIntegral (I# q#) :.$ fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
-
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where