diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 61 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 243 |
2 files changed, 198 insertions, 106 deletions
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 11a8ffb..beedbcf 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -26,16 +26,11 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy -import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -#ifndef OXAR_DEFAULT_SHOW_INSTANCES -import Data.Foldable (toList) -#endif - import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape @@ -65,7 +60,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) #ifndef OXAR_DEFAULT_SHOW_INSTANCES instance (Show a, Elt a) => Show (Ranked n a) where showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) + let sh = show (shrToList (rshape arr)) in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr #endif @@ -87,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance Elt a => Elt (Ranked n a) where + {-# INLINE mshape #-} mshape (M_Ranked arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Ranked arr) i = Ranked (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ @@ -104,6 +102,7 @@ instance Elt a => Elt (Ranked n a) where mtoListOuter (M_Ranked arr) = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -112,6 +111,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -120,6 +120,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -139,7 +140,7 @@ instance Elt a => Elt (Ranked n a) where type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) + mshapeTree (Ranked arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -149,18 +150,19 @@ instance Elt a => Elt (Ranked n a) where marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWriteLinear idx (Ranked arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh sh' s. + Proxy sh -> Int -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) @@ -176,6 +178,14 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) + mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) @@ -188,6 +198,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) @@ -249,20 +263,9 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a ratan2Array = liftRanked2 matan2Array +{-# INLINE rshape #-} rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shrFromShX2 (mshape arr) +rshape (Ranked arr) = coerce (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape - --- Needed already here, but re-exported in Data.Array.Nested.Convert. -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- Needed already here, but re-exported in Data.Array.Nested.Convert. --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh - | Refl <- lemRankReplicate (Proxy @n) - = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 6d61bd5..6d47ade 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,8 +1,5 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -36,15 +33,16 @@ import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, build) -import GHC.Generics (Generic) +import GHC.Exts (build) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas -import Data.Array.Nested.Mixed.Shape.Internal +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -183,7 +181,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) = listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) +listrSplitAt SZ sh = (ZR, sh) +listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) +listrSplitAt SS{} ZR = error "m' + 1 <= 0" + +listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case listrRank sh of { shlen@SNat -> @@ -195,11 +198,6 @@ listrPermutePrefix = \perm sh -> ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" } where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = @@ -216,8 +214,7 @@ listrPermutePrefix = \perm sh -> type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR @@ -243,8 +240,6 @@ instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l #endif -instance NFData i => NFData (IxR sh i) - ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l @@ -255,12 +250,12 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n +{-# INLINEABLE ixrFromList #-} ixrFromList :: forall n i. SNat n -> [i] -> IxR n i ixrFromList = coerce (listrFromList @_ @i) -{-# INLINEABLE ixrToList #-} -ixrToList :: forall n i. IxR n i -> [i] -ixrToList = coerce (listrToList @_ @i) +ixrToList :: IxR n i -> [i] +ixrToList = Foldable.toList ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -288,27 +283,69 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix) + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +{-# INLINEABLE ixrFromLinear #-} +ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i +ixrFromLinear (ShR sh) i + | Refl <- lemRankReplicate (Proxy @m) + = ixrFromIxX $ ixxFromLinear sh i + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +shrEnum :: IShR n -> [IIxR n] +shrEnum = shrEnum' + +{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site +shrEnum' :: forall i n. Num i => IShR n -> [IxR n i] +shrEnum' (ShR sh) + | Refl <- lemRankReplicate (Proxy @n) + = (unsafeCoerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh + -- TODO: switch to coerce once newtypes overhauled -- * Ranked shapes type role ShR nominal representational type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) +newtype ShR n i = ShR (ShX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR +pattern ZSR <- ShR (matchZSR @n -> Just Refl) + where ZSR = ShR ZSX + +matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZSR _ = Nothing pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) +pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i)) + where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl + = ShR (SUnknown i :$% shl) + +data UnconsShRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i +shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1) +shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) + | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl + = Just (UnconsShRRes (ShR sh') x) +shrUncons (ShR _) = Nothing + infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} @@ -319,85 +356,140 @@ type IShR n = ShR n Int deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l + showsPrec d (ShR l) = showsPrec d l #endif -instance NFData i => NFData (ShR sh i) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' +shrEqRank ZSR ZSR = Just Refl +shrEqRank (_ :$: sh) (_ :$: sh') + | Just Refl <- shrEqRank sh sh' + = Just Refl +shrEqRank _ _ = Nothing -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' +shrEqual ZSR ZSR = Just Refl +shrEqual (i :$: sh) (i' :$: sh') + | Just Refl <- shrEqual sh sh' + , i == i' + = Just Refl +shrEqual _ _ = Nothing shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l +shrLength (ShR l) = shxLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh +shrRank :: forall n i. ShR n i -> SNat n +shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh +shrSize (ShR sh) = shxSize sh -shrFromList :: forall n i. SNat n -> [i] -> ShR n i -shrFromList = coerce (listrFromList @_ @i) +-- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@. +-- We don't report the size of the list in case of errors in order not to retain the list. +{-# INLINEABLE shrFromList #-} +shrFromList :: SNat n -> [Int] -> IShR n +shrFromList snat topl = ShR $ ShX $ go snat topl + where + go :: SNat n -> [Int] -> ListH (Replicate n Nothing) Int + go SZ [] = ZH + go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")" + go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is) + go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")" +-- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shrToList #-} -shrToList :: forall n i. ShR n i -> [i] -shrToList = coerce (listrToList @_ @i) +shrToList :: IShR n -> [Int] +shrToList (ShR (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListH sh Int -> is + go ZH = nil + go (ConsUnknown i rest) = i `cons` go rest + go ConsKnown{} = error "shrToList: impossible case" + in go l) -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list +shrHead :: forall n i. ShR (n + 1) i -> i +shrHead (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxHead @Nothing @(Replicate n Nothing) sh of + SUnknown i -> i -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) +shrTail :: forall n i. ShR (n + 1) i -> ShR n i +shrTail + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (shxTail @_ @_ @i) -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) +shrInit :: forall n i. ShR (n + 1) i -> ShR n i +shrInit + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = -- TODO: change this and all other unsafeCoerceRefl to lemmas: + gcastWith (unsafeCoerceRefl + :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ + coerce (shxInit @_ @_ @i) -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list +shrLast :: forall n i. ShR (n + 1) i -> i +shrLast (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxLast sh of + SUnknown i -> i + SKnown{} -> error "shrLast: impossible SKnown" -- | Performs a runtime check that the lengths are identical. shrCast :: SNat n' -> ShR n i -> ShR n' i -shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) +shrCast SZ ZSR = ZSR +shrCast (SS n) (i :$: sh) = i :$: shrCast n sh +shrCast _ _ = error "shrCast: ranks don't match" shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +shrAppend = + -- lemReplicatePlusApp requires an SNat + gcastWith (unsafeCoerceRefl + :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ + coerce (shxAppend @_ @_ @i) {-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 +shrZipWith _ ZSR ZSR = ZSR +shrZipWith f (i :$: irest) (j :$: jrest) = + f i j :$: shrZipWith f irest jrest +shrZipWith _ _ _ = + error "shrZipWith: impossible pattern needlessly required" -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) +shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i) +shrSplitAt SZ sh = (ZSR, sh) +shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh) +shrSplitAt SS{} ZSR = error "m' + 1 <= 0" -shrEnum :: IShR sh -> [IIxR sh] -shrEnum = shrEnum' +shrIndex :: forall k sh i. SNat k -> ShR sh i -> i +shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of + SUnknown i -> i + SKnown{} -> error "shrIndex: impossible SKnown" -{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site -shrEnum' :: Num i => IShR sh -> [IxR sh i] -shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]] +-- Copy-pasted from listrPermutePrefix, probably unavoidably. +shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i +shrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case shrRank sh of { shlen@SNat -> + let sperm = shrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } where - suffixes = drop 1 (scanr (*) 1 (shrToList sh)) - - fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i - fromLin ZSR _ _ = ZIR - fromLin (_ :$: sh') (I# suff# : suffs) i# = - let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' - in fromIntegral (I# q#) :.: fromLin sh' suffs r# - fromLin _ _ _ = error "impossible" + applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i + applyPermRFull _ ZSR _ = ZSR + applyPermRFull sm@SNat (i :$: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> shrIndex si l :$: applyPermRFull sm perm l + EQI -> shrIndex si l :$: applyPermRFull sm perm l + GTI -> error "shrPermutePrefix: Index in permutation out of range" -- | Untyped: length is checked at runtime. @@ -413,18 +505,15 @@ instance KnownNat n => IsList (IxR n i) where toList = Foldable.toList -- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +instance KnownNat n => IsList (IShR n) where + type Item (IShR n) = Int + fromList = shrFromList (SNat @n) + toList = shrToList -- * Internal helper functions listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR -listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l listrCastWithName name _ _ = error $ name ++ ": ranks don't match" - -$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) -{-# INLINEABLE ixrFromLinear #-} |
