diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 955 |
1 files changed, 0 insertions, 955 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs deleted file mode 100644 index a2f9737..0000000 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ /dev/null @@ -1,955 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Internal.Mixed where - -import Prelude hiding (mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (bimap) -import Data.Coerce -import Data.Foldable (toList) -import Data.Int -import Data.Kind (Constraint, Type) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.List.NonEmpty qualified as NE -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types (CInt) -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Bag - - --- TODO: --- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a --- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int --- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) --- After benchmarking: matmul and matvec - - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - deriving (Show) - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class (Storable a, Elt a) => PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Bool -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Eq, Ord, Generic) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) -deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) - -deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) -deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - -showsMixedArray :: (Show a, Elt a) - => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ - -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ - -> Int -> Mixed sh a -> ShowS -showsMixedArray fromlistPrefix replicatePrefix d arr = - showParen (d > 10) $ - -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here - case mtoListLinear arr of - hd : _ : _ - | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> - showString replicatePrefix . showString " " . showsPrec 11 hd - _ -> - showString fromlistPrefix . showString " " . shows (mtoListLinear arr) - -instance (Show a, Elt a) => Show (Mixed sh a) where - showsPrec d arr = - let sh = show (shxToList (mshape arr)) - in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr - -instance Elt a => NFData (Mixed sh a) where - rnf = mrnf - - -mliftNumElt1 :: (PrimElt a, PrimElt b) - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) - -> Mixed sh a -> Mixed sh b -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) - -mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) - -> Mixed sh a -> Mixed sh b -> Mixed sh c -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) - | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) - | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 (liftO2 . numEltAdd) - (-) = mliftNumElt2 (liftO2 . numEltSub) - (*) = mliftNumElt2 (liftO2 . numEltMul) - negate = mliftNumElt1 (liftO1 . numEltNeg) - abs = mliftNumElt1 (liftO1 . numEltAbs) - signum = mliftNumElt1 (liftO1 . numEltSignum) - -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS - fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 (liftO1 . floatEltRecip) - (/) = mliftNumElt2 (liftO2 . floatEltDiv) - -instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 (liftO1 . floatEltExp) - log = mliftNumElt1 (liftO1 . floatEltLog) - sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) - - (**) = mliftNumElt2 (liftO2 . floatEltPow) - logBase = mliftNumElt2 (liftO2 . floatEltLogbase) - - sin = mliftNumElt1 (liftO1 . floatEltSin) - cos = mliftNumElt1 (liftO1 . floatEltCos) - tan = mliftNumElt1 (liftO1 . floatEltTan) - asin = mliftNumElt1 (liftO1 . floatEltAsin) - acos = mliftNumElt1 (liftO1 . floatEltAcos) - atan = mliftNumElt1 (liftO1 . floatEltAtan) - sinh = mliftNumElt1 (liftO1 . floatEltSinh) - cosh = mliftNumElt1 (liftO1 . floatEltCosh) - tanh = mliftNumElt1 (liftO1 . floatEltTanh) - asinh = mliftNumElt1 (liftO1 . floatEltAsinh) - acosh = mliftNumElt1 (liftO1 . floatEltAcosh) - atanh = mliftNumElt1 (liftO1 . floatEltAtanh) - log1p = mliftNumElt1 (liftO1 . floatEltLog1p) - expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) - log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) - log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) - -mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 (liftO2 . intEltQuot) -mremArray = mliftNumElt2 (liftO2 . intEltRem) - -matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - - mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - -- TODO: mliftL is currently unused. - -- | All arrays in the input must have equal shapes, including subarrays - -- inside their elements. - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) - - mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - - -- | All arrays in the input must have equal shapes, including subarrays - -- inside their elements. - mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a - - mrnf :: Mixed sh a -> () - - -- ====== PRIVATE METHODS ====== -- - - -- | Tree giving the shape of every array component. - type ShapeTree a - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Returns the stride vector of each underlying component array making up - -- this mixed array. - marrayStrides :: Mixed sh a -> Bag [Int] - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. -class Elt a => KnownElt a where - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArrayUnsafe :: IShX sh -> Mixed sh a - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive sh _) = sh - mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift ssh2 f (M_Primitive _ a) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , let result = f ZKX a - = M_Primitive (X.shape ssh2 result) result - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , Refl <- lemAppNil @sh3 - , let result = f ZKX a b - = M_Primitive (X.shape ssh3 result) result - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) - mliftL ssh2 f l - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ - f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) - - mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = - let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - - mtranspose perm (M_Primitive sh arr) = - M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) - - mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) - mconcat l@(M_Primitive (_ :$% sh) _ :| _) = - let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) - in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result - - mrnf (M_Primitive sh a) = rnf sh `seq` rnf a - - type ShapeTree (Primitive a) = () - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Bool instance Elt Bool -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where - memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Bool instance KnownElt Bool -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) - mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) - mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) - mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - mliftL ssh2 f = - let unzipT2l [] = ([], []) - unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) - unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) - in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 - - mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) - - mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) - mconcat = - let unzipT2l [] = ([], []) - unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) - unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) - in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 - - mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b - - type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where - memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - - mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest sh arr) i - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest ZSX - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - - mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) - in M_Nest sh2 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) - in M_Nest sh3 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) - -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) - mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = - let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) - in fmap (M_Nest sh2) result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 - in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh (Mixed sh' a) - -> Mixed (PermutePrefix is sh) (Mixed sh' a) - mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') - , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) - , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') - , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (shxPermutePrefix perm sh) - (mtranspose perm arr) - - mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mconcat l@(M_Nest sh1 _ :| _) = - let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result - - mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr - - type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) - - mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Nest _ arr) = marrayStrides arr - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - - mvecsUnsafeNew sh example - | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - -memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a -memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) - -mrank :: Elt a => Mixed sh a -> SNat (Rank sh) -mrank = shxRank . mshape - --- | The total number of elements in the array. -msize :: Elt a => Mixed sh a -> Int -msize = shxSize . mshape - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of - [] -> memptyArrayUnsafe sh - firstidx : restidxs -> - let firstelem = f (ixxZero' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArrayUnsafe sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = - let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr - -mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 - where - sn :$% sh = mshape arr1 - sm :$% _ = mshape arr2 - ssh = ssxFromShape sh - snm :: SMayNat () SNat (AddMaybe n m) - snm = case (sn, sm) of - (SUnknown{}, _) -> SUnknown () - (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (snatPlus n m) - - f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b - f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) - -mtoListLinear :: Elt a => Mixed sh a -> [a] -mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr - -munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a -munNest (M_Nest _ arr) = arr - -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 - -munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) -munzip (M_Tup2 a b) = (a, b) - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a - => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') - (\(sshT :: StaticShX shT) -> - case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of - Refl -> X.replicate sh (ssxAppend ssh' sshT)) - arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = - let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') - arr - -mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a -mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - --- | Throws if the array is empty. -mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh -mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) - --- | Throws if the array is empty. -mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh -mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) - -mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) - => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a -mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) - | Refl <- lemInitApp (Proxy @sh) (Proxy @n) - , Refl <- lemLastApp (Proxy @sh) (Proxy @n) - = case sh1 of - _ :$% _ - | sh1 == sh2 - , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> - fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) - | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" - ZSX -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'mdot1Inner' if applicable. -mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a -mdot a b = - munScalar $ - mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) - (fromPrimitive (mflatten (toPrimitive b))) - -mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -mtoXArrayPrim = mtoXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: (PrimElt a, PrimElt b) - => (a -> b) - -> Mixed sh a -> Mixed sh b -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) - => (a -> b -> c) - -> Mixed sh a -> Mixed sh b -> Mixed sh c -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = - fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) - => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr - --- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors -data SafeMCastSpec - = MCastId - | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec - | MCastForget - -type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint -type family SafeMCast spec sh1 sh2 where - SafeMCast MCastId sh sh = () - SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) - SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing - --- | This is an O(1) operation: the 'SafeMCast' constraint ensures that --- type-level shape information can only be forgotten, not introduced, and thus --- that no runtime shape checks are required. The @spec@ describes to --- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. --- --- To see how to construct the spec, read the equations of 'SafeMCast' closely. -mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a -mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) |