diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 741 |
1 files changed, 0 insertions, 741 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs deleted file mode 100644 index 84e16b3..0000000 --- a/src/Data/Array/Nested/Mixed.hs +++ /dev/null @@ -1,741 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Mixed where - -import Control.DeepSeq (NFData) -import Control.Monad (forM_, when) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Coerce -import Data.Foldable (toList) -import Data.Int -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types (CInt) -import Foreign.Storable (Storable) -import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed (XArray(..)) -import Data.Array.Mixed qualified as X -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Lemmas - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class Storable a => PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Show, Eq, Generic) - --- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. -deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) - -instance NFData a => NFData (Mixed sh (Primitive a)) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector) --- etc. - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) -deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) -deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) -deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) -deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) -deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) -deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - -mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) - -mliftNumElt2 :: PrimElt a - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) - | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) - | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 numEltAdd - (-) = mliftNumElt2 numEltSub - (*) = mliftNumElt2 numEltMul - negate = mliftNumElt1 numEltNeg - abs = mliftNumElt1 numEltAbs - signum = mliftNumElt1 numEltSignum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 floatEltRecip - (/) = mliftNumElt2 floatEltDiv - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 floatEltExp - log = mliftNumElt1 floatEltLog - sqrt = mliftNumElt1 floatEltSqrt - - (**) = mliftNumElt2 floatEltPow - logBase = mliftNumElt2 floatEltLogbase - - sin = mliftNumElt1 floatEltSin - cos = mliftNumElt1 floatEltCos - tan = mliftNumElt1 floatEltTan - asin = mliftNumElt1 floatEltAsin - acos = mliftNumElt1 floatEltAcos - atan = mliftNumElt1 floatEltAtan - sinh = mliftNumElt1 floatEltSinh - cosh = mliftNumElt1 floatEltCosh - tanh = mliftNumElt1 floatEltTanh - asinh = mliftNumElt1 floatEltAsinh - acosh = mliftNumElt1 floatEltAcosh - atanh = mliftNumElt1 floatEltAtanh - log1p = mliftNumElt1 floatEltLog1p - expm1 = mliftNumElt1 floatEltExpm1 - log1pexp = mliftNumElt1 floatEltLog1pexp - log1mexp = mliftNumElt1 floatEltLog1mexp - - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - - mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - - -- ====== PRIVATE METHODS ====== -- - - -- | Tree giving the shape of every array component. - type ShapeTree a - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate' / 'rgenerate' / --- 'sgenerate'. -class Elt a => KnownElt a where - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IShX sh -> Mixed sh a - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive sh _) = sh - mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift ssh2 f (M_Primitive _ a) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , let result = f ZKX a - = M_Primitive (X.shape ssh2 result) result - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , Refl <- lemAppNil @sh3 - , let result = f ZKX a b - = M_Primitive (X.shape ssh3 result) result - - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - - mtranspose perm (M_Primitive sh arr) = - M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) - - type ShapeTree (Primitive a) = () - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where - memptyArray sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) - mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) - mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) - mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - - mcast ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) - - mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) - - type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - - mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest sh arr) i - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest ZSX - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - - mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) - in M_Nest sh2 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) - in M_Nest sh3 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcast ssh1 sh2 _ (M_Nest sh1T arr) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh (Mixed sh' a) - -> Mixed (PermutePrefix is sh) (Mixed sh' a) - mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') - , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) - , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') - , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (shxPermutePrefix perm sh) - (mtranspose perm arr) - - type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) - - mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - - mvecsUnsafeNew sh example - | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of - [] -> memptyArray sh - firstidx : restidxs -> - let firstelem = f (ixxZero' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = - let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 - where - sn :$% sh = mshape arr1 - sm :$% _ = mshape arr2 - ssh = ssxFromShape sh - snm :: SMayNat () SNat (AddMaybe n m) - snm = case (sn, sm) of - (SUnknown{}, _) -> SUnknown () - (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (snatPlus n m) - - f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b - f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a - => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') - (\(sshT :: StaticShX shT) -> - case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of - Refl -> X.replicate sh (ssxAppend ssh' sshT)) - arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = - let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') - arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - -masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -masXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -masXArrayPrim = masXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: PrimElt a - => (a -> a) - -> Mixed sh a -> Mixed sh a -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: PrimElt a - => (a -> a -> a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = - fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) |