aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs741
1 files changed, 0 insertions, 741 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
deleted file mode 100644
index 84e16b3..0000000
--- a/src/Data/Array/Nested/Mixed.hs
+++ /dev/null
@@ -1,741 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DefaultSignatures #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-module Data.Array.Nested.Mixed where
-
-import Control.DeepSeq (NFData)
-import Control.Monad (forM_, when)
-import Control.Monad.ST
-import Data.Array.RankedS qualified as S
-import Data.Coerce
-import Data.Foldable (toList)
-import Data.Int
-import Data.Kind (Type)
-import Data.List.NonEmpty (NonEmpty(..))
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Data.Vector.Storable.Mutable qualified as VSM
-import Foreign.C.Types (CInt)
-import Foreign.Storable (Storable)
-import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-
-import Data.Array.Mixed (XArray(..))
-import Data.Array.Mixed qualified as X
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Lemmas
-
-
--- Invariant in the API
--- ====================
---
--- In the underlying XArray, there is some shape for elements of an empty
--- array. For example, for this array:
---
--- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
--- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
---
--- the two underlying XArrays have a shape, and those shapes might be anything.
--- The invariant is that these element shapes are unobservable in the API.
--- (This is possible because you ought to not be able to get to such an element
--- without indexing out of bounds.)
---
--- Note, though, that the converse situation may arise: the outer array might
--- be nonempty but then the inner arrays might. This is fine, an invariant only
--- applies if the _outer_ array is empty.
---
--- TODO: can we enforce that the elements of an empty (nested) array have
--- all-zero shape?
--- -> no, because mlift and also any kind of internals probing from outsiders
-
-
--- Primitive element types
--- =======================
---
--- There are a few primitive element types; arrays containing elements of such
--- type are a newtype over an XArray, which it itself a newtype over a Vector.
--- Unfortunately, the setup of the library requires us to list these primitive
--- element types multiple times; to aid in extending the list, all these lists
--- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-
-
--- | Wrapper type used as a tag to attach instances on. The instances on arrays
--- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
--- of scalars; this means that if @orthotope@ supports an element type @T@ that
--- this library does not (directly), it may just work if you use an array of
--- @'Primitive' T@ instead.
-newtype Primitive a = Primitive a
-
--- | Element types that are primitive; arrays of these types are just a newtype
--- wrapper over an array.
-class Storable a => PrimElt a where
- fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
- toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
-
- default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
- fromPrimitive = coerce
-
- default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
- toPrimitive = coerce
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-instance PrimElt Int
-instance PrimElt Int64
-instance PrimElt Int32
-instance PrimElt CInt
-instance PrimElt Float
-instance PrimElt Double
-instance PrimElt ()
-
-
--- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
--- over product-typed elements using a data family so that the full array is
--- always in struct-of-arrays format.
---
--- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
--- dimension permutations (e.g. 'mtranspose') are typically free.
---
--- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
--- class.
-type Mixed :: [Maybe Nat] -> Type -> Type
-data family Mixed sh a
--- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
--- that you're not supposed to see. In particular, you might see (nonempty)
--- sizes of the elements of an empty array, which is information that should
--- ostensibly not exist; the full array is still empty.
-
-data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
- deriving (Show, Eq, Generic)
-
--- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
-deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a))
-
-instance NFData a => NFData (Mixed sh (Primitive a))
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector)
--- etc.
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int)
-deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64)
-deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32)
-deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt)
-deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float)
-deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double)
-deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ())
-
-data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
-deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
-instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b))
--- etc., larger tuples (perhaps use generics to allow arbitrary product types)
-
-data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
-deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a))
-
-
--- | Internal helper data family mirroring 'Mixed' that consists of mutable
--- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
-data family MixedVecs s sh a
-
-newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
-newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
-newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
-newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
-newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
-newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
-newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
--- etc.
-
-data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
--- etc.
-
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-
-
-mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
-mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
-
-mliftNumElt2 :: PrimElt a
- => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
- -> Mixed sh a -> Mixed sh a -> Mixed sh a
-mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
- | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
- | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
-
-instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
- (+) = mliftNumElt2 numEltAdd
- (-) = mliftNumElt2 numEltSub
- (*) = mliftNumElt2 numEltMul
- negate = mliftNumElt1 numEltNeg
- abs = mliftNumElt1 numEltAbs
- signum = mliftNumElt1 numEltSignum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
-
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
- recip = mliftNumElt1 floatEltRecip
- (/) = mliftNumElt2 floatEltDiv
-
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
- exp = mliftNumElt1 floatEltExp
- log = mliftNumElt1 floatEltLog
- sqrt = mliftNumElt1 floatEltSqrt
-
- (**) = mliftNumElt2 floatEltPow
- logBase = mliftNumElt2 floatEltLogbase
-
- sin = mliftNumElt1 floatEltSin
- cos = mliftNumElt1 floatEltCos
- tan = mliftNumElt1 floatEltTan
- asin = mliftNumElt1 floatEltAsin
- acos = mliftNumElt1 floatEltAcos
- atan = mliftNumElt1 floatEltAtan
- sinh = mliftNumElt1 floatEltSinh
- cosh = mliftNumElt1 floatEltCosh
- tanh = mliftNumElt1 floatEltTanh
- asinh = mliftNumElt1 floatEltAsinh
- acosh = mliftNumElt1 floatEltAcosh
- atanh = mliftNumElt1 floatEltAtanh
- log1p = mliftNumElt1 floatEltLog1p
- expm1 = mliftNumElt1 floatEltExpm1
- log1pexp = mliftNumElt1 floatEltLog1pexp
- log1mexp = mliftNumElt1 floatEltLog1mexp
-
-
--- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
--- a@; see the documentation for 'Primitive' for more details.
-class Elt a where
- -- ====== PUBLIC METHODS ====== --
-
- mshape :: Mixed sh a -> IShX sh
- mindex :: Mixed sh a -> IIxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
- mscalar :: a -> Mixed '[] a
-
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
-
- mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
-
- -- | Note: this library makes no particular guarantees about the shapes of
- -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
- -- full 'XArray' and as such you can distinguish different empty arrays by
- -- the "shapes" of their elements. This information is meaningless, so you
- -- should not use it.
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a
-
- -- | See the documentation for 'mlift'.
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
-
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
-
- mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
- => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
-
- -- ====== PRIVATE METHODS ====== --
-
- -- | Tree giving the shape of every array component.
- type ShapeTree a
-
- mshapeTree :: a -> ShapeTree a
-
- mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
-
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
-
- mshowShapeTree :: Proxy a -> ShapeTree a -> String
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-
- -- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-
--- | Element types for which we have evidence of the (static part of the) shape
--- in a type class constraint. Compare the instance contexts of the instances
--- of this class with those of 'Elt': some instances have an additional
--- "known-shape" constraint.
---
--- This class is (currently) only required for 'mgenerate' / 'rgenerate' /
--- 'sgenerate'.
-class Elt a => KnownElt a where
- -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IShX sh -> Mixed sh a
-
- -- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents.
- mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
-
- mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-
-
--- Arrays of scalars are basically just arrays of scalars.
-instance Storable a => Elt (Primitive a) where
- mshape (M_Primitive sh _) = sh
- mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
- mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
- mlift ssh2 f (M_Primitive _ a)
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- , let result = f ZKX a
- = M_Primitive (X.shape ssh2 result) result
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
- mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- , Refl <- lemAppNil @sh3
- , let result = f ZKX a b
- = M_Primitive (X.shape ssh3 result) result
-
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
- mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
- let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
-
- mtranspose perm (M_Primitive sh arr) =
- M_Primitive (shxPermutePrefix perm sh)
- (X.transpose (ssxFromShape sh) perm arr)
-
- type ShapeTree (Primitive a) = ()
- mshapeTree _ = ()
- mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
- mshowShapeTree _ () = "()"
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
- :: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShape sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
-
- mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Int instance Elt Int
-deriving via Primitive Int64 instance Elt Int64
-deriving via Primitive Int32 instance Elt Int32
-deriving via Primitive CInt instance Elt CInt
-deriving via Primitive Double instance Elt Double
-deriving via Primitive Float instance Elt Float
-deriving via Primitive () instance Elt ()
-
-instance Storable a => KnownElt (Primitive a) where
- memptyArray sh = M_Primitive sh (X.empty sh)
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
- mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Int instance KnownElt Int
-deriving via Primitive Int64 instance KnownElt Int64
-deriving via Primitive Int32 instance KnownElt Int32
-deriving via Primitive CInt instance KnownElt CInt
-deriving via Primitive Double instance KnownElt Double
-deriving via Primitive Float instance KnownElt Float
-deriving via Primitive () instance KnownElt ()
-
--- Arrays of pairs are pairs of arrays.
-instance (Elt a, Elt b) => Elt (a, b) where
- mshape (M_Tup2 a _) = mshape a
- mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
- mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
- mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
- mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
- mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
- mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
-
- mcast ssh1 sh2 psh' (M_Tup2 a b) =
- M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
-
- mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
-
- type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
- mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
- mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
- mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
-
-instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
- mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
- mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-
--- Arrays of arrays are just arrays, but with more dimensions.
-instance Elt a => Elt (Mixed sh' a) where
- -- TODO: this is quadratic in the nesting depth because it repeatedly
- -- truncates the shape vector to one a little shorter. Fix with a
- -- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
- mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
-
- mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
-
- mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest sh arr) i
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
-
- mscalar = M_Nest ZSX
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
-
- mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift ssh2 f (M_Nest sh1 arr) =
- let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
- in M_Nest sh2 result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
- mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
- let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
- in M_Nest sh3 result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
- mcast ssh1 sh2 _ (M_Nest sh1T arr)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
-
- mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
- => Perm is -> Mixed sh (Mixed sh' a)
- -> Mixed (PermutePrefix is sh) (Mixed sh' a)
- mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
- , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
- , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
- , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
- , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- = M_Nest (shxPermutePrefix perm sh)
- (mtranspose perm arr)
-
- type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
-
- mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
-
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
-
-instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
- memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
-
- mvecsUnsafeNew sh example
- | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
- where
- sh' = mshape example
-
- mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-
-
--- | Create an array given a size and a function that computes the element at a
--- given index.
---
--- __WARNING__: It is required that every @a@ returned by the argument to
--- 'mgenerate' has the same shape. For example, the following will throw a
--- runtime error:
---
--- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
--- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
--- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
--- > ...
---
--- because the size of the inner 'mgenerate' is not always the same (it depends
--- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
--- the entire hierarchy (after distributing out tuples) must be a rectangular
--- array. The type of 'mgenerate' allows this requirement to be broken very
--- easily, hence the runtime check.
-mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case shxEnum sh of
- [] -> memptyArray sh
- firstidx : restidxs ->
- let firstelem = f (ixxZero' sh)
- shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
- then memptyArray sh
- else runST $ do
- vecs <- mvecsUnsafeNew sh firstelem
- mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
- forM_ restidxs $ \idx -> do
- let val = f idx
- when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
- error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
- mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
-
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
- let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
-
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
-
-mappend :: forall n m sh a. Elt a
- => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
-mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
- where
- sn :$% sh = mshape arr1
- sm :$% _ = mshape arr2
- ssh = ssxFromShape sh
- snm :: SMayNat () SNat (AddMaybe n m)
- snm = case (sn, sm) of
- (SUnknown{}, _) -> SUnknown ()
- (SKnown{}, SUnknown{}) -> SUnknown ()
- (SKnown n, SKnown m) -> SKnown (snatPlus n m)
-
- f :: forall sh' b. Storable b
- => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
- f ssh' = X.append (ssxAppend ssh ssh')
-
-mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
-
-mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
-mfromVector sh v = fromPrimitive (mfromVectorP sh v)
-
-mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
-mtoVectorP (M_Primitive _ v) = X.toVector v
-
-mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
-mtoVector arr = mtoVectorP (toPrimitive arr)
-
-mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
-
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
-mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
-
-munScalar :: Elt a => Mixed '[] a -> a
-munScalar arr = mindex arr ZIX
-
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
-
--- | See the caveats at @X.rerank@.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
-
-mreplicate :: forall sh sh' a. Elt a
- => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
-mreplicate sh arr =
- let ssh' = ssxFromShape (mshape arr)
- in mlift (ssxAppend (ssxFromShape sh) ssh')
- (\(sshT :: StaticShX shT) ->
- case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
- Refl -> X.replicate sh (ssxAppend ssh' sshT))
- arr
-
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
-
-mreplicateScal :: forall sh a. PrimElt a
- => IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
-
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
- let _ :$% sh = mshape arr
- in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
-
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
-
-mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
-
-mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
-mreshape sh' arr =
- mlift (ssxFromShape sh')
- (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
- arr
-
-miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
-miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-
-masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
-masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
-
-masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
-masXArrayPrim = masXArrayPrimP . toPrimitive
-
-mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
-mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
-
-mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
-mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
-
-mliftPrim :: PrimElt a
- => (a -> a)
- -> Mixed sh a -> Mixed sh a
-mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
-
-mliftPrim2 :: PrimElt a
- => (a -> a -> a)
- -> Mixed sh a -> Mixed sh a -> Mixed sh a
-mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
- fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))