aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs61
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs243
2 files changed, 198 insertions, 106 deletions
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 11a8ffb..beedbcf 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -26,16 +26,11 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-#ifndef OXAR_DEFAULT_SHOW_INSTANCES
-import Data.Foldable (toList)
-#endif
-
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
@@ -65,7 +60,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Ranked n a) where
showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
+ let sh = show (shrToList (rshape arr))
in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
#endif
@@ -87,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
instance Elt a => Elt (Ranked n a) where
+ {-# INLINE mshape #-}
mshape (M_Ranked arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Ranked arr) i = Ranked (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i =
coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
@@ -104,6 +102,7 @@ instance Elt a => Elt (Ranked n a) where
mtoListOuter (M_Ranked arr) =
coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -112,6 +111,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -120,6 +120,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -139,7 +140,7 @@ instance Elt a => Elt (Ranked n a) where
type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
- mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)
+ mshapeTree (Ranked arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -149,18 +150,19 @@ instance Elt a => Elt (Ranked n a) where
marrayStrides (M_Ranked arr) = marrayStrides arr
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh' (Ranked n a))
@(Mixed sh' (Mixed (Replicate n Nothing) a))
arr)
@@ -176,6 +178,14 @@ instance Elt a => Elt (Ranked n a) where
(coerce @(MixedVecs s sh (Ranked n a))
@(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
@@ -188,6 +198,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
@@ -249,20 +263,9 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
ratan2Array = liftRanked2 matan2Array
+{-# INLINE rshape #-}
rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shrFromShX2 (mshape arr)
+rshape (Ranked arr) = coerce (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 6d61bd5..6d47ade 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,8 +1,5 @@
-{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -36,15 +33,16 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, build)
-import GHC.Generics (Generic)
+import GHC.Exts (build)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Lemmas
-import Data.Array.Nested.Mixed.Shape.Internal
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -183,7 +181,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) =
listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+listrSplitAt SZ sh = (ZR, sh)
+listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
case listrRank sh of { shlen@SNat ->
@@ -195,11 +198,6 @@ listrPermutePrefix = \perm sh ->
++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
}
where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
applyPermRFull _ ZR _ = ZR
applyPermRFull sm@SNat (i ::: perm) l =
@@ -216,8 +214,7 @@ listrPermutePrefix = \perm sh ->
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
@@ -243,8 +240,6 @@ instance Show i => Show (IxR n i) where
showsPrec _ (IxR l) = listrShow shows l
#endif
-instance NFData i => NFData (IxR sh i)
-
ixrLength :: IxR sh i -> Int
ixrLength (IxR l) = listrLength l
@@ -255,12 +250,12 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
+{-# INLINEABLE ixrFromList #-}
ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
ixrFromList = coerce (listrFromList @_ @i)
-{-# INLINEABLE ixrToList #-}
-ixrToList :: forall n i. IxR n i -> [i]
-ixrToList = coerce (listrToList @_ @i)
+ixrToList :: IxR n i -> [i]
+ixrToList = Foldable.toList
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -288,27 +283,69 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixrToLinear #-}
+ixrToLinear :: Num i => IShR m -> IxR m i -> i
+ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix)
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+{-# INLINEABLE ixrFromLinear #-}
+ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i
+ixrFromLinear (ShR sh) i
+ | Refl <- lemRankReplicate (Proxy @m)
+ = ixrFromIxX $ ixxFromLinear sh i
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+shrEnum :: IShR n -> [IIxR n]
+shrEnum = shrEnum'
+
+{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
+shrEnum' :: forall i n. Num i => IShR n -> [IxR n i]
+shrEnum' (ShR sh)
+ | Refl <- lemRankReplicate (Proxy @n)
+ = (unsafeCoerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh
+ -- TODO: switch to coerce once newtypes overhauled
-- * Ranked shapes
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+newtype ShR n i = ShR (ShX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor)
pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
+pattern ZSR <- ShR (matchZSR @n -> Just Refl)
+ where ZSR = ShR ZSX
+
+matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZSR _ = Nothing
pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
+pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i))
+ where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl
+ = ShR (SUnknown i :$% shl)
+
+data UnconsShRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i
+shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1)
+shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
+ | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl
+ = Just (UnconsShRRes (ShR sh') x)
+shrUncons (ShR _) = Nothing
+
infixr 3 :$:
{-# COMPLETE ZSR, (:$:) #-}
@@ -319,85 +356,140 @@ type IShR n = ShR n Int
deriving instance Show i => Show (ShR n i)
#else
instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
+ showsPrec d (ShR l) = showsPrec d l
#endif
-instance NFData i => NFData (ShR sh i)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+shrEqRank ZSR ZSR = Just Refl
+shrEqRank (_ :$: sh) (_ :$: sh')
+ | Just Refl <- shrEqRank sh sh'
+ = Just Refl
+shrEqRank _ _ = Nothing
-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+shrEqual ZSR ZSR = Just Refl
+shrEqual (i :$: sh) (i' :$: sh')
+ | Just Refl <- shrEqual sh sh'
+ , i == i'
+ = Just Refl
+shrEqual _ _ = Nothing
shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
+shrLength (ShR l) = shxLength l
-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
+shrRank :: forall n i. ShR n i -> SNat n
+shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
+shrSize (ShR sh) = shxSize sh
-shrFromList :: forall n i. SNat n -> [i] -> ShR n i
-shrFromList = coerce (listrFromList @_ @i)
+-- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@.
+-- We don't report the size of the list in case of errors in order not to retain the list.
+{-# INLINEABLE shrFromList #-}
+shrFromList :: SNat n -> [Int] -> IShR n
+shrFromList snat topl = ShR $ ShX $ go snat topl
+ where
+ go :: SNat n -> [Int] -> ListH (Replicate n Nothing) Int
+ go SZ [] = ZH
+ go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")"
+ go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is)
+ go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")"
+-- This is equivalent to but faster than @coerce shxToList@.
{-# INLINEABLE shrToList #-}
-shrToList :: forall n i. ShR n i -> [i]
-shrToList = coerce (listrToList @_ @i)
+shrToList :: IShR n -> [Int]
+shrToList (ShR (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListH sh Int -> is
+ go ZH = nil
+ go (ConsUnknown i rest) = i `cons` go rest
+ go ConsKnown{} = error "shrToList: impossible case"
+ in go l)
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
+shrHead :: forall n i. ShR (n + 1) i -> i
+shrHead (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxHead @Nothing @(Replicate n Nothing) sh of
+ SUnknown i -> i
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
+shrTail :: forall n i. ShR (n + 1) i -> ShR n i
+shrTail
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (shxTail @_ @_ @i)
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
+shrInit :: forall n i. ShR (n + 1) i -> ShR n i
+shrInit
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = -- TODO: change this and all other unsafeCoerceRefl to lemmas:
+ gcastWith (unsafeCoerceRefl
+ :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $
+ coerce (shxInit @_ @_ @i)
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
+shrLast :: forall n i. ShR (n + 1) i -> i
+shrLast (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxLast sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrLast: impossible SKnown"
-- | Performs a runtime check that the lengths are identical.
shrCast :: SNat n' -> ShR n i -> ShR n' i
-shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+shrCast SZ ZSR = ZSR
+shrCast (SS n) (i :$: sh) = i :$: shrCast n sh
+shrCast _ _ = error "shrCast: ranks don't match"
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+shrAppend =
+ -- lemReplicatePlusApp requires an SNat
+ gcastWith (unsafeCoerceRefl
+ :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $
+ coerce (shxAppend @_ @_ @i)
{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+shrZipWith _ ZSR ZSR = ZSR
+shrZipWith f (i :$: irest) (j :$: jrest) =
+ f i j :$: shrZipWith f irest jrest
+shrZipWith _ _ _ =
+ error "shrZipWith: impossible pattern needlessly required"
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i)
+shrSplitAt SZ sh = (ZSR, sh)
+shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh)
+shrSplitAt SS{} ZSR = error "m' + 1 <= 0"
-shrEnum :: IShR sh -> [IIxR sh]
-shrEnum = shrEnum'
+shrIndex :: forall k sh i. SNat k -> ShR sh i -> i
+shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrIndex: impossible SKnown"
-{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
-shrEnum' :: Num i => IShR sh -> [IxR sh i]
-shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
+-- Copy-pasted from listrPermutePrefix, probably unavoidably.
+shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i
+shrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case shrRank sh of { shlen@SNat ->
+ let sperm = shrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
where
- suffixes = drop 1 (scanr (*) 1 (shrToList sh))
-
- fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
- fromLin ZSR _ _ = ZIR
- fromLin (_ :$: sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
- in fromIntegral (I# q#) :.: fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
+ applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i
+ applyPermRFull _ ZSR _ = ZSR
+ applyPermRFull sm@SNat (i :$: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> shrIndex si l :$: applyPermRFull sm perm l
+ EQI -> shrIndex si l :$: applyPermRFull sm perm l
+ GTI -> error "shrPermutePrefix: Index in permutation out of range"
-- | Untyped: length is checked at runtime.
@@ -413,18 +505,15 @@ instance KnownNat n => IsList (IxR n i) where
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+instance KnownNat n => IsList (IShR n) where
+ type Item (IShR n) = Int
+ fromList = shrFromList (SNat @n)
+ toList = shrToList
-- * Internal helper functions
listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
listrCastWithName _ SZ ZR = ZR
-listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l
listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
-
-$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
-{-# INLINEABLE ixrFromLinear #-}