diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 |
commit | e4e23a33f77d250af1e9b6614cf249128ba1510a (patch) | |
tree | 34bb40910003749becbaf8005a7b7ca62024fff2 /src/Data/Array/Nested/Internal.hs | |
parent | 7c9865354442326d55094087ad6a74b6e96341fb (diff) |
Shape/index hygiene
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 239 |
1 files changed, 112 insertions, 127 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 49ed7cb..2f1e79e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -8,12 +8,12 @@ {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -23,7 +23,7 @@ {-| TODO: -* We should be more consistent in whether functions take a 'StaticShapeX' +* We should be more consistent in whether functions take a 'StaticShX' argument or a 'KnownShapeX' constraint. * Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point @@ -51,7 +51,7 @@ import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X import Data.INat @@ -100,9 +100,9 @@ type family MapJust l where lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) where - go :: SINat m -> StaticShapeX (Replicate m Nothing) - go SZ = ZSX - go (SS n) = () :$? go n + go :: SINat m -> StaticShX (Replicate m Nothing) + go SZ = ZKSX + go (SS n) = () :!$? go n lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (inatSing @n) @@ -119,10 +119,10 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n) go SZ = Refl go (SS n) | Refl <- go n = Refl -ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh') -ixAppSplit _ ZSX idx = (ZIX, idx) -ixAppSplit p (_ :$@ ssh) (i :.@ idx) = first (i :.@) (ixAppSplit p ssh idx) -ixAppSplit p (_ :$? ssh) (i :.? idx) = first (i :.?) (ixAppSplit p ssh idx) +shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') +shAppSplit _ ZKSX idx = (ZSX, idx) +shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) +shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx) -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -184,7 +184,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MV data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) -- etc. -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a) +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) -- | Tree giving the shape of every array component. @@ -196,9 +196,9 @@ type family ShapeTree a where ShapeTree () = () ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a) - ShapeTree (Ranked n a) = (IIxR n, ShapeTree a) - ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a) + ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) + ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or @@ -207,7 +207,7 @@ type family ShapeTree a where class Elt a where -- ====== PUBLIC METHODS ====== -- - mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh + mshape :: KnownShapeX sh => Mixed sh a -> IShX sh mindex :: Mixed sh a -> IIxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a @@ -240,15 +240,12 @@ class Elt a where -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a -- ====== PRIVATE METHODS ====== -- - -- Remember I said that this module needed better management of exports? -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IIxX sh -> Mixed sh a + memptyArray :: IShX sh -> Mixed sh a mshapeTree :: a -> ShapeTree a - mshapeTreeZero :: Proxy a -> ShapeTree a - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool @@ -257,20 +254,20 @@ class Elt a where -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. - mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a) + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- Arrays of scalars are basically just arrays of scalars. @@ -299,9 +296,8 @@ instance Storable a => Elt (Primitive a) where , Refl <- X.lemAppNil @sh3 = M_Primitive (f Proxy a b) - memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")")) + memptyArray sh = M_Primitive (X.empty sh) mshapeTree _ = () - mshapeTreeZero _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" @@ -312,7 +308,7 @@ instance Storable a => Elt (Primitive a) where -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. KnownShapeX sh' - => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) @@ -338,7 +334,6 @@ instance (Elt a, Elt b) => Elt (a, b) where memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b)) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" @@ -357,10 +352,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh + mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) + = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) mindex (M_Nest arr) i = mindexPartial arr i @@ -410,12 +405,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) = f (Proxy @(sh' ++ shT)) - memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh')))) + memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh')))) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) - mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a)) - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t @@ -424,32 +417,26 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example)) (mindex example (X.zeroIxX (knownShapeX @sh'))) where sh' = mshape example - mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs + = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs --- | Check whether a given shape corresponds on the statically-known components of the shape. -checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool -checkBounds ZIX ZSX = True -checkBounds (n :.@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' -checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh' -- | Create an array given a size and a function that computes the element at a -- given index. @@ -468,31 +455,25 @@ checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh' -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f - -- TODO: Do we need this checkBounds check elsewhere as well? - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - -- If the shape is empty, there is no first element, so we should not try to - -- generate it. - | X.shapeSize sh == 0 = memptyArray sh - | otherwise = - let firstidx = X.zeroIxX' sh - firstelem = f (X.zeroIxX' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ (tail (X.enumShape sh)) $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case X.enumShape sh of + [] -> memptyArray sh + firstidx : restidxs -> + let firstelem = f (X.zeroIxX' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArray sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a mtranspose perm = @@ -506,12 +487,8 @@ mappend = mlift2 go => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVector sh v - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - | otherwise = - M_Primitive (X.fromVector sh v) +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v = M_Primitive (X.fromVector sh v) mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a mfromList1 = mfromList . fmap mscalar @@ -522,17 +499,13 @@ mtoList1 = map munScalar . mtoList munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - | otherwise = - M_Primitive (X.constant sh x) +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) +mconstantP sh x = M_Primitive (X.constant sh x) -- | This 'Coercible' constraint holds for the scalar types for which 'Mixed' -- is defined. mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) - => IIxX sh -> a -> Mixed sh a + => IShX sh -> a -> Mixed sh a mconstant sh x = coerce (mconstantP sh x) mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a @@ -648,7 +621,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 f arr1 arr2 - memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a) + memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ @@ -657,9 +630,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where mshapeTree (Ranked arr) | Refl <- lemRankReplicate (Proxy @n) , Dict <- lemKnownReplicate (Proxy @n) - = first ixCvtXR (mshapeTree arr) - - mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a)) + = first shCvtXR (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -675,7 +646,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWrite sh idx arr @@ -683,7 +654,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where vecs) mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () mvecsWritePartial sh idx arr vecs @@ -696,7 +667,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) vecs) - mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) mvecsFreeze sh vecs | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @@ -712,6 +683,8 @@ data ShS sh where ZSS :: ShS '[] (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) deriving instance Show (ShS sh) +deriving instance Eq (ShS sh) +deriving instance Ord (ShS sh) infixr 3 :$$ -- | A statically-known shape of a shape-typed array. @@ -726,9 +699,9 @@ sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) where - go :: ShS sh' -> StaticShapeX (MapJust sh') - go ZSS = ZSX - go (n :$$ sh) = n :$@ go sh + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKSX + go (n :$$ sh) = n :!$@ go sh lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 @@ -777,7 +750,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 f arr1 arr2 - memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a) + memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ @@ -785,9 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshapeTree (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) - = first (ixCvtXS (knownShape @sh)) (mshapeTree arr) - - mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a)) + = first (shCvtXS (knownShape @sh)) (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -803,7 +774,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWrite sh idx arr @@ -811,7 +782,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where vecs) mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs @@ -824,7 +795,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) - mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @@ -927,6 +898,8 @@ newtype ShR n i = ShR (ListR n i) deriving (Show, Eq, Ord) deriving newtype (Functor, Foldable) +type IShR n = ShR n Int + pattern ZSR :: forall n i. () => n ~ Z => ShR n i pattern ZSR = ShR ZR @@ -957,20 +930,29 @@ ixCvtXR ZIX = ZIR ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx ixCvtXR (n :.? idx) = n :.: ixCvtXR idx +shCvtXR :: IShX sh -> IShR (X.Rank sh) +shCvtXR ZSX = ZSR +shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx +shCvtXR (n :$? idx) = n :$: shCvtXR idx + ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: idx) = n :.? ixCvtRX idx -shapeSizeR :: IIxR n -> Int -shapeSizeR ZIR = 1 -shapeSizeR (n :.: sh) = n * shapeSizeR sh +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: idx) = n :$? shCvtRX idx + +shapeSizeR :: IShR n -> Int +shapeSizeR ZSR = 1 +shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n +rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) - = ixCvtXR (mshape arr) + = shCvtXR (mshape arr) rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) @@ -983,12 +965,12 @@ rindexPartial (Ranked arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a +rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f - | Dict <- knownIxR sh + | Dict <- knownShR sh , Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) - = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) + = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. (KnownINat n2, Elt a) @@ -1005,7 +987,7 @@ rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) - . X.sumOuter (() :$? ZSX) (knownShapeX @(Replicate n Nothing)) + . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) $ arr @@ -1021,10 +1003,10 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend rscalar :: Elt a => a -> Ranked I0 a rscalar x = Ranked (mscalar x) -rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVector sh v | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromVector (ixCvtRX sh) v) + = Ranked (mfromVector (shCvtRX sh) v) rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a rfromList l @@ -1043,13 +1025,13 @@ rtoList1 = map runScalar . rtoList runScalar :: Elt a => Ranked I0 a -> a runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) rconstantP sh x | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mconstantP (ixCvtRX sh) x) + = Ranked (mconstantP (shCvtRX sh) x) rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) - => IIxR n -> a -> Ranked n a + => IShR n -> a -> Ranked n a rconstant sh x = coerce (rconstantP sh x) rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a @@ -1141,27 +1123,30 @@ zeroIxS :: ShS sh -> IIxS sh zeroIxS ZSS = ZIS zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh --- TODO: this function should not exist perhaps -cvtShSIxS :: ShS sh -> IIxS sh -cvtShSIxS ZSS = ZIS -cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh - ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh ixCvtXS ZSS ZIX = ZIS ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx +shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh +shCvtXS ZSS ZSX = ZSS +shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx + ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX ZIS = ZIX ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh -shapeSizeS :: IIxS sh -> Int -shapeSizeS ZIS = 1 -shapeSizeS (n :.$ sh) = n * shapeSizeS sh +shCvtSX :: ShS sh -> IShX (MapJust sh) +shCvtSX ZSS = ZSX +shCvtSX (n :$$ sh) = n :$@ shCvtSX sh + +shapeSizeS :: ShS sh -> Int +shapeSizeS ZSS = 1 +shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh -- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh -sshape _ = cvtShSIxS (knownShape @sh) +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh +sshape _ = knownShape @sh sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) @@ -1177,7 +1162,7 @@ sindexPartial (Shaped arr) idx = sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) + = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh))) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) @@ -1194,7 +1179,7 @@ ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) - . X.sumOuter (natSing @n :$@ ZSX) (knownShapeX @(MapJust sh)) + . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) $ arr @@ -1213,7 +1198,7 @@ sscalar x = Shaped (mscalar x) sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) sfromVector v | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v) + = Shaped (mfromVector (shCvtSX (knownShape @sh)) v) sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) => NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1236,7 +1221,7 @@ sunScalar arr = sindex arr ZIS sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) sconstantP x | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x) + = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) => a -> Shaped sh a |