aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-17 22:53:52 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-17 22:53:52 +0200
commit4adbbd8e2e635cc4c647be40f0dd258668dd2452 (patch)
tree1f89ce0adc26ed98e80e759f2bf403b107d667e1 /src/Data/Array/Nested/Internal.hs
parent06625c89089044b064bbc6cf36ea4e83199c19a4 (diff)
More WIP singletonisation
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs268
1 files changed, 135 insertions, 133 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index d2883a7..e7e2fd6 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -23,11 +23,10 @@
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# OPTIONS -Wno-unused-imports #-}
+
{-|
TODO:
-* We should be more consistent in whether functions take a 'StaticShX'
- argument or a 'KnownShapeX' constraint.
-
* Allow downtyping certain dimensions, and write conversions between Mixed,
Ranked and Shaped
@@ -89,7 +88,7 @@ import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate)
+import Data.Array.Mixed
import qualified Data.Array.Mixed as X
@@ -179,9 +178,8 @@ lemReplicatePlusApp _ _ _ = go (natSing @n)
= sym (X.lemReplicateSucc @a @(n'm1 + m))
shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
-shAppSplit _ ZKSX idx = (ZSX, idx)
-shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx)
-shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)
+shAppSplit _ ZKX idx = (ZSX, idx)
+shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx)
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
@@ -197,11 +195,11 @@ class PrimElt a where
fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
- default fromPrimitive :: Coercible (Mixed' sh a) (Mixed' sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
- fromPrimitive (Mixed sh m) = Mixed sh (coerce m)
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
- default toPrimitive :: Coercible (Mixed' sh (Primitive a)) (Mixed' sh a) => Mixed sh a -> Mixed sh (Primitive a)
- toPrimitive (Mixed sh m) = Mixed sh (coerce m)
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
-- [PRIMITIVE ELEMENT TYPES LIST]
instance PrimElt Int
@@ -218,37 +216,31 @@ instance PrimElt ()
--
-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
-- class.
-data Mixed sh a = Mixed (IShX sh) (Mixed' sh a)
-deriving instance Show (Mixed' sh a) => Show (Mixed sh a)
-
-unMixed :: Mixed sh a -> Mixed' sh a
-unMixed (Mixed _ arr) = arr
-
-type Mixed' :: [Maybe Nat] -> Type -> Type
-data family Mixed' sh a
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
-- that you're not supposed to see. In particular, you might see (nonempty)
-- sizes of the elements of an empty array, which is information that should
-- ostensibly not exist; the full array is still empty.
-newtype instance Mixed' sh (Primitive a) = M_Primitive (XArray sh a)
+data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
deriving (Show)
-- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed' sh Int = M_Int (XArray sh Int)
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int))
deriving (Show)
-newtype instance Mixed' sh Double = M_Double (XArray sh Double)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double))
deriving (Show)
-newtype instance Mixed' sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) -- no content, orthotope optimises this (via Vector)
deriving (Show)
-- etc.
-data instance Mixed' sh (a, b) = M_Tup2 !(Mixed' sh a) !(Mixed' sh b)
-deriving instance (Show (Mixed' sh a), Show (Mixed' sh b)) => Show (Mixed' sh (a, b))
+data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b)
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
-- etc.
-newtype instance Mixed' sh1 (Mixed sh2 a) = M_Nest (Mixed' (sh1 ++ sh2) a)
-deriving instance Show (Mixed' (sh1 ++ sh2) a) => Show (Mixed' sh1 (Mixed sh2 a))
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(StaticShX sh1) !(Mixed (sh1 ++ sh2) a)
+deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-- | Internal helper data family mirroring 'Mixed' that consists of mutable
@@ -279,7 +271,7 @@ type family ShapeTree a where
ShapeTree () = ()
ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- ShapeTree (Mixed' sh a) = (IShX sh, ShapeTree a)
+ ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a)
ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
@@ -357,37 +349,37 @@ class Elt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
- mshape (Mixed sh _) = sh
- mindex (Mixed _ (M_Primitive a)) i = Primitive (X.index a i)
- mindexPartial (Mixed sh (M_Primitive a)) i = Mixed (X.shDropIx sh i) (M_Primitive (X.indexPartial a i))
- mscalar (Primitive x) = Mixed ZSX (M_Primitive (X.scalar x))
+ mshape (M_Primitive sh _) = sh
+ mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromList1 sn l@(arr1 :| _) =
- let sh = sn :$@ mshape arr1
- in Mixed sh (M_Primitive (X.fromList1 (X.staticShapeFrom sh) (map (coerce . unMixed) (toList l))))
- mtoList1 (Mixed sh (M_Primitive arr)) = map (Mixed (X.shTail sh) . coerce) (X.toList1 arr)
+ let sh = SKnown sn :$% mshape arr1
+ in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr)
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
- mlift ssh2 f (Mixed _ (M_Primitive a))
+ mlift ssh2 f (M_Primitive _ a)
| Refl <- X.lemAppNil @sh1
, Refl <- X.lemAppNil @sh2
- , let result = f ZKSX a
- = Mixed (X.shape ssh2 result) (M_Primitive result)
+ , let result = f ZKX a
+ = M_Primitive (X.shape ssh2 result) result
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
- mlift2 ssh3 f (Mixed _ (M_Primitive a)) (Mixed _ (M_Primitive b))
+ mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
| Refl <- X.lemAppNil @sh1
, Refl <- X.lemAppNil @sh2
, Refl <- X.lemAppNil @sh3
- , let result = f ZKSX a b
- = Mixed (X.shape ssh3 result) (M_Primitive result)
+ , let result = f ZKX a b
+ = M_Primitive (X.shape ssh3 result) result
- memptyArray sh = Mixed sh (M_Primitive (X.empty sh))
+ memptyArray sh = M_Primitive sh (X.empty sh)
mshapeTree _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
@@ -400,18 +392,14 @@ instance Storable a => Elt (Primitive a) where
mvecsWritePartial
:: forall sh' sh s.
IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (Mixed sh' (M_Primitive arr)) (MV_Primitive v) = do
+ mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
let arrsh = X.shape (X.staticShapeFrom sh') arr
offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh))
VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr)
- mvecsFreeze sh (MV_Primitive v) = Mixed sh . M_Primitive . X.fromVector sh <$> VS.freeze v
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
-
-
-
-TODO -- should rewrite methods of Elt class to take ' in their name, and work on Mixed' instead of Mixed (and take explicit StaticShapeX). Then wrap all of the public functions to work on Mixed. Then don't export the contents of Elt from Nested.hs, and export the wrappers instead. This also makes the haddocks more consistent.
deriving via Primitive Int instance Elt Int
deriving via Primitive Double instance Elt Double
deriving via Primitive () instance Elt ()
@@ -422,11 +410,12 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromList1 l = M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l))
- (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromList1 n l =
+ M_Tup2 (mfromList1 n ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromList1 n ((\(M_Tup2 _ y) -> y) <$> l))
mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b)
- mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
- mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
+ mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
@@ -443,66 +432,74 @@ instance (Elt a, Elt b) => Elt (a, b) where
mvecsWritePartial sh i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShX :: [Maybe Nat] -> Constraint
+class KnownShX sh where knownShX :: StaticShX sh
+instance KnownShX '[] where knownShX = ZKX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+
-- Arrays of arrays are just arrays, but with more dimensions.
-instance Elt a => Elt (Mixed sh' a) where
+instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
- mshape (M_Nest arr)
- | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
- = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
+ mshape (M_Nest ssh arr)
+ = fst (shAppSplit (Proxy @sh') ssh (mshape arr))
- mindex (M_Nest arr) i = mindexPartial arr i
+ mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
+ mindex (M_Nest _ arr) i = mindexPartial arr i
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest arr) i
+ mindexPartial (M_Nest ssh arr) i
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (X.ssxDropIx ssh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
- mscalar = M_Nest
+ mscalar = M_Nest ZKX
mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Just n : sh) (Mixed sh' a)
- mfromList1 l
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh'))
- = M_Nest (mfromList1 (coerce l))
+ mfromList1 sn l@(arr :| _) =
+ M_Nest (SKnown sn :!% X.staticShapeFrom (mshape arr))
+ (mfromList1 sn ((\(M_Nest _ a) -> a) <$> l))
- mtoList1 (M_Nest arr) = coerce (mtoList1 arr)
+ mtoList1 (M_Nest ssh arr) = map (M_Nest (X.ssxTail ssh)) (mtoList1 arr)
mlift :: forall sh1 sh2.
- (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift f (M_Nest arr)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- = M_Nest (mlift f' arr)
+ mlift ssh2 f (M_Nest ssh1 arr) = M_Nest ssh2 (mlift (X.ssxAppend ssh2 ssh') f' arr)
where
+ ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr)))
+
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
- f' _
+ f' sshT
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
, Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
- = f (Proxy @(sh' ++ shT))
+ = f (X.ssxAppend ssh' sshT)
mlift2 :: forall sh1 sh2 sh3.
- (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ StaticShX sh3
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
- mlift2 f (M_Nest arr1) (M_Nest arr2)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh'))
- = M_Nest (mlift2 f' arr1 arr2)
+ mlift2 ssh3 f (M_Nest ssh1 arr1) (M_Nest _ arr2) = M_Nest ssh3 (mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2)
where
+ ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr1)))
+
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
- f' _
+ f' sshT
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
, Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
, Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
- = f (Proxy @(sh' ++ shT))
+ = f (X.ssxAppend ssh' sshT)
- memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh'))))
+ memptyArray sh = M_Nest (X.staticShapeFrom sh) (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh'))))
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh'))))
+ mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -512,12 +509,11 @@ instance Elt a => Elt (Mixed sh' a) where
mvecsUnsafeNew sh example
| X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example))
- (mindex example (X.zeroIxX (knownShapeX @sh')))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh')))
where
sh' = mshape example
- mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a)
+ mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs
@@ -525,12 +521,11 @@ instance Elt a => Elt (Mixed sh' a) where
IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
- mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs
+ mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest (X.staticShapeFrom sh) <$> mvecsFreeze (X.shAppend sh sh') vecs
-- | Create an array given a size and a function that computes the element at a
@@ -572,27 +567,36 @@ mgenerate sh f = case X.enumShape sh of
mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
=> HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
-mtranspose perm
- | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh)))
- = mlift $ \(Proxy @sh') ->
- X.rerankTop (knownShapeX @sh) (knownShapeX @(X.PermutePrefix is sh)) (knownShapeX @sh')
- (X.transpose perm)
+mtranspose perm arr =
+ let ssh = X.staticShapeFrom (mshape arr)
+ sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh)
+ in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
-mappend = mlift2 go
- where go :: forall sh' b. Storable b
- => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
- go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
+mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
+ where
+ sn :$% sh = mshape arr1
+ sm :$% _ = mshape arr2
+ ssh = X.staticShapeFrom sh
+ snm :: SMayNat () SNat (X.AddMaybe n m)
+ snm = case (sn, sm) of
+ (SUnknown{}, _) -> SUnknown ()
+ (SKnown{}, SUnknown{}) -> SUnknown ()
+ (SKnown n, SKnown m) -> SKnown (X.plusSNat n m)
+
+ f :: forall sh' b. Storable b
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (X.ssxAppend ssh ssh')
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVectorP sh v = M_Primitive (X.fromVector sh v)
+mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
mfromVector :: forall sh a. (Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a
mfromVector sh v = fromPrimitive (mfromVectorP sh v)
mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
-mtoVectorP (M_Primitive v) = X.toVector v
+mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (coerce toPrimitive arr)
@@ -607,64 +611,60 @@ munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
mconstantP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mconstantP sh x = M_Primitive (X.constant sh x)
+mconstantP sh x = M_Primitive sh (X.constant sh x)
mconstant :: forall sh a. (Storable a, PrimElt a)
=> IShX sh -> a -> Mixed sh a
mconstant sh x = fromPrimitive (mconstantP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n = withKnownNat n $ mlift $ \_ -> X.slice i n
+mslice i n arr =
+ let _ :$% sh = mshape arr
+ in withKnownNat n $ mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr
msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n = mlift $ \_ -> X.sliceU i n
+msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 = mlift $ \_ -> X.rev1
+mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr
mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
-mreshape sh' = mlift $ \(_ :: Proxy shIn) ->
- X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh'
+mreshape sh' arr =
+ mlift (X.staticShapeFrom sh')
+ (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh')
+ arr
-masXArrayPrimP :: Mixed sh (Primitive a) -> XArray sh a
-masXArrayPrimP (M_Primitive arr) = arr
+masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
+masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
-masXArrayPrim :: PrimElt a => Mixed sh a -> XArray sh a
+masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
masXArrayPrim = masXArrayPrimP . toPrimitive
-mfromXArrayPrimP :: XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP :: IShX sh -> XArray sh a -> Mixed sh (Primitive a)
mfromXArrayPrimP = M_Primitive
-mfromXArrayPrim :: PrimElt a => XArray sh a -> Mixed sh a
-mfromXArrayPrim = fromPrimitive . mfromXArrayPrimP
+mfromXArrayPrim :: PrimElt a => IShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
-mliftPrim :: Storable a
+mliftPrim :: (Storable a, PrimElt a)
=> (a -> a)
- -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
-mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr))
+ -> Mixed sh a -> Mixed sh a
+mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
-mliftPrim2 :: Storable a
+mliftPrim2 :: (Storable a, PrimElt a)
=> (a -> a -> a)
- -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
-mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) =
- M_Primitive (X.XArray (S.zipWithA f arr1 arr2))
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
+ fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
-instance (Storable a, Num a) => Num (Mixed sh (Primitive a)) where
+instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where
(+) = mliftPrim2 (+)
(-) = mliftPrim2 (-)
(*) = mliftPrim2 (*)
negate = mliftPrim negate
abs = mliftPrim abs
signum = mliftPrim signum
- fromInteger n =
- case X.ssxToShape' (knownShapeX @sh) of
- Just sh -> M_Primitive (X.constant sh (fromInteger n))
- Nothing -> error "Data.Array.Nested.fromIntegral: \
- \Unknown components in shape, use explicit mconstant"
-
--- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Mixed sh (Primitive Int) instance Num (Mixed sh Int)
-deriving via Mixed sh (Primitive Double) instance Num (Mixed sh Double)
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant"
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -694,10 +694,10 @@ newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
-- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed' sh (Ranked n a) = M_Ranked (Mixed' sh (Mixed (Replicate n Nothing) a))
-deriving instance Show (Mixed' sh (Mixed (Replicate n Nothing) a)) => Show (Mixed' sh (Ranked n a))
-newtype instance Mixed' sh (Shaped sh' a) = M_Shaped (Mixed' sh (Mixed (MapJust sh' ) a))
-deriving instance Show (Mixed' sh (Mixed (MapJust sh' ) a)) => Show (Mixed' sh (Shaped sh' a))
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a))
+deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a))
newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
@@ -804,6 +804,8 @@ instance (Elt a, KnownNat n) => Elt (Ranked n a) where
-}
-- | The shape of a shape-typed array given as a list of 'SNat' values.
+TODO -- write ListS and implement IxS and ShS in terms of it.
+TODO -- for ListR and ListS, write an uncons function like for ListX and implement the cons pattern synonyms in terms of it directly, instead of using a separate uncons function for both types.
data ShS sh where
ZSS :: ShS '[]
(:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)