aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs955
1 files changed, 0 insertions, 955 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
deleted file mode 100644
index a2f9737..0000000
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ /dev/null
@@ -1,955 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DefaultSignatures #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-module Data.Array.Nested.Internal.Mixed where
-
-import Prelude hiding (mconcat)
-
-import Control.DeepSeq (NFData(..))
-import Control.Monad (forM_, when)
-import Control.Monad.ST
-import Data.Array.RankedS qualified as S
-import Data.Bifunctor (bimap)
-import Data.Coerce
-import Data.Foldable (toList)
-import Data.Int
-import Data.Kind (Constraint, Type)
-import Data.List.NonEmpty (NonEmpty(..))
-import Data.List.NonEmpty qualified as NE
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Data.Vector.Storable.Mutable qualified as VSM
-import Foreign.C.Types (CInt)
-import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray(..))
-import Data.Array.Mixed.XArray qualified as X
-import Data.Bag
-
-
--- TODO:
--- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
--- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int
--- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute)
--- After benchmarking: matmul and matvec
-
-
-
--- Invariant in the API
--- ====================
---
--- In the underlying XArray, there is some shape for elements of an empty
--- array. For example, for this array:
---
--- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
--- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
---
--- the two underlying XArrays have a shape, and those shapes might be anything.
--- The invariant is that these element shapes are unobservable in the API.
--- (This is possible because you ought to not be able to get to such an element
--- without indexing out of bounds.)
---
--- Note, though, that the converse situation may arise: the outer array might
--- be nonempty but then the inner arrays might. This is fine, an invariant only
--- applies if the _outer_ array is empty.
---
--- TODO: can we enforce that the elements of an empty (nested) array have
--- all-zero shape?
--- -> no, because mlift and also any kind of internals probing from outsiders
-
-
--- Primitive element types
--- =======================
---
--- There are a few primitive element types; arrays containing elements of such
--- type are a newtype over an XArray, which it itself a newtype over a Vector.
--- Unfortunately, the setup of the library requires us to list these primitive
--- element types multiple times; to aid in extending the list, all these lists
--- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-
-
--- | Wrapper type used as a tag to attach instances on. The instances on arrays
--- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
--- of scalars; this means that if @orthotope@ supports an element type @T@ that
--- this library does not (directly), it may just work if you use an array of
--- @'Primitive' T@ instead.
-newtype Primitive a = Primitive a
- deriving (Show)
-
--- | Element types that are primitive; arrays of these types are just a newtype
--- wrapper over an array.
-class (Storable a, Elt a) => PrimElt a where
- fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
- toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
-
- default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
- fromPrimitive = coerce
-
- default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
- toPrimitive = coerce
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-instance PrimElt Bool
-instance PrimElt Int
-instance PrimElt Int64
-instance PrimElt Int32
-instance PrimElt CInt
-instance PrimElt Float
-instance PrimElt Double
-instance PrimElt ()
-
-
--- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
--- over product-typed elements using a data family so that the full array is
--- always in struct-of-arrays format.
---
--- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
--- dimension permutations (e.g. 'mtranspose') are typically free.
---
--- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
--- class.
-type Mixed :: [Maybe Nat] -> Type -> Type
-data family Mixed sh a
--- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
--- that you're not supposed to see. In particular, you might see (nonempty)
--- sizes of the elements of an empty array, which is information that should
--- ostensibly not exist; the full array is still empty.
-
-data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
- deriving (Eq, Ord, Generic)
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector)
--- etc.
-
-data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
--- etc., larger tuples (perhaps use generics to allow arbitrary product types)
-
-deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b))
-deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b))
-
-data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
-
-deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a))
-deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a))
-
-
--- | Internal helper data family mirroring 'Mixed' that consists of mutable
--- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
-data family MixedVecs s sh a
-
-newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool)
-newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
-newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
-newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
-newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
-newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
-newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
-newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
--- etc.
-
-data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
--- etc.
-
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-
-
-showsMixedArray :: (Show a, Elt a)
- => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@
- -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@
- -> Int -> Mixed sh a -> ShowS
-showsMixedArray fromlistPrefix replicatePrefix d arr =
- showParen (d > 10) $
- -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here
- case mtoListLinear arr of
- hd : _ : _
- | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) ->
- showString replicatePrefix . showString " " . showsPrec 11 hd
- _ ->
- showString fromlistPrefix . showString " " . shows (mtoListLinear arr)
-
-instance (Show a, Elt a) => Show (Mixed sh a) where
- showsPrec d arr =
- let sh = show (shxToList (mshape arr))
- in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr
-
-instance Elt a => NFData (Mixed sh a) where
- rnf = mrnf
-
-
-mliftNumElt1 :: (PrimElt a, PrimElt b)
- => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
- -> Mixed sh a -> Mixed sh b
-mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
-
-mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
- => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
- -> Mixed sh a -> Mixed sh b -> Mixed sh c
-mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
- | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
- | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
-
-instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
- (+) = mliftNumElt2 (liftO2 . numEltAdd)
- (-) = mliftNumElt2 (liftO2 . numEltSub)
- (*) = mliftNumElt2 (liftO2 . numEltMul)
- negate = mliftNumElt1 (liftO1 . numEltNeg)
- abs = mliftNumElt1 (liftO1 . numEltAbs)
- signum = mliftNumElt1 (liftO1 . numEltSignum)
- -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
- fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
- recip = mliftNumElt1 (liftO1 . floatEltRecip)
- (/) = mliftNumElt2 (liftO2 . floatEltDiv)
-
-instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
- exp = mliftNumElt1 (liftO1 . floatEltExp)
- log = mliftNumElt1 (liftO1 . floatEltLog)
- sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
-
- (**) = mliftNumElt2 (liftO2 . floatEltPow)
- logBase = mliftNumElt2 (liftO2 . floatEltLogbase)
-
- sin = mliftNumElt1 (liftO1 . floatEltSin)
- cos = mliftNumElt1 (liftO1 . floatEltCos)
- tan = mliftNumElt1 (liftO1 . floatEltTan)
- asin = mliftNumElt1 (liftO1 . floatEltAsin)
- acos = mliftNumElt1 (liftO1 . floatEltAcos)
- atan = mliftNumElt1 (liftO1 . floatEltAtan)
- sinh = mliftNumElt1 (liftO1 . floatEltSinh)
- cosh = mliftNumElt1 (liftO1 . floatEltCosh)
- tanh = mliftNumElt1 (liftO1 . floatEltTanh)
- asinh = mliftNumElt1 (liftO1 . floatEltAsinh)
- acosh = mliftNumElt1 (liftO1 . floatEltAcosh)
- atanh = mliftNumElt1 (liftO1 . floatEltAtanh)
- log1p = mliftNumElt1 (liftO1 . floatEltLog1p)
- expm1 = mliftNumElt1 (liftO1 . floatEltExpm1)
- log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp)
- log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)
-
-mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
-mquotArray = mliftNumElt2 (liftO2 . intEltQuot)
-mremArray = mliftNumElt2 (liftO2 . intEltRem)
-
-matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
-matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
-
--- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
--- a@; see the documentation for 'Primitive' for more details.
-class Elt a where
- -- ====== PUBLIC METHODS ====== --
-
- mshape :: Mixed sh a -> IShX sh
- mindex :: Mixed sh a -> IIxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
- mscalar :: a -> Mixed '[] a
-
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
-
- mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
-
- -- | Note: this library makes no particular guarantees about the shapes of
- -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the
- -- full 'XArray' and as such you can distinguish different empty arrays by
- -- the "shapes" of their elements. This information is meaningless, so you
- -- should not use it.
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a
-
- -- | See the documentation for 'mlift'.
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
-
- -- TODO: mliftL is currently unused.
- -- | All arrays in the input must have equal shapes, including subarrays
- -- inside their elements.
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a)
-
- mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
-
- mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
- => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
-
- -- | All arrays in the input must have equal shapes, including subarrays
- -- inside their elements.
- mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a
-
- mrnf :: Mixed sh a -> ()
-
- -- ====== PRIVATE METHODS ====== --
-
- -- | Tree giving the shape of every array component.
- type ShapeTree a
-
- mshapeTree :: a -> ShapeTree a
-
- mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
-
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
-
- mshowShapeTree :: Proxy a -> ShapeTree a -> String
-
- -- | Returns the stride vector of each underlying component array making up
- -- this mixed array.
- marrayStrides :: Mixed sh a -> Bag [Int]
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-
- -- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-
--- | Element types for which we have evidence of the (static part of the) shape
--- in a type class constraint. Compare the instance contexts of the instances
--- of this class with those of 'Elt': some instances have an additional
--- "known-shape" constraint.
---
--- This class is (currently) only required for 'mgenerate',
--- 'Data.Array.Nested.Ranked.rgenerate' and
--- 'Data.Array.Nested.Shaped.sgenerate'.
-class Elt a => KnownElt a where
- -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArrayUnsafe :: IShX sh -> Mixed sh a
-
- -- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents.
- mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
-
- mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-
-
--- Arrays of scalars are basically just arrays of scalars.
-instance Storable a => Elt (Primitive a) where
- mshape (M_Primitive sh _) = sh
- mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
- mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
- mlift ssh2 f (M_Primitive _ a)
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- , let result = f ZKX a
- = M_Primitive (X.shape ssh2 result) result
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
- mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- , Refl <- lemAppNil @sh3
- , let result = f ZKX a b
- = M_Primitive (X.shape ssh3 result) result
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a))
- mliftL ssh2 f l
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $
- f ZKX (fmap (\(M_Primitive _ arr) -> arr) l)
-
- mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
- mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
- let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- sh2 = shxCast' sh1 ssh2
- in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
-
- mtranspose perm (M_Primitive sh arr) =
- M_Primitive (shxPermutePrefix perm sh)
- (X.transpose (ssxFromShape sh) perm arr)
-
- mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
- mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
- let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l)
- in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result
-
- mrnf (M_Primitive sh a) = rnf sh `seq` rnf a
-
- type ShapeTree (Primitive a) = ()
- mshapeTree _ = ()
- mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
- mshowShapeTree _ () = "()"
- marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
- :: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShape sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
-
- mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Bool instance Elt Bool
-deriving via Primitive Int instance Elt Int
-deriving via Primitive Int64 instance Elt Int64
-deriving via Primitive Int32 instance Elt Int32
-deriving via Primitive CInt instance Elt CInt
-deriving via Primitive Double instance Elt Double
-deriving via Primitive Float instance Elt Float
-deriving via Primitive () instance Elt ()
-
-instance Storable a => KnownElt (Primitive a) where
- memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
- mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Bool instance KnownElt Bool
-deriving via Primitive Int instance KnownElt Int
-deriving via Primitive Int64 instance KnownElt Int64
-deriving via Primitive Int32 instance KnownElt Int32
-deriving via Primitive CInt instance KnownElt CInt
-deriving via Primitive Double instance KnownElt Double
-deriving via Primitive Float instance KnownElt Float
-deriving via Primitive () instance KnownElt ()
-
--- Arrays of pairs are pairs of arrays.
-instance (Elt a, Elt b) => Elt (a, b) where
- mshape (M_Tup2 a _) = mshape a
- mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
- mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
- mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
- mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
- mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
- mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
- mliftL ssh2 f =
- let unzipT2l [] = ([], [])
- unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
- unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
- in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2
-
- mcastPartial ssh1 sh2 psh' (M_Tup2 a b) =
- M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b)
-
- mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
- mconcat =
- let unzipT2l [] = ([], [])
- unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
- unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
- in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2
-
- mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b
-
- type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
- mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
- mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
- marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
- mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
-
-instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
- memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
- mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
- mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-
--- Arrays of arrays are just arrays, but with more dimensions.
-instance Elt a => Elt (Mixed sh' a) where
- -- TODO: this is quadratic in the nesting depth because it repeatedly
- -- truncates the shape vector to one a little shorter. Fix with a
- -- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
- mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
-
- mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
-
- mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest sh arr) i
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
-
- mscalar = M_Nest ZSX
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
-
- mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift ssh2 f (M_Nest sh1 arr) =
- let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
- in M_Nest sh2 result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
- mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
- let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
- in M_Nest sh3 result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
- -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
- mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
- let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
- in fmap (M_Nest sh2) result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
- mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- sh2 = shxCast' sh1 ssh2
- in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
-
- mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
- => Perm is -> Mixed sh (Mixed sh' a)
- -> Mixed (PermutePrefix is sh) (Mixed sh' a)
- mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
- , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
- , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
- , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
- , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- = M_Nest (shxPermutePrefix perm sh)
- (mtranspose perm arr)
-
- mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mconcat l@(M_Nest sh1 _ :| _) =
- let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result
-
- mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
-
- type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
-
- mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Nest _ arr) = marrayStrides arr
-
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
-
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
-
-instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
- memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
-
- mvecsUnsafeNew sh example
- | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
- where
- sh' = mshape example
-
- mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-
-
-memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
-memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
-
-mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
-mrank = shxRank . mshape
-
--- | The total number of elements in the array.
-msize :: Elt a => Mixed sh a -> Int
-msize = shxSize . mshape
-
--- | Create an array given a size and a function that computes the element at a
--- given index.
---
--- __WARNING__: It is required that every @a@ returned by the argument to
--- 'mgenerate' has the same shape. For example, the following will throw a
--- runtime error:
---
--- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
--- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
--- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
--- > ...
---
--- because the size of the inner 'mgenerate' is not always the same (it depends
--- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
--- the entire hierarchy (after distributing out tuples) must be a rectangular
--- array. The type of 'mgenerate' allows this requirement to be broken very
--- easily, hence the runtime check.
-mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case shxEnum sh of
- [] -> memptyArrayUnsafe sh
- firstidx : restidxs ->
- let firstelem = f (ixxZero' sh)
- shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
- then memptyArrayUnsafe sh
- else runST $ do
- vecs <- mvecsUnsafeNew sh firstelem
- mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
- forM_ restidxs $ \idx -> do
- let val = f idx
- when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
- error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
- mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
-
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
- let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
-
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
-
-msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr
-
-mappend :: forall n m sh a. Elt a
- => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
-mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
- where
- sn :$% sh = mshape arr1
- sm :$% _ = mshape arr2
- ssh = ssxFromShape sh
- snm :: SMayNat () SNat (AddMaybe n m)
- snm = case (sn, sm) of
- (SUnknown{}, _) -> SUnknown ()
- (SKnown{}, SUnknown{}) -> SUnknown ()
- (SKnown n, SKnown m) -> SKnown (snatPlus n m)
-
- f :: forall sh' b. Storable b
- => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
- f ssh' = X.append (ssxAppend ssh ssh')
-
-mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
-
-mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
-mfromVector sh v = fromPrimitive (mfromVectorP sh v)
-
-mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
-mtoVectorP (M_Primitive _ v) = X.toVector v
-
-mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
-mtoVector arr = mtoVectorP (toPrimitive arr)
-
-mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
-
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
-mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
-
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
-
-mtoListLinear :: Elt a => Mixed sh a -> [a]
-mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
-
-munScalar :: Elt a => Mixed '[] a -> a
-munScalar arr = mindex arr ZIX
-
-mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
-mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
-
-munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
-munNest (M_Nest _ arr) = arr
-
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
-
-munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
-munzip (M_Tup2 a b) = (a, b)
-
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
-
--- | See the caveats at @X.rerank@.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
-
-mreplicate :: forall sh sh' a. Elt a
- => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
-mreplicate sh arr =
- let ssh' = ssxFromShape (mshape arr)
- in mlift (ssxAppend (ssxFromShape sh) ssh')
- (\(sshT :: StaticShX shT) ->
- case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
- Refl -> X.replicate sh (ssxAppend ssh' sshT))
- arr
-
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
-
-mreplicateScal :: forall sh a. PrimElt a
- => IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
-
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
- let _ :$% sh = mshape arr
- in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
-
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
-
-mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
-
-mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
-mreshape sh' arr =
- mlift (ssxFromShape sh')
- (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
- arr
-
-mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
-mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr
-
-miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
-miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-
--- | Throws if the array is empty.
-mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
-mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
-
--- | Throws if the array is empty.
-mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
-mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
-
-mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
- => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
-mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
- | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
- , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
- = case sh1 of
- _ :$% _
- | sh1 == sh2
- , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
- fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
- | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"
- ZSX -> error "unreachable"
-
--- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'mdot1Inner' if applicable.
-mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
-mdot a b =
- munScalar $
- mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a)))
- (fromPrimitive (mflatten (toPrimitive b)))
-
-mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
-mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
-
-mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
-mtoXArrayPrim = mtoXArrayPrimP . toPrimitive
-
-mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
-mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
-
-mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
-mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
-
-mliftPrim :: (PrimElt a, PrimElt b)
- => (a -> b)
- -> Mixed sh a -> Mixed sh b
-mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
-
-mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
- => (a -> b -> c)
- -> Mixed sh a -> Mixed sh b -> Mixed sh c
-mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
- fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
-
-mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
- => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
-mcast ssh2 arr
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
-
--- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors
-data SafeMCastSpec
- = MCastId
- | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec
- | MCastForget
-
-type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint
-type family SafeMCast spec sh1 sh2 where
- SafeMCast MCastId sh sh = ()
- SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B)
- SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing
-
--- | This is an O(1) operation: the 'SafeMCast' constraint ensures that
--- type-level shape information can only be forgotten, not introduced, and thus
--- that no runtime shape checks are required. The @spec@ describes to
--- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@.
---
--- To see how to construct the spec, read the equations of 'SafeMCast' closely.
-mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a
-mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a)