diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 | 
| commit | e4e23a33f77d250af1e9b6614cf249128ba1510a (patch) | |
| tree | 34bb40910003749becbaf8005a7b7ca62024fff2 /src/Data/Array/Nested | |
| parent | 7c9865354442326d55094087ad6a74b6e96341fb (diff) | |
Shape/index hygiene
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 239 | 
1 files changed, 112 insertions, 127 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 49ed7cb..2f1e79e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -8,12 +8,12 @@  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE RoleAnnotations #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeAbstractions #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} @@ -23,7 +23,7 @@  {-|  TODO: -* We should be more consistent in whether functions take a 'StaticShapeX' +* We should be more consistent in whether functions take a 'StaticShX'    argument or a 'KnownShapeX' constraint.  * Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point @@ -51,7 +51,7 @@ import qualified Data.Vector.Storable.Mutable as VSM  import Foreign.Storable (Storable)  import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)  import qualified Data.Array.Mixed as X  import Data.INat @@ -100,9 +100,9 @@ type family MapJust l where  lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)  lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))    where -    go :: SINat m -> StaticShapeX (Replicate m Nothing) -    go SZ = ZSX -    go (SS n) = () :$? go n +    go :: SINat m -> StaticShX (Replicate m Nothing) +    go SZ = ZKSX +    go (SS n) = () :!$? go n  lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n  lemRankReplicate _ = go (inatSing @n) @@ -119,10 +119,10 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)      go SZ = Refl      go (SS n) | Refl <- go n = Refl -ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh') -ixAppSplit _ ZSX idx = (ZIX, idx) -ixAppSplit p (_ :$@ ssh) (i :.@ idx) = first (i :.@) (ixAppSplit p ssh idx) -ixAppSplit p (_ :$? ssh) (i :.? idx) = first (i :.?) (ixAppSplit p ssh idx) +shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') +shAppSplit _ ZKSX idx = (ZSX, idx) +shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) +shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)  -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -184,7 +184,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ())  -- no content, MV  data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)  -- etc. -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a) +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)  -- | Tree giving the shape of every array component. @@ -196,9 +196,9 @@ type family ShapeTree a where    ShapeTree () = ()    ShapeTree (a, b) = (ShapeTree a, ShapeTree b) -  ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a) -  ShapeTree (Ranked n a) = (IIxR n, ShapeTree a) -  ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a) +  ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) +  ShapeTree (Ranked n a) = (IShR n, ShapeTree a) +  ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)  -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or @@ -207,7 +207,7 @@ type family ShapeTree a where  class Elt a where    -- ====== PUBLIC METHODS ====== -- -  mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh +  mshape :: KnownShapeX sh => Mixed sh a -> IShX sh    mindex :: Mixed sh a -> IIxX sh -> a    mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a    mscalar :: a -> Mixed '[] a @@ -240,15 +240,12 @@ class Elt a where           -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a    -- ====== PRIVATE METHODS ====== -- -  -- Remember I said that this module needed better management of exports?    -- | Create an empty array. The given shape must have size zero; this may or may not be checked. -  memptyArray :: IIxX sh -> Mixed sh a +  memptyArray :: IShX sh -> Mixed sh a    mshapeTree :: a -> ShapeTree a -  mshapeTreeZero :: Proxy a -> ShapeTree a -    mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool    mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool @@ -257,20 +254,20 @@ class Elt a where    -- | Create uninitialised vectors for this array type, given the shape of    -- this vector and an example for the contents. -  mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a) +  mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)    mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)    -- | Given the shape of this array, an index and a value, write the value at    -- that index in the vectors. -  mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +  mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()    -- | Given the shape of this array, an index and a value, write the value at    -- that index in the vectors. -  mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +  mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()    -- | Given the shape of this array, finalise the vectors into 'XArray's. -  mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) +  mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)  -- Arrays of scalars are basically just arrays of scalars. @@ -299,9 +296,8 @@ instance Storable a => Elt (Primitive a) where      , Refl <- X.lemAppNil @sh3      = M_Primitive (f Proxy a b) -  memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")")) +  memptyArray sh = M_Primitive (X.empty sh)    mshapeTree _ = () -  mshapeTreeZero _ = ()    mshapeTreeEq _ () () = True    mshapeTreeEmpty _ () = False    mshowShapeTree _ () = "()" @@ -312,7 +308,7 @@ instance Storable a => Elt (Primitive a) where    -- TODO: this use of toVector is suboptimal    mvecsWritePartial      :: forall sh' sh s. KnownShapeX sh' -    => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () +    => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()    mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do      let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))      VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) @@ -338,7 +334,6 @@ instance (Elt a, Elt b) => Elt (a, b) where    memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)    mshapeTree (x, y) = (mshapeTree x, mshapeTree y) -  mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b))    mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'    mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2    mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" @@ -357,10 +352,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    -- TODO: this is quadratic in the nesting depth because it repeatedly    -- truncates the shape vector to one a little shorter. Fix with a    -- moverlongShape method, a prefix of which is mshape. -  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh +  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh    mshape (M_Nest arr)      | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') -    = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) +    = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))    mindex (M_Nest arr) i = mindexPartial arr i @@ -410,12 +405,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where          , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))          = f (Proxy @(sh' ++ shT)) -  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh')))) +  memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh'))))    mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) -  mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a)) -    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2    mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t @@ -424,32 +417,26 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    mvecsUnsafeNew sh example      | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) -    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) +    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example))                                                   (mindex example (X.zeroIxX (knownShapeX @sh')))      where        sh' = mshape example -  mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) +  mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) -  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs +  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs    mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 -                    => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) +                    => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)                      -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)                      -> ST s ()    mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)      | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))      , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs - -  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs +    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs +  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs --- | Check whether a given shape corresponds on the statically-known components of the shape. -checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool -checkBounds ZIX ZSX = True -checkBounds (n :.@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' -checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh'  -- | Create an array given a size and a function that computes the element at a  -- given index. @@ -468,31 +455,25 @@ checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh'  -- the entire hierarchy (after distributing out tuples) must be a rectangular  -- array. The type of 'mgenerate' allows this requirement to be broken very  -- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f -  -- TODO: Do we need this checkBounds check elsewhere as well? -  | not (checkBounds sh (knownShapeX @sh)) = -      error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  -- If the shape is empty, there is no first element, so we should not try to -  -- generate it. -  | X.shapeSize sh == 0 = memptyArray sh -  | otherwise = -      let firstidx = X.zeroIxX' sh -          firstelem = f (X.zeroIxX' sh) -          shapetree = mshapeTree firstelem -      in if mshapeTreeEmpty (Proxy @a) shapetree -           then memptyArray sh -           else runST $ do -                  vecs <- mvecsUnsafeNew sh firstelem -                  mvecsWrite sh firstidx firstelem vecs -                  -- TODO: This is likely fine if @a@ is big, but if @a@ is a -                  -- scalar this array copying inefficient. Should improve this. -                  forM_ (tail (X.enumShape sh)) $ \idx -> do -                    let val = f idx -                    when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ -                      error "Data.Array.Nested mgenerate: generated values do not have equal shapes" -                    mvecsWrite sh idx val vecs -                  mvecsFreeze sh vecs +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case X.enumShape sh of +  [] -> memptyArray sh +  firstidx : restidxs -> +    let firstelem = f (X.zeroIxX' sh) +        shapetree = mshapeTree firstelem +    in if mshapeTreeEmpty (Proxy @a) shapetree +         then memptyArray sh +         else runST $ do +                vecs <- mvecsUnsafeNew sh firstelem +                mvecsWrite sh firstidx firstelem vecs +                -- TODO: This is likely fine if @a@ is big, but if @a@ is a +                -- scalar this array copying inefficient. Should improve this. +                forM_ restidxs $ \idx -> do +                  let val = f idx +                  when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ +                    error "Data.Array.Nested mgenerate: generated values do not have equal shapes" +                  mvecsWrite sh idx val vecs +                mvecsFreeze sh vecs  mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a  mtranspose perm = @@ -506,12 +487,8 @@ mappend = mlift2 go             => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b          go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVector sh v -  | not (checkBounds sh (knownShapeX @sh)) = -      error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  | otherwise = -      M_Primitive (X.fromVector sh v) +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v = M_Primitive (X.fromVector sh v)  mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a  mfromList1 = mfromList . fmap mscalar @@ -522,17 +499,13 @@ mtoList1 = map munScalar . mtoList  munScalar :: Elt a => Mixed '[] a -> a  munScalar arr = mindex arr ZIX -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x -  | not (checkBounds sh (knownShapeX @sh)) = -      error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  | otherwise = -      M_Primitive (X.constant sh x) +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) +mconstantP sh x = M_Primitive (X.constant sh x)  -- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'  -- is defined.  mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) -          => IIxX sh -> a -> Mixed sh a +          => IShX sh -> a -> Mixed sh a  mconstant sh x = coerce (mconstantP sh x)  mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a @@ -648,7 +621,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where      = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $          mlift2 f arr1 arr2 -  memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a) +  memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)    memptyArray i      | Dict <- lemKnownReplicate (Proxy @n)      = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ @@ -657,9 +630,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where    mshapeTree (Ranked arr)      | Refl <- lemRankReplicate (Proxy @n)      , Dict <- lemKnownReplicate (Proxy @n) -    = first ixCvtXR (mshapeTree arr) - -  mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a)) +    = first shCvtXR (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -675,7 +646,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where      | Dict <- lemKnownReplicate (Proxy @n)      = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) -  mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () +  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()    mvecsWrite sh idx (Ranked arr) vecs      | Dict <- lemKnownReplicate (Proxy @n)      = mvecsWrite sh idx arr @@ -683,7 +654,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where             vecs)    mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' -                    => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) +                    => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)                      -> MixedVecs s (sh ++ sh') (Ranked n a)                      -> ST s ()    mvecsWritePartial sh idx arr vecs @@ -696,7 +667,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where                  @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))             vecs) -  mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) +  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))    mvecsFreeze sh vecs      | Dict <- lemKnownReplicate (Proxy @n)      = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @@ -712,6 +683,8 @@ data ShS sh where    ZSS :: ShS '[]    (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)  deriving instance Show (ShS sh) +deriving instance Eq (ShS sh) +deriving instance Ord (ShS sh)  infixr 3 :$$  -- | A statically-known shape of a shape-typed array. @@ -726,9 +699,9 @@ sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict  lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)  lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))    where -    go :: ShS sh' -> StaticShapeX (MapJust sh') -    go ZSS = ZSX -    go (n :$$ sh) = n :$@ go sh +    go :: ShS sh' -> StaticShX (MapJust sh') +    go ZSS = ZKSX +    go (n :$$ sh) = n :!$@ go sh  lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2                    -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 @@ -777,7 +750,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where      = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $          mlift2 f arr1 arr2 -  memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a) +  memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)    memptyArray i      | Dict <- lemKnownMapJust (Proxy @sh)      = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ @@ -785,9 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where    mshapeTree (Shaped arr)      | Dict <- lemKnownMapJust (Proxy @sh) -    = first (ixCvtXS (knownShape @sh)) (mshapeTree arr) - -  mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a)) +    = first (shCvtXS (knownShape @sh)) (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -803,7 +774,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where      | Dict <- lemKnownMapJust (Proxy @sh)      = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) -  mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () +  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()    mvecsWrite sh idx (Shaped arr) vecs      | Dict <- lemKnownMapJust (Proxy @sh)      = mvecsWrite sh idx arr @@ -811,7 +782,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where             vecs)    mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 -                    => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) +                    => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)                      -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)                      -> ST s ()    mvecsWritePartial sh idx arr vecs @@ -824,7 +795,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where                  @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))             vecs) -  mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) +  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))    mvecsFreeze sh vecs      | Dict <- lemKnownMapJust (Proxy @sh)      = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @@ -927,6 +898,8 @@ newtype ShR n i = ShR (ListR n i)    deriving (Show, Eq, Ord)    deriving newtype (Functor, Foldable) +type IShR n = ShR n Int +  pattern ZSR :: forall n i. () => n ~ Z => ShR n i  pattern ZSR = ShR ZR @@ -957,20 +930,29 @@ ixCvtXR ZIX = ZIR  ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx  ixCvtXR (n :.? idx) = n :.: ixCvtXR idx +shCvtXR :: IShX sh -> IShR (X.Rank sh) +shCvtXR ZSX = ZSR +shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx +shCvtXR (n :$? idx) = n :$: shCvtXR idx +  ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)  ixCvtRX ZIR = ZIX  ixCvtRX (n :.: idx) = n :.? ixCvtRX idx -shapeSizeR :: IIxR n -> Int -shapeSizeR ZIR = 1 -shapeSizeR (n :.: sh) = n * shapeSizeR sh +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: idx) = n :$? shCvtRX idx + +shapeSizeR :: IShR n -> Int +shapeSizeR ZSR = 1 +shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n +rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n  rshape (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n)    , Refl <- lemRankReplicate (Proxy @n) -  = ixCvtXR (mshape arr) +  = shCvtXR (mshape arr)  rindex :: Elt a => Ranked n a -> IIxR n -> a  rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) @@ -983,12 +965,12 @@ rindexPartial (Ranked arr) idx =  -- | __WARNING__: All values returned from the function must have equal shape.  -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a +rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a  rgenerate sh f -  | Dict <- knownIxR sh +  | Dict <- knownShR sh    , Dict <- lemKnownReplicate (Proxy @n)    , Refl <- lemRankReplicate (Proxy @n) -  = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) +  = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))  -- | See the documentation of 'mlift'.  rlift :: forall n1 n2 a. (KnownINat n2, Elt a) @@ -1005,7 +987,7 @@ rsumOuter1 (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n)    = Ranked      . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) -    . X.sumOuter (() :$? ZSX) (knownShapeX @(Replicate n Nothing)) +    . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))      . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)      $ arr @@ -1021,10 +1003,10 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend  rscalar :: Elt a => a -> Ranked I0 a  rscalar x = Ranked (mscalar x) -rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)  rfromVector sh v    | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mfromVector (ixCvtRX sh) v) +  = Ranked (mfromVector (shCvtRX sh) v)  rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a  rfromList l @@ -1043,13 +1025,13 @@ rtoList1 = map runScalar . rtoList  runScalar :: Elt a => Ranked I0 a -> a  runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)  rconstantP sh x    | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mconstantP (ixCvtRX sh) x) +  = Ranked (mconstantP (shCvtRX sh) x)  rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) -          => IIxR n -> a -> Ranked n a +          => IShR n -> a -> Ranked n a  rconstant sh x = coerce (rconstantP sh x)  rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a @@ -1141,27 +1123,30 @@ zeroIxS :: ShS sh -> IIxS sh  zeroIxS ZSS = ZIS  zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh --- TODO: this function should not exist perhaps -cvtShSIxS :: ShS sh -> IIxS sh -cvtShSIxS ZSS = ZIS -cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh -  ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh  ixCvtXS ZSS ZIX = ZIS  ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx +shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh +shCvtXS ZSS ZSX = ZSS +shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx +  ixCvtSX :: IIxS sh -> IIxX (MapJust sh)  ixCvtSX ZIS = ZIX  ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh -shapeSizeS :: IIxS sh -> Int -shapeSizeS ZIS = 1 -shapeSizeS (n :.$ sh) = n * shapeSizeS sh +shCvtSX :: ShS sh -> IShX (MapJust sh) +shCvtSX ZSS = ZSX +shCvtSX (n :$$ sh) = n :$@ shCvtSX sh + +shapeSizeS :: ShS sh -> Int +shapeSizeS ZSS = 1 +shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh  -- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh -sshape _ = cvtShSIxS (knownShape @sh) +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh +sshape _ = knownShape @sh  sindex :: Elt a => Shaped sh a -> IIxS sh -> a  sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) @@ -1177,7 +1162,7 @@ sindexPartial (Shaped arr) idx =  sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a  sgenerate f    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) +  = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh)))  -- | See the documentation of 'mlift'.  slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) @@ -1194,7 +1179,7 @@ ssumOuter1 (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh)    = Shaped      . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) -    . X.sumOuter (natSing @n :$@ ZSX) (knownShapeX @(MapJust sh)) +    . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh))      . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)      $ arr @@ -1213,7 +1198,7 @@ sscalar x = Shaped (mscalar x)  sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)  sfromVector v    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v) +  = Shaped (mfromVector (shCvtSX (knownShape @sh)) v)  sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)            => NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1236,7 +1221,7 @@ sunScalar arr = sindex arr ZIS  sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)  sconstantP x    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x) +  = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)  sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))            => a -> Shaped sh a | 
