summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-21 15:45:47 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-21 15:45:47 +0200
commit3a82a91be0f1b18f071cdb35526b2b2d0b8e093f (patch)
tree6a65376a17ef4051446952541bab57149c66fca7 /src/Data/Array/Nested
parent181c02b60445310105be126fd6cc2ee9f5dc2c8a (diff)
Make index types useful for horde-ad by parameterising Int
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs144
1 files changed, 74 insertions, 70 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index fadf1a7..9a87389 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -47,7 +47,7 @@ import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
import qualified Data.Array.Mixed as X
import Data.INat
@@ -105,7 +105,7 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)
go SZ = Refl
go (SS n) | Refl <- go n = Refl
-ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IxX (sh ++ sh') -> (IxX sh, IxX sh')
+ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh')
ixAppSplit _ SZX idx = (IZX, idx)
ixAppSplit p (_ :$@ ssh) (i ::@ idx) = first (i ::@) (ixAppSplit p ssh idx)
ixAppSplit p (_ :$? ssh) (i ::? idx) = first (i ::?) (ixAppSplit p ssh idx)
@@ -168,7 +168,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MV
data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
-- etc.
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IxX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-- | Tree giving the shape of every array component.
@@ -179,9 +179,9 @@ type family ShapeTree a where
ShapeTree () = ()
ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- ShapeTree (Mixed sh a) = (IxX sh, ShapeTree a)
- ShapeTree (Ranked n a) = (IxR n, ShapeTree a)
- ShapeTree (Shaped sh a) = (IxS sh, ShapeTree a)
+ ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a)
+ ShapeTree (Ranked n a) = (IIxR n, ShapeTree a)
+ ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a)
-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
@@ -190,9 +190,9 @@ type family ShapeTree a where
class Elt a where
-- ====== PUBLIC METHODS ====== --
- mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
- mindex :: Mixed sh a -> IxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+ mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
-- | All arrays in the list, even subarrays inside @a@, must have the same
@@ -226,7 +226,7 @@ class Elt a where
-- Remember I said that this module needed better management of exports?
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IxX sh -> Mixed sh a
+ memptyArray :: IIxX sh -> Mixed sh a
mshapeTree :: a -> ShapeTree a
@@ -240,20 +240,20 @@ class Elt a where
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents.
- mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+ mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a)
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- Arrays of scalars are basically just arrays of scalars.
@@ -295,7 +295,7 @@ instance Storable a => Elt (Primitive a) where
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
:: forall sh' sh s. KnownShapeX sh'
- => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))
VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
@@ -339,7 +339,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
+ mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
= fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
@@ -347,7 +347,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mindex (M_Nest arr) i = mindexPartial arr i
mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest arr) i
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
@@ -416,7 +416,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
@@ -428,7 +428,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
-- | Check whether a given shape corresponds on the statically-known components of the shape.
-checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool
checkBounds IZX SZX = True
checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
@@ -450,7 +450,7 @@ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
-mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f
-- TODO: Do we need this checkBounds check elsewhere as well?
| not (checkBounds sh (knownShapeX @sh)) =
@@ -488,7 +488,7 @@ mappend = mlift2 go
=> Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
-mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVector sh v
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
@@ -504,7 +504,7 @@ mtoList1 = map munScalar . mtoList
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr IZX
-mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> a -> Mixed sh (Primitive a)
+mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a)
mconstantP sh x
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
@@ -514,7 +514,7 @@ mconstantP sh x
-- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'
-- is defined.
mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
- => IxX sh -> a -> Mixed sh a
+ => IIxX sh -> a -> Mixed sh a
mconstant sh x = coerce (mconstantP sh x)
mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
@@ -594,7 +594,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
@@ -629,7 +629,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
= coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 f arr1 arr2
- memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
+ memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a)
memptyArray i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
@@ -656,7 +656,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
- mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs
| Dict <- lemKnownReplicate (Proxy @n)
= mvecsWrite sh idx arr
@@ -664,7 +664,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
vecs)
mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
- => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
+ => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
-> MixedVecs s (sh ++ sh') (Ranked n a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
@@ -677,7 +677,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
@(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
vecs)
- mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
mvecsFreeze sh vecs
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
@@ -723,7 +723,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
mindexPartial (M_Shaped arr) i
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
@@ -758,7 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
= coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 f arr1 arr2
- memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a)
memptyArray i
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
@@ -784,7 +784,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
- mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs
| Dict <- lemKnownMapJust (Proxy @sh)
= mvecsWrite sh idx arr
@@ -792,7 +792,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
vecs)
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
-> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
@@ -805,7 +805,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
@(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
vecs)
- mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
mvecsFreeze sh vecs
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a))
@@ -850,46 +850,48 @@ deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
-- | An index into a rank-typed array.
-type IxR :: INat -> Type
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
-deriving instance Show (IxR n)
-deriving instance Eq (IxR n)
-infixr 5 :::
-
-zeroIxR :: SINat n -> IxR n
+type IxR :: Type -> INat -> Type
+data IxR i n where
+ IZR :: IxR i Z
+ (:::) :: forall n i. i -> IxR i n -> IxR i (S n)
+deriving instance Show i => Show (IxR i n)
+deriving instance Eq i => Eq (IxR i n)
+infixr 3 :::
+
+type IIxR = IxR Int
+
+zeroIxR :: SINat n -> IIxR n
zeroIxR SZ = IZR
zeroIxR (SS n) = 0 ::: zeroIxR n
-ixCvtXR :: IxX sh -> IxR (X.Rank sh)
+ixCvtXR :: IIxX sh -> IIxR (X.Rank sh)
ixCvtXR IZX = IZR
ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx
ixCvtXR (n ::? idx) = n ::: ixCvtXR idx
-ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX IZR = IZX
ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
-knownIxR :: IxR n -> Dict KnownINat n
+knownIxR :: IIxR n -> Dict KnownINat n
knownIxR IZR = Dict
knownIxR (_ ::: idx) | Dict <- knownIxR idx = Dict
-shapeSizeR :: IxR n -> Int
+shapeSizeR :: IIxR n -> Int
shapeSizeR IZR = 1
shapeSizeR (n ::: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n
+rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= ixCvtXR (mshape arr)
-rindex :: Elt a => Ranked n a -> IxR n -> a
+rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
@@ -897,7 +899,7 @@ rindexPartial (Ranked arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. Elt a => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
| Dict <- knownIxR sh
, Dict <- lemKnownReplicate (Proxy @n)
@@ -935,7 +937,7 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
rscalar :: Elt a => a -> Ranked I0 a
rscalar x = Ranked (mscalar x)
-rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVector sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVector (ixCvtRX sh) v)
@@ -957,13 +959,13 @@ rtoList1 = map runScalar . rtoList
runScalar :: Elt a => Ranked I0 a -> a
runScalar arr = rindex arr IZR
-rconstantP :: forall n a. (KnownINat n, Storable a) => IxR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a)
rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstantP (ixCvtRX sh) x)
rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
- => IxR n -> a -> Ranked n a
+ => IIxR n -> a -> Ranked n a
rconstant sh x = coerce (rconstantP sh x)
rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
@@ -1000,43 +1002,45 @@ deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped
-- (traditionally called \"@Fin@\"). Note that because the shape of a
-- shape-typed array is known statically, you can also retrieve the array shape
-- from a 'KnownShape' dictionary.
-type IxS :: [Nat] -> Type
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
-deriving instance Show (IxS n)
-deriving instance Eq (IxS n)
-infixr 5 ::$
-
-zeroIxS :: SShape sh -> IxS sh
+type IxS :: Type -> [Nat] -> Type
+data IxS i sh where
+ IZS :: IxS i '[]
+ (::$) :: forall n sh i. i -> IxS i sh -> IxS i (n : sh)
+deriving instance Show i => Show (IxS i n)
+deriving instance Eq i => Eq (IxS i n)
+infixr 3 ::$
+
+type IIxS = IxS Int
+
+zeroIxS :: SShape sh -> IIxS sh
zeroIxS ShNil = IZS
zeroIxS (ShCons _ sh) = 0 ::$ zeroIxS sh
-cvtSShapeIxS :: SShape sh -> IxS sh
+cvtSShapeIxS :: SShape sh -> IIxS sh
cvtSShapeIxS ShNil = IZS
cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh
-ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
+ixCvtXS :: SShape sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ShNil IZX = IZS
ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx
-ixCvtSX :: IxS sh -> IxX (MapJust sh)
+ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX IZS = IZX
ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
-shapeSizeS :: IxS sh -> Int
+shapeSizeS :: IIxS sh -> Int
shapeSizeS IZS = 1
shapeSizeS (n ::$ sh) = n * shapeSizeS sh
-- | This does not touch the passed array, all information comes from 'KnownShape'.
-sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
+sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh
sshape _ = cvtSShapeIxS (knownShape @sh)
-sindex :: Elt a => Shaped sh a -> IxS sh -> a
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
+sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial (Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
(rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
@@ -1044,7 +1048,7 @@ sindexPartial (Shaped arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a
sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))