aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs955
1 files changed, 955 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
new file mode 100644
index 0000000..ec19c21
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -0,0 +1,955 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Mixed where
+
+import Prelude hiding (mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad (forM_, when)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (bimap)
+import Data.Coerce
+import Data.Foldable (toList)
+import Data.Int
+import Data.Kind (Constraint, Type)
+import Data.List.NonEmpty (NonEmpty(..))
+import Data.List.NonEmpty qualified as NE
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
+import Foreign.C.Types (CInt)
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
+
+import Data.Array.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Types
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
+import Data.Array.Nested.Mixed.Shape
+import Data.Bag
+
+
+-- TODO:
+-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+-- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int
+-- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute)
+-- After benchmarking: matmul and matvec
+
+
+
+-- Invariant in the API
+-- ====================
+--
+-- In the underlying XArray, there is some shape for elements of an empty
+-- array. For example, for this array:
+--
+-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
+-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
+--
+-- the two underlying XArrays have a shape, and those shapes might be anything.
+-- The invariant is that these element shapes are unobservable in the API.
+-- (This is possible because you ought to not be able to get to such an element
+-- without indexing out of bounds.)
+--
+-- Note, though, that the converse situation may arise: the outer array might
+-- be nonempty but then the inner arrays might. This is fine, an invariant only
+-- applies if the _outer_ array is empty.
+--
+-- TODO: can we enforce that the elements of an empty (nested) array have
+-- all-zero shape?
+-- -> no, because mlift and also any kind of internals probing from outsiders
+
+
+-- Primitive element types
+-- =======================
+--
+-- There are a few primitive element types; arrays containing elements of such
+-- type are a newtype over an XArray, which it itself a newtype over a Vector.
+-- Unfortunately, the setup of the library requires us to list these primitive
+-- element types multiple times; to aid in extending the list, all these lists
+-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+ deriving (Show)
+
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class (Storable a, Elt a) => PrimElt a where
+ fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+ toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
+
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Bool
+instance PrimElt Int
+instance PrimElt Int64
+instance PrimElt Int32
+instance PrimElt CInt
+instance PrimElt Float
+instance PrimElt Double
+instance PrimElt ()
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'mtranspose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
+-- that you're not supposed to see. In particular, you might see (nonempty)
+-- sizes of the elements of an empty array, which is information that should
+-- ostensibly not exist; the full array is still empty.
+
+data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
+ deriving (Eq, Ord, Generic)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
+
+deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b))
+deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b))
+
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+
+deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a))
+deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a))
+
+
+-- | Internal helper data family mirroring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool)
+newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
+newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
+newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
+newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
+newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
+newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+
+
+showsMixedArray :: (Show a, Elt a)
+ => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@
+ -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@
+ -> Int -> Mixed sh a -> ShowS
+showsMixedArray fromlistPrefix replicatePrefix d arr =
+ showParen (d > 10) $
+ -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here
+ case mtoListLinear arr of
+ hd : _ : _
+ | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) ->
+ showString replicatePrefix . showString " " . showsPrec 11 hd
+ _ ->
+ showString fromlistPrefix . showString " " . shows (mtoListLinear arr)
+
+instance (Show a, Elt a) => Show (Mixed sh a) where
+ showsPrec d arr =
+ let sh = show (shxToList (mshape arr))
+ in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr
+
+instance Elt a => NFData (Mixed sh a) where
+ rnf = mrnf
+
+
+mliftNumElt1 :: (PrimElt a, PrimElt b)
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
+ -> Mixed sh a -> Mixed sh b
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+
+mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
+ -> Mixed sh a -> Mixed sh b -> Mixed sh c
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 (liftO2 . numEltAdd)
+ (-) = mliftNumElt2 (liftO2 . numEltSub)
+ (*) = mliftNumElt2 (liftO2 . numEltMul)
+ negate = mliftNumElt1 (liftO1 . numEltNeg)
+ abs = mliftNumElt1 (liftO1 . numEltAbs)
+ signum = mliftNumElt1 (liftO1 . numEltSignum)
+ -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
+ fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ recip = mliftNumElt1 (liftO1 . floatEltRecip)
+ (/) = mliftNumElt2 (liftO2 . floatEltDiv)
+
+instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ exp = mliftNumElt1 (liftO1 . floatEltExp)
+ log = mliftNumElt1 (liftO1 . floatEltLog)
+ sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
+
+ (**) = mliftNumElt2 (liftO2 . floatEltPow)
+ logBase = mliftNumElt2 (liftO2 . floatEltLogbase)
+
+ sin = mliftNumElt1 (liftO1 . floatEltSin)
+ cos = mliftNumElt1 (liftO1 . floatEltCos)
+ tan = mliftNumElt1 (liftO1 . floatEltTan)
+ asin = mliftNumElt1 (liftO1 . floatEltAsin)
+ acos = mliftNumElt1 (liftO1 . floatEltAcos)
+ atan = mliftNumElt1 (liftO1 . floatEltAtan)
+ sinh = mliftNumElt1 (liftO1 . floatEltSinh)
+ cosh = mliftNumElt1 (liftO1 . floatEltCosh)
+ tanh = mliftNumElt1 (liftO1 . floatEltTanh)
+ asinh = mliftNumElt1 (liftO1 . floatEltAsinh)
+ acosh = mliftNumElt1 (liftO1 . floatEltAcosh)
+ atanh = mliftNumElt1 (liftO1 . floatEltAtanh)
+ log1p = mliftNumElt1 (liftO1 . floatEltLog1p)
+ expm1 = mliftNumElt1 (liftO1 . floatEltExpm1)
+ log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp)
+ log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)
+
+mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
+mquotArray = mliftNumElt2 (liftO2 . intEltQuot)
+mremArray = mliftNumElt2 (liftO2 . intEltRem)
+
+matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
+matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
+
+-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: Mixed sh a -> IShX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
+ mscalar :: a -> Mixed '[] a
+
+ -- | All arrays in the list, even subarrays inside @a@, must have the same
+ -- shape; if they do not, a runtime error will be thrown. See the
+ -- documentation of 'mgenerate' for more information about this restriction.
+ -- Furthermore, the length of the list must correspond with @n@: if @n@ is
+ -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
+ -- thrown.
+ --
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
+
+ -- | Note: this library makes no particular guarantees about the shapes of
+ -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the
+ -- full 'XArray' and as such you can distinguish different empty arrays by
+ -- the "shapes" of their elements. This information is meaningless, so you
+ -- should not use it.
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- | See the documentation for 'mlift'.
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+
+ -- TODO: mliftL is currently unused.
+ -- | All arrays in the input must have equal shapes, including subarrays
+ -- inside their elements.
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a)
+
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+
+ -- | All arrays in the input must have equal shapes, including subarrays
+ -- inside their elements.
+ mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a
+
+ mrnf :: Mixed sh a -> ()
+
+ -- ====== PRIVATE METHODS ====== --
+
+ -- | Tree giving the shape of every array component.
+ type ShapeTree a
+
+ mshapeTree :: a -> ShapeTree a
+
+ mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
+
+ mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+
+ mshowShapeTree :: Proxy a -> ShapeTree a -> String
+
+ -- | Returns the stride vector of each underlying component array making up
+ -- this mixed array.
+ marrayStrides :: Mixed sh a -> Bag [Int]
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- | Element types for which we have evidence of the (static part of the) shape
+-- in a type class constraint. Compare the instance contexts of the instances
+-- of this class with those of 'Elt': some instances have an additional
+-- "known-shape" constraint.
+--
+-- This class is (currently) only required for 'mgenerate',
+-- 'Data.Array.Nested.Ranked.rgenerate' and
+-- 'Data.Array.Nested.Shaped.sgenerate'.
+class Elt a => KnownElt a where
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArrayUnsafe :: IShX sh -> Mixed sh a
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents.
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance Storable a => Elt (Primitive a) where
+ mshape (M_Primitive sh _) = sh
+ mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
+ mfromListOuter l@(arr1 :| _) =
+ let sh = SUnknown (length l) :$% mshape arr1
+ in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift ssh2 f (M_Primitive _ a)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , let result = f ZKX a
+ = M_Primitive (X.shape ssh2 result) result
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
+ mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
+ , let result = f ZKX a b
+ = M_Primitive (X.shape ssh3 result) result
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a))
+ mliftL ssh2 f l
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $
+ f ZKX (fmap (\(M_Primitive _ arr) -> arr) l)
+
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
+ let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ sh2 = shxCast' sh1 ssh2
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShape sh) perm arr)
+
+ mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
+ mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
+ let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l)
+ in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result
+
+ mrnf (M_Primitive sh a) = rnf sh `seq` rnf a
+
+ type ShapeTree (Primitive a) = ()
+ mshapeTree _ = ()
+ mshapeTreeEq _ () () = True
+ mshapeTreeEmpty _ () = False
+ mshowShapeTree _ () = "()"
+ marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShape sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Bool instance Elt Bool
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Int64 instance Elt Int64
+deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive CInt instance Elt CInt
+deriving via Primitive Double instance Elt Double
+deriving via Primitive Float instance Elt Float
+deriving via Primitive () instance Elt ()
+
+instance Storable a => KnownElt (Primitive a) where
+ memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Bool instance KnownElt Bool
+deriving via Primitive Int instance KnownElt Int
+deriving via Primitive Int64 instance KnownElt Int64
+deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive CInt instance KnownElt CInt
+deriving via Primitive Double instance KnownElt Double
+deriving via Primitive Float instance KnownElt Float
+deriving via Primitive () instance KnownElt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ mliftL ssh2 f =
+ let unzipT2l [] = ([], [])
+ unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
+ unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
+ in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2
+
+ mcastPartial ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b)
+
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+ mconcat =
+ let unzipT2l [] = ([], [])
+ unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
+ unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
+ in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2
+
+ mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b
+
+ type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
+ mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
+ mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
+ mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
+ marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
+ memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance Elt a => Elt (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting depth because it repeatedly
+ -- truncates the shape vector to one a little shorter. Fix with a
+ -- moverlongShape method, a prefix of which is mshape.
+ mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
+ mshape (M_Nest sh arr)
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+
+ mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
+ mindex (M_Nest _ arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest sh arr) i
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mscalar = M_Nest ZSX
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
+ M_Nest (SUnknown (length l) :$% mshape arr)
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift ssh2 f (M_Nest sh1 arr) =
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ in M_Nest sh2 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
+ mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ in M_Nest sh3 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
+ -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
+ mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
+ let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ in fmap (M_Nest sh2) result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ sh2 = shxCast' sh1 ssh2
+ in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (shxPermutePrefix perm sh)
+ (mtranspose perm arr)
+
+ mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mconcat l@(M_Nest sh1 _ :| _) =
+ let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
+ in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result
+
+ mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
+
+ type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+
+ mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Nest _ arr) = marrayStrides arr
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+
+instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
+ memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
+
+ mvecsUnsafeNew sh example
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ where
+ sh' = mshape example
+
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
+
+memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
+memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
+
+mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
+mrank = shxRank . mshape
+
+-- | The total number of elements in the array.
+msize :: Elt a => Mixed sh a -> Int
+msize = shxSize . mshape
+
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+--
+-- __WARNING__: It is required that every @a@ returned by the argument to
+-- 'mgenerate' has the same shape. For example, the following will throw a
+-- runtime error:
+--
+-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
+-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
+-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
+-- > ...
+--
+-- because the size of the inner 'mgenerate' is not always the same (it depends
+-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
+-- the entire hierarchy (after distributing out tuples) must be a rectangular
+-- array. The type of 'mgenerate' allows this requirement to be broken very
+-- easily, hence the runtime check.
+mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum sh of
+ [] -> memptyArrayUnsafe sh
+ firstidx : restidxs ->
+ let firstelem = f (ixxZero' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEmpty (Proxy @a) shapetree
+ then memptyArrayUnsafe sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ restidxs $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
+ mvecsFreeze sh vecs
+
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1P (M_Primitive (n :$% sh) arr) =
+ let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+
+msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr
+
+mappend :: forall n m sh a. Elt a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
+ where
+ sn :$% sh = mshape arr1
+ sm :$% _ = mshape arr2
+ ssh = ssxFromShape sh
+ snm :: SMayNat () SNat (AddMaybe n m)
+ snm = case (sn, sm) of
+ (SUnknown{}, _) -> SUnknown ()
+ (SKnown{}, SUnknown{}) -> SUnknown ()
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
+
+ f :: forall sh' b. Storable b
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
+
+mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+
+mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+
+mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
+mtoVectorP (M_Primitive _ v) = X.toVector v
+
+mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
+mtoVector arr = mtoVectorP (toPrimitive arr)
+
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoListOuter
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
+
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1 l)
+
+mtoListLinear :: Elt a => Mixed sh a -> [a]
+mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
+
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr ZIX
+
+mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
+mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+
+munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
+munNest (M_Nest _ arr) = arr
+
+mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip = M_Tup2
+
+munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
+munzip (M_Tup2 a b) = (a, b)
+
+mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+ let sh1 = shxDropSSX sh ssh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
+
+-- | See the caveats at @X.rerank@.
+mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
+mrerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = ssxFromShape (mshape arr)
+ in mlift (ssxAppend (ssxFromShape sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+
+mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+mslice i n arr =
+ let _ :$% sh = mshape arr
+ in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+
+msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+
+mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
+mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+
+mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' arr =
+ mlift (ssxFromShape sh')
+ (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ arr
+
+mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
+mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr
+
+miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
+miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+
+-- | Throws if the array is empty.
+mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
+
+-- | Throws if the array is empty.
+mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+
+mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
+mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sh1 of
+ _ :$% _
+ | sh1 == sh2
+ , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
+ | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"
+ ZSX -> error "unreachable"
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'mdot1Inner' if applicable.
+mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
+mdot a b =
+ munScalar $
+ mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a)))
+ (fromPrimitive (mflatten (toPrimitive b)))
+
+mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
+mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
+
+mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
+mtoXArrayPrim = mtoXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
+
+mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+
+mliftPrim :: (PrimElt a, PrimElt b)
+ => (a -> b)
+ -> Mixed sh a -> Mixed sh b
+mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
+ => (a -> b -> c)
+ -> Mixed sh a -> Mixed sh b -> Mixed sh c
+mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
+ fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
+
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
+
+-- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors
+data SafeMCastSpec
+ = MCastId
+ | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec
+ | MCastForget
+
+type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint
+type family SafeMCast spec sh1 sh2 where
+ SafeMCast MCastId sh sh = ()
+ SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B)
+ SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing
+
+-- | This is an O(1) operation: the 'SafeMCast' constraint ensures that
+-- type-level shape information can only be forgotten, not introduced, and thus
+-- that no runtime shape checks are required. The @spec@ describes to
+-- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@.
+--
+-- To see how to construct the spec, read the equations of 'SafeMCast' closely.
+mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a
+mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a)