aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 22:47:52 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 22:47:52 +0200
commit8b59d8ef4ff97936f2a753d1ce345e0404c26b2b (patch)
tree947f75cb43982fbdb551dc329f036b0591f3c2b2 /src/Data/Array/Nested/Internal
parentf0752d67cd188f438280e1f0c692dc1f5f14a190 (diff)
Clearer module purposes
Thanks Mikolaj for discussion
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs28
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs59
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs741
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs446
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs467
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs379
6 files changed, 2120 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
new file mode 100644
index 0000000..e101981
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Convert.hs
@@ -0,0 +1,28 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.Array.Nested.Internal.Convert where
+
+import Data.Type.Equality
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Shape
+import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Internal.Mixed
+import Data.Array.Nested.Internal.Ranked
+import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Internal.Shaped
+
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr
+
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped arr targetsh
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
new file mode 100644
index 0000000..5ce5373
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Lemmas.hs
@@ -0,0 +1,59 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.Array.Nested.Internal.Lemmas where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Nested.Internal.Shape
+
+
+lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust ZSS = Refl
+lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
+
+lemMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemMapJustTakeLen PNil _ = Refl
+lemMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustTakeLen is sh = Refl
+lemMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemMapJustDropLen PNil _ = Refl
+lemMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustDropLen is sh = Refl
+lemMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemMapJustIndex SZ (_ :$$ _) = Refl
+lemMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemMapJustIndex i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemMapJustIndex _ ZSS = error "Index of empty"
+
+lemMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemMapJustPermute PNil _ = Refl
+lemMapJustPermute (i `PCons` is) sh
+ | Refl <- lemMapJustPermute is sh
+ , Refl <- lemMapJustIndex i sh
+ = Refl
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
new file mode 100644
index 0000000..98871d5
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -0,0 +1,741 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Internal.Mixed where
+
+import Control.DeepSeq (NFData)
+import Control.Monad (forM_, when)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Coerce
+import Data.Foldable (toList)
+import Data.Int
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty(..))
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
+import Foreign.C.Types (CInt)
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Mixed.XArray (XArray(..))
+import Data.Array.Mixed.XArray qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Lemmas
+
+
+-- Invariant in the API
+-- ====================
+--
+-- In the underlying XArray, there is some shape for elements of an empty
+-- array. For example, for this array:
+--
+-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
+-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
+--
+-- the two underlying XArrays have a shape, and those shapes might be anything.
+-- The invariant is that these element shapes are unobservable in the API.
+-- (This is possible because you ought to not be able to get to such an element
+-- without indexing out of bounds.)
+--
+-- Note, though, that the converse situation may arise: the outer array might
+-- be nonempty but then the inner arrays might. This is fine, an invariant only
+-- applies if the _outer_ array is empty.
+--
+-- TODO: can we enforce that the elements of an empty (nested) array have
+-- all-zero shape?
+-- -> no, because mlift and also any kind of internals probing from outsiders
+
+
+-- Primitive element types
+-- =======================
+--
+-- There are a few primitive element types; arrays containing elements of such
+-- type are a newtype over an XArray, which it itself a newtype over a Vector.
+-- Unfortunately, the setup of the library requires us to list these primitive
+-- element types multiple times; to aid in extending the list, all these lists
+-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class Storable a => PrimElt a where
+ fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+ toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
+
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Int
+instance PrimElt Int64
+instance PrimElt Int32
+instance PrimElt CInt
+instance PrimElt Float
+instance PrimElt Double
+instance PrimElt ()
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'mtranspose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
+-- that you're not supposed to see. In particular, you might see (nonempty)
+-- sizes of the elements of an empty array, which is information that should
+-- ostensibly not exist; the full array is still empty.
+
+data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
+ deriving (Show, Eq, Generic)
+
+-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
+deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a))
+
+instance NFData a => NFData (Mixed sh (Primitive a))
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int)
+deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64)
+deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32)
+deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt)
+deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float)
+deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double)
+deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ())
+
+data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
+instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b))
+-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
+
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
+instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a))
+
+
+-- | Internal helper data family mirroring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
+newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
+newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
+newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
+newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
+newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+
+
+mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+
+mliftNumElt2 :: PrimElt a
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 numEltAdd
+ (-) = mliftNumElt2 numEltSub
+ (*) = mliftNumElt2 numEltMul
+ negate = mliftNumElt1 numEltNeg
+ abs = mliftNumElt1 numEltAbs
+ signum = mliftNumElt1 numEltSignum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ recip = mliftNumElt1 floatEltRecip
+ (/) = mliftNumElt2 floatEltDiv
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ exp = mliftNumElt1 floatEltExp
+ log = mliftNumElt1 floatEltLog
+ sqrt = mliftNumElt1 floatEltSqrt
+
+ (**) = mliftNumElt2 floatEltPow
+ logBase = mliftNumElt2 floatEltLogbase
+
+ sin = mliftNumElt1 floatEltSin
+ cos = mliftNumElt1 floatEltCos
+ tan = mliftNumElt1 floatEltTan
+ asin = mliftNumElt1 floatEltAsin
+ acos = mliftNumElt1 floatEltAcos
+ atan = mliftNumElt1 floatEltAtan
+ sinh = mliftNumElt1 floatEltSinh
+ cosh = mliftNumElt1 floatEltCosh
+ tanh = mliftNumElt1 floatEltTanh
+ asinh = mliftNumElt1 floatEltAsinh
+ acosh = mliftNumElt1 floatEltAcosh
+ atanh = mliftNumElt1 floatEltAtanh
+ log1p = mliftNumElt1 floatEltLog1p
+ expm1 = mliftNumElt1 floatEltExpm1
+ log1pexp = mliftNumElt1 floatEltLog1pexp
+ log1mexp = mliftNumElt1 floatEltLog1mexp
+
+
+-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: Mixed sh a -> IShX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
+ mscalar :: a -> Mixed '[] a
+
+ -- | All arrays in the list, even subarrays inside @a@, must have the same
+ -- shape; if they do not, a runtime error will be thrown. See the
+ -- documentation of 'mgenerate' for more information about this restriction.
+ -- Furthermore, the length of the list must correspond with @n@: if @n@ is
+ -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
+ -- thrown.
+ --
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
+
+ -- | Note: this library makes no particular guarantees about the shapes of
+ -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
+ -- full 'XArray' and as such you can distinguish different empty arrays by
+ -- the "shapes" of their elements. This information is meaningless, so you
+ -- should not use it.
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- | See the documentation for 'mlift'.
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+
+ -- ====== PRIVATE METHODS ====== --
+
+ -- | Tree giving the shape of every array component.
+ type ShapeTree a
+
+ mshapeTree :: a -> ShapeTree a
+
+ mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
+
+ mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+
+ mshowShapeTree :: Proxy a -> ShapeTree a -> String
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- | Element types for which we have evidence of the (static part of the) shape
+-- in a type class constraint. Compare the instance contexts of the instances
+-- of this class with those of 'Elt': some instances have an additional
+-- "known-shape" constraint.
+--
+-- This class is (currently) only required for 'mgenerate' / 'rgenerate' /
+-- 'sgenerate'.
+class Elt a => KnownElt a where
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IShX sh -> Mixed sh a
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents.
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance Storable a => Elt (Primitive a) where
+ mshape (M_Primitive sh _) = sh
+ mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
+ mfromListOuter l@(arr1 :| _) =
+ let sh = SUnknown (length l) :$% mshape arr1
+ in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift ssh2 f (M_Primitive _ a)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , let result = f ZKX a
+ = M_Primitive (X.shape ssh2 result) result
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
+ mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
+ , let result = f ZKX a b
+ = M_Primitive (X.shape ssh3 result) result
+
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
+ let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShape sh) perm arr)
+
+ type ShapeTree (Primitive a) = ()
+ mshapeTree _ = ()
+ mshapeTreeEq _ () () = True
+ mshapeTreeEmpty _ () = False
+ mshowShapeTree _ () = "()"
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShape sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Int64 instance Elt Int64
+deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive CInt instance Elt CInt
+deriving via Primitive Double instance Elt Double
+deriving via Primitive Float instance Elt Float
+deriving via Primitive () instance Elt ()
+
+instance Storable a => KnownElt (Primitive a) where
+ memptyArray sh = M_Primitive sh (X.empty sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance KnownElt Int
+deriving via Primitive Int64 instance KnownElt Int64
+deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive CInt instance KnownElt CInt
+deriving via Primitive Double instance KnownElt Double
+deriving via Primitive Float instance KnownElt Float
+deriving via Primitive () instance KnownElt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+
+ mcast ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+
+ type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
+ mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
+ mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
+ mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance Elt a => Elt (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting depth because it repeatedly
+ -- truncates the shape vector to one a little shorter. Fix with a
+ -- moverlongShape method, a prefix of which is mshape.
+ mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
+ mshape (M_Nest sh arr)
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+
+ mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
+ mindex (M_Nest _ arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest sh arr) i
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mscalar = M_Nest ZSX
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
+ M_Nest (SUnknown (length l) :$% mshape arr)
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift ssh2 f (M_Nest sh1 arr) =
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ in M_Nest sh2 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
+ mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ in M_Nest sh3 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcast ssh1 sh2 _ (M_Nest sh1T arr)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (shxPermutePrefix perm sh)
+ (mtranspose perm arr)
+
+ type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+
+ mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+
+instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
+ memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
+
+ mvecsUnsafeNew sh example
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ where
+ sh' = mshape example
+
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
+
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+--
+-- __WARNING__: It is required that every @a@ returned by the argument to
+-- 'mgenerate' has the same shape. For example, the following will throw a
+-- runtime error:
+--
+-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
+-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
+-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
+-- > ...
+--
+-- because the size of the inner 'mgenerate' is not always the same (it depends
+-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
+-- the entire hierarchy (after distributing out tuples) must be a rectangular
+-- array. The type of 'mgenerate' allows this requirement to be broken very
+-- easily, hence the runtime check.
+mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum sh of
+ [] -> memptyArray sh
+ firstidx : restidxs ->
+ let firstelem = f (ixxZero' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEmpty (Proxy @a) shapetree
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ restidxs $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
+ mvecsFreeze sh vecs
+
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1P (M_Primitive (n :$% sh) arr) =
+ let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+
+mappend :: forall n m sh a. Elt a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
+ where
+ sn :$% sh = mshape arr1
+ sm :$% _ = mshape arr2
+ ssh = ssxFromShape sh
+ snm :: SMayNat () SNat (AddMaybe n m)
+ snm = case (sn, sm) of
+ (SUnknown{}, _) -> SUnknown ()
+ (SKnown{}, SUnknown{}) -> SUnknown ()
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
+
+ f :: forall sh' b. Storable b
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
+
+mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+
+mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+
+mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
+mtoVectorP (M_Primitive _ v) = X.toVector v
+
+mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
+mtoVector arr = mtoVectorP (toPrimitive arr)
+
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoListOuter
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
+
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr ZIX
+
+mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+ let sh1 = shxDropSSX sh ssh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
+
+-- | See the caveats at @X.rerank@.
+mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
+mrerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = ssxFromShape (mshape arr)
+ in mlift (ssxAppend (ssxFromShape sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+
+mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+mslice i n arr =
+ let _ :$% sh = mshape arr
+ in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+
+msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+
+mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
+mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+
+mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' arr =
+ mlift (ssxFromShape sh')
+ (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ arr
+
+miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
+miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+
+masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
+masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
+
+masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
+masXArrayPrim = masXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
+
+mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+
+mliftPrim :: PrimElt a
+ => (a -> a)
+ -> Mixed sh a -> Mixed sh a
+mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: PrimElt a
+ => (a -> a -> a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
+ fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
new file mode 100644
index 0000000..d5bd70f
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -0,0 +1,446 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Internal.Ranked where
+
+import Prelude hiding (mappend)
+
+import Control.DeepSeq (NFData)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed.XArray (XArray(..))
+import Data.Array.Mixed.XArray qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Nested.Internal.Mixed
+import Data.Array.Nested.Internal.Shape
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
+deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a)
+deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a)
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance Elt a => Elt (Ranked n a) where
+ mshape (M_Ranked arr) = mshape arr
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
+
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
+ type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+
+ mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArray i
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArray i
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
+
+arithPromoteRanked :: forall n a. PrimElt a
+ => (forall sh. Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a
+arithPromoteRanked = coerce
+
+arithPromoteRanked2 :: forall n a. PrimElt a
+ => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a -> Ranked n a
+arithPromoteRanked2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+ (+) = arithPromoteRanked2 (+)
+ (-) = arithPromoteRanked2 (-)
+ (*) = arithPromoteRanked2 (*)
+ negate = arithPromoteRanked negate
+ abs = arithPromoteRanked abs
+ signum = arithPromoteRanked signum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
+ recip = arithPromoteRanked recip
+ (/) = arithPromoteRanked2 (/)
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
+ exp = arithPromoteRanked exp
+ log = arithPromoteRanked log
+ sqrt = arithPromoteRanked sqrt
+ (**) = arithPromoteRanked2 (**)
+ logBase = arithPromoteRanked2 logBase
+ sin = arithPromoteRanked sin
+ cos = arithPromoteRanked cos
+ tan = arithPromoteRanked tan
+ asin = arithPromoteRanked asin
+ acos = arithPromoteRanked acos
+ atan = arithPromoteRanked atan
+ sinh = arithPromoteRanked sinh
+ cosh = arithPromoteRanked cosh
+ tanh = arithPromoteRanked tanh
+ asinh = arithPromoteRanked asinh
+ acosh = arithPromoteRanked acosh
+ atanh = arithPromoteRanked atanh
+ log1p = arithPromoteRanked GHC.Float.log1p
+ expm1 = arithPromoteRanked GHC.Float.expm1
+ log1pexp = arithPromoteRanked GHC.Float.log1pexp
+ log1mexp = arithPromoteRanked GHC.Float.log1mexp
+
+
+rshape :: forall n a. Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = shCvtXR' (mshape arr)
+
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
+
+rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (castWith (subst2 (lemReplicatePlusApp (ixrToSNat idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (ixCvtRX idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgenerate sh f
+ | sn@SNat <- shrToSNat sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
+ = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
+
+-- | See the documentation of 'mlift'.
+rlift :: forall n1 n2 a. Elt a
+ => SNat n2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
+
+-- | See the documentation of 'mlift2'.
+rlift2 :: forall n1 n2 n3 a. Elt a
+ => SNat n3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
+rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+
+rsumOuter1P :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1P (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (msumOuter1P arr)
+
+rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+
+rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a
+rtranspose perm arr
+ | sn@SNat <- shrToSNat (rshape arr)
+ , Dict <- lemKnownReplicate sn
+ , length perm <= fromIntegral (natVal (Proxy @n))
+ = rlift sn
+ (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
+ arr
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+
+rappend :: forall n a. Elt a
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend arr1 arr2
+ | sn@SNat <- shrToSNat (rshape arr1)
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
+ arr1 arr2
+
+rscalar :: Elt a => a -> Ranked 0 a
+rscalar x = Ranked (mscalar x)
+
+rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
+ | Dict <- lemKnownReplicate (shrToSNat sh)
+ = Ranked (mfromVectorP (shCvtRX sh) v)
+
+rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+
+rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
+rtoVectorP = coerce mtoVectorP
+
+rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
+rtoVector = coerce mtoVector
+
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 l = Ranked (mfromList1 l)
+
+rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
+rfromList1Prim l = Ranked (mfromList1Prim l)
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+
+rtoList1 :: Elt a => Ranked 1 a -> [a]
+rtoList1 = map runScalar . rtoListOuter
+
+rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
+rfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
+
+rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
+rfromOrthotope sn arr
+ | Refl <- lemRankReplicate sn
+ = let xarr = XArray arr
+ in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
+
+runScalar :: Elt a => Ranked 0 a -> a
+runScalar arr = rindex arr ZIR
+
+rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
+rrerankP sn sh2 f (Ranked arr)
+ | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
+ , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
+ = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr)
+
+-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
+-- input array, then there is no way to deduce the full shape of the output
+-- array (more precisely, the @n2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the @n2@ part of the output shape with zeros.
+--
+-- For example, if:
+--
+-- @
+-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- f :: Ranked 2 Int -> Ranked 3 Float
+-- @
+--
+-- then:
+--
+-- @
+-- rrerank _ _ _ f arr :: Ranked 5 Float
+-- @
+--
+-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
+-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
+-- to return an array with shape all-0 here (it probably didn't), but there is
+-- no better number to put here absent a subarray of the input to pass to @f@.
+rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked (n + n1) a -> Ranked (n + n2) b
+rrerank ssh sh2 f (rtoPrimitive -> arr) =
+ rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
+
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (shrToSNat sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shCvtRX sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
+ | Dict <- lemKnownReplicate (shrToSNat sh)
+ = Ranked (mreplicateScalP (shCvtRX sh) x)
+
+rreplicateScal :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+
+rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n arr
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = rlift (shrToSNat (rshape arr))
+ (\_ -> X.sliceU i n)
+ arr
+
+rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 arr =
+ rlift (shrToSNat (rshape arr))
+ (\(_ :: StaticShX sh') ->
+ case lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
+ arr
+
+rreshape :: forall n n' a. Elt a
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' rarr@(Ranked arr)
+ | Dict <- lemKnownReplicate (shrToSNat (rshape rarr))
+ , Dict <- lemKnownReplicate (shrToSNat sh')
+ = Ranked (mreshape (shCvtRX sh') arr)
+
+riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
+riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+
+rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
+rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr)
+
+rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
+rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)
+
+rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked arr
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))
+ , Refl <- lemRankReplicate (shxRank (mshape arr))
+ = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ where
+ convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
+ convSh ZSX = ZSX
+ convSh (smn :$% (sh :: IShX sh'T))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
+ = SUnknown (fromSMayNat' smn) :$% convSh sh
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
new file mode 100644
index 0000000..319f171
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -0,0 +1,467 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Internal.Shape where
+
+import Data.Array.Mixed.Types
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Kind (Type, Constraint)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR sh' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrToSNat :: ListR n i -> SNat n
+listrToSNat ZR = SNat
+listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrToSNat sperm, listrToSNat sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+type IIxR n = IxR n Int
+
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
+ixCvtXR ZIX = ZIR
+ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
+
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
+ixCvtRX ZIR = ZIX
+ixCvtRX (n :.: (idx :: IxR m Int)) =
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixCvtRX idx)
+
+ixrToSNat :: IxR n i -> SNat n
+ixrToSNat (IxR sh) = listrToSNat sh
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: (ShR sh) = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+
+shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
+shCvtXR' ZSX =
+ castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
+ ZSR
+shCvtXR' (n :$% (idx :: IShX sh))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
+ castWith (subst2 (lem1 @sh Refl))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ where
+ lem1 :: forall sh' n' k.
+ k : sh' :~: Replicate n' Nothing
+ -> Rank sh' + 1 :~: n'
+ lem1 Refl = unsafeCoerceRefl
+
+ lem2 :: k : sh :~: Replicate n Nothing
+ -> sh :~: Replicate (Rank sh) Nothing
+ lem2 Refl = unsafeCoerceRefl
+
+shCvtRX :: IShR n -> IShX (Replicate n Nothing)
+shCvtRX ZSR = ZSX
+shCvtRX (n :$: (idx :: ShR m Int)) =
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shCvtRX idx)
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrToSNat :: ShR n i -> SNat n
+shrToSNat (ShR sh) = listrToSNat sh
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList = go (SNat @n)
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error "IsList(ListR): Mismatched list length"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList
+
+
+type role ListS nominal representational
+type ListS :: [Nat] -> (Nat -> Type) -> Type
+data ListS sh f where
+ ZS :: ListS '[] f
+ -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
+ (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+infixr 3 ::$
+
+instance (forall n. Show (f n)) => Show (ListS sh f) where
+ showsPrec _ = listsShow shows
+
+data UnconsListSRes f sh1 =
+ forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
+listsUncons ZS = Nothing
+
+listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
+listsFmap _ ZS = ZS
+listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
+
+listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFold _ ZS = mempty
+listsFold f (x ::$ xs) = f x <> listsFold f xs
+
+listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListS sh' f -> ShowS
+ go _ ZS = id
+ go prefix (x ::$ xs) = showString prefix . f x . go "," xs
+
+listsToList :: ListS sh (Const i) -> [i]
+listsToList ZS = []
+listsToList (Const i ::$ is) = i : listsToList is
+
+listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend ZS idx' = idx'
+listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+
+listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm PNil _ = ZS
+listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
+listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm PNil sh = sh
+listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
+listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
+ case listsIndex (Proxy @is') (Proxy @sh) i sh of
+ (item, SNat) -> item ::$ listsPermute is sh
+
+-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
+listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listsIndex p pT i sh
+listsIndex _ _ _ ZS = error "Index into empty shape"
+
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen = coerce (listsTakeLenPerm @SNat)
+
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute = coerce (listsPermute @SNat)
+
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+
+applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+
+applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
+applyPermIxS = coerce (applyPermS @(Const i))
+
+applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
+applyPermShS = coerce (applyPermS @SNat)
+
+
+-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (ListS sh (Const i))
+ deriving (Eq, Ord)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS = IxS ZS
+
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
+ where i :.$ IxS shl = IxS (Const i ::$ shl)
+infixr 3 :.$
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+type IIxS sh = IxS sh Int
+
+instance Show i => Show (IxS sh i) where
+ showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+
+instance Functor (IxS sh) where
+ fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
+
+instance Foldable (IxS sh) where
+ foldMap f (IxS l) = listsFold (f . getConst) l
+
+ixsZero :: ShS sh -> IIxS sh
+ixsZero ZSS = ZIS
+ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
+
+ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
+ixCvtXS ZSS ZIX = ZIS
+ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
+
+ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
+ixCvtSX ZIS = ZIX
+ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
+
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ListS sh SNat)
+ deriving (Eq, Ord)
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS = ShS ZS
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
+ where i :$$ ShS shl = ShS (i ::$ shl)
+
+infixr 3 :$$
+
+{-# COMPLETE ZSS, (:$$) #-}
+
+instance Show (ShS sh) where
+ showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+
+shsLength :: ShS sh -> Int
+shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)
+
+shsToList :: ShS sh -> [Int]
+shsToList ZSS = []
+shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+
+shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
+shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerceRefl
+shCvtXS' (SUnknown _ :$% _) = error "impossible"
+
+shCvtSX :: ShS sh -> IShX (MapJust sh)
+shCvtSX ZSS = ZSX
+shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
+
+shsSize :: ShS sh -> Int
+shsSize ZSS = 1
+shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownShS sh => IsList (ListS sh (Const i)) where
+ type Item (ListS sh (Const i)) = i
+ fromList topl = go (knownShS @sh) topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listsToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShS sh => IsList (IxS sh i) where
+ type Item (IxS sh i) = i
+ fromList = IxS . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and values are checked at runtime.
+instance KnownShS sh => IsList (ShS sh) where
+ type Item (ShS sh) = Int
+ fromList topl = ShS (go (knownShS @sh) topl)
+ where
+ go :: ShS sh' -> [Int] -> ListS sh' SNat
+ go ZSS [] = ZS
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = sn ::$ go sh is
+ | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shsToList
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
new file mode 100644
index 0000000..afa91eb
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -0,0 +1,379 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Internal.Shaped where
+
+import Prelude hiding (mappend)
+
+import Control.DeepSeq (NFData)
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.TypeLits
+
+import Data.Array.Mixed.XArray (XArray)
+import Data.Array.Mixed.XArray qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Internal.Mixed
+import Data.Array.Nested.Internal.Shape
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a)
+deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a)
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
+
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArray i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+arithPromoteShaped :: forall sh a. PrimElt a
+ => (forall shx. Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a
+arithPromoteShaped = coerce
+
+arithPromoteShaped2 :: forall sh a. PrimElt a
+ => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a -> Shaped sh a
+arithPromoteShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = arithPromoteShaped2 (+)
+ (-) = arithPromoteShaped2 (-)
+ (*) = arithPromoteShaped2 (*)
+ negate = arithPromoteShaped negate
+ abs = arithPromoteShaped abs
+ signum = arithPromoteShaped signum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ recip = arithPromoteShaped recip
+ (/) = arithPromoteShaped2 (/)
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ exp = arithPromoteShaped exp
+ log = arithPromoteShaped log
+ sqrt = arithPromoteShaped sqrt
+ (**) = arithPromoteShaped2 (**)
+ logBase = arithPromoteShaped2 logBase
+ sin = arithPromoteShaped sin
+ cos = arithPromoteShaped cos
+ tan = arithPromoteShaped tan
+ asin = arithPromoteShaped asin
+ acos = arithPromoteShaped acos
+ atan = arithPromoteShaped atan
+ sinh = arithPromoteShaped sinh
+ cosh = arithPromoteShaped cosh
+ tanh = arithPromoteShaped tanh
+ asinh = arithPromoteShaped asinh
+ acosh = arithPromoteShaped acosh
+ atanh = arithPromoteShaped atanh
+ log1p = arithPromoteShaped GHC.Float.log1p
+ expm1 = arithPromoteShaped GHC.Float.expm1
+ log1pexp = arithPromoteShaped GHC.Float.log1pexp
+ log1mexp = arithPromoteShaped GHC.Float.log1mexp
+
+
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shCvtXS' (mshape arr)
+
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
+
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx _ _ ZIS = ZSS
+shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+
+sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
+sindexPartial sarr@(Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
+ (ixCvtSX idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
+
+-- | See the documentation of 'mlift'.
+slift :: forall sh1 sh2 a. Elt a
+ => ShS sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
+
+-- | See the documentation of 'mlift'.
+slift2 :: forall sh1 sh2 sh3 a. Elt a
+ => ShS sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
+
+ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+
+ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ , Refl <- lemMapJustTakeLen perm (sshape sarr)
+ , Refl <- lemMapJustDropLen perm (sshape sarr)
+ , Refl <- lemMapJustPermute perm (shsTakeLen perm (sshape sarr))
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
+ = Shaped (mtranspose perm arr)
+
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
+
+sscalar :: Elt a => a -> Shaped '[] a
+sscalar x = Shaped (mscalar x)
+
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
+
+sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+
+stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
+stoVectorP = coerce mtoVectorP
+
+stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
+stoVector = coerce mtoVector
+
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
+
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
+
+sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+
+stoList1 :: Elt a => Shaped '[n] a -> [a]
+stoList1 = map sunScalar . stoListOuter
+
+sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromListPrim sn l
+ | Refl <- lemAppNil @'[Just n]
+ = let ssh = SUnknown () :!% ZKX
+ xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
+ in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+
+sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr ZIS
+
+srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
+srerankP sh sh2 f sarr@(Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh1)
+ , Refl <- lemMapJustApp sh (Proxy @sh2)
+ = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
+ (shCvtSX sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr)
+
+srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
+srerank sh sh2 f (stoPrimitive -> arr) =
+ sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shCvtSX sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
+
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n@SNat arr =
+ let _ :$$ sh = sshape arr
+ in slift (n :$$ sh) (\_ -> X.slice i n) arr
+
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
+
+sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
+
+siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
+siota sn = Shaped (miota sn)
+
+sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)
+
+sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
+
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
+
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => Mixed sh a -> ShS sh' -> Shaped sh' a
+mcastToShaped arr targetsh
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(MapJust sh')
+ , Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)