aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 22:20:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 22:20:57 +0200
commitf0752d67cd188f438280e1f0c692dc1f5f14a190 (patch)
tree2dd05c13aef3b3c6384bfa091b14633bc86e65a4 /src/Data/Array/Nested/Mixed.hs
parent19eab026f4f4c6a2d38ceb1fffa6062ba2637a46 (diff)
Refactor Nested (modules, function names)
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs741
1 files changed, 741 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
new file mode 100644
index 0000000..84e16b3
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -0,0 +1,741 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Mixed where
+
+import Control.DeepSeq (NFData)
+import Control.Monad (forM_, when)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Coerce
+import Data.Foldable (toList)
+import Data.Int
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty(..))
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
+import Foreign.C.Types (CInt)
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Mixed (XArray(..))
+import Data.Array.Mixed qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Lemmas
+
+
+-- Invariant in the API
+-- ====================
+--
+-- In the underlying XArray, there is some shape for elements of an empty
+-- array. For example, for this array:
+--
+-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
+-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
+--
+-- the two underlying XArrays have a shape, and those shapes might be anything.
+-- The invariant is that these element shapes are unobservable in the API.
+-- (This is possible because you ought to not be able to get to such an element
+-- without indexing out of bounds.)
+--
+-- Note, though, that the converse situation may arise: the outer array might
+-- be nonempty but then the inner arrays might. This is fine, an invariant only
+-- applies if the _outer_ array is empty.
+--
+-- TODO: can we enforce that the elements of an empty (nested) array have
+-- all-zero shape?
+-- -> no, because mlift and also any kind of internals probing from outsiders
+
+
+-- Primitive element types
+-- =======================
+--
+-- There are a few primitive element types; arrays containing elements of such
+-- type are a newtype over an XArray, which it itself a newtype over a Vector.
+-- Unfortunately, the setup of the library requires us to list these primitive
+-- element types multiple times; to aid in extending the list, all these lists
+-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class Storable a => PrimElt a where
+ fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+ toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
+
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Int
+instance PrimElt Int64
+instance PrimElt Int32
+instance PrimElt CInt
+instance PrimElt Float
+instance PrimElt Double
+instance PrimElt ()
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'mtranspose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
+-- that you're not supposed to see. In particular, you might see (nonempty)
+-- sizes of the elements of an empty array, which is information that should
+-- ostensibly not exist; the full array is still empty.
+
+data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
+ deriving (Show, Eq, Generic)
+
+-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
+deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a))
+
+instance NFData a => NFData (Mixed sh (Primitive a))
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int)
+deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64)
+deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32)
+deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt)
+deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float)
+deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double)
+deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ())
+
+data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
+instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b))
+-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
+
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
+instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a))
+
+
+-- | Internal helper data family mirroring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
+newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
+newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
+newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
+newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
+newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+
+
+mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+
+mliftNumElt2 :: PrimElt a
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 numEltAdd
+ (-) = mliftNumElt2 numEltSub
+ (*) = mliftNumElt2 numEltMul
+ negate = mliftNumElt1 numEltNeg
+ abs = mliftNumElt1 numEltAbs
+ signum = mliftNumElt1 numEltSignum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ recip = mliftNumElt1 floatEltRecip
+ (/) = mliftNumElt2 floatEltDiv
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ exp = mliftNumElt1 floatEltExp
+ log = mliftNumElt1 floatEltLog
+ sqrt = mliftNumElt1 floatEltSqrt
+
+ (**) = mliftNumElt2 floatEltPow
+ logBase = mliftNumElt2 floatEltLogbase
+
+ sin = mliftNumElt1 floatEltSin
+ cos = mliftNumElt1 floatEltCos
+ tan = mliftNumElt1 floatEltTan
+ asin = mliftNumElt1 floatEltAsin
+ acos = mliftNumElt1 floatEltAcos
+ atan = mliftNumElt1 floatEltAtan
+ sinh = mliftNumElt1 floatEltSinh
+ cosh = mliftNumElt1 floatEltCosh
+ tanh = mliftNumElt1 floatEltTanh
+ asinh = mliftNumElt1 floatEltAsinh
+ acosh = mliftNumElt1 floatEltAcosh
+ atanh = mliftNumElt1 floatEltAtanh
+ log1p = mliftNumElt1 floatEltLog1p
+ expm1 = mliftNumElt1 floatEltExpm1
+ log1pexp = mliftNumElt1 floatEltLog1pexp
+ log1mexp = mliftNumElt1 floatEltLog1mexp
+
+
+-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: Mixed sh a -> IShX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
+ mscalar :: a -> Mixed '[] a
+
+ -- | All arrays in the list, even subarrays inside @a@, must have the same
+ -- shape; if they do not, a runtime error will be thrown. See the
+ -- documentation of 'mgenerate' for more information about this restriction.
+ -- Furthermore, the length of the list must correspond with @n@: if @n@ is
+ -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
+ -- thrown.
+ --
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
+
+ -- | Note: this library makes no particular guarantees about the shapes of
+ -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
+ -- full 'XArray' and as such you can distinguish different empty arrays by
+ -- the "shapes" of their elements. This information is meaningless, so you
+ -- should not use it.
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- | See the documentation for 'mlift'.
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+
+ -- ====== PRIVATE METHODS ====== --
+
+ -- | Tree giving the shape of every array component.
+ type ShapeTree a
+
+ mshapeTree :: a -> ShapeTree a
+
+ mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
+
+ mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+
+ mshowShapeTree :: Proxy a -> ShapeTree a -> String
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- | Element types for which we have evidence of the (static part of the) shape
+-- in a type class constraint. Compare the instance contexts of the instances
+-- of this class with those of 'Elt': some instances have an additional
+-- "known-shape" constraint.
+--
+-- This class is (currently) only required for 'mgenerate' / 'rgenerate' /
+-- 'sgenerate'.
+class Elt a => KnownElt a where
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IShX sh -> Mixed sh a
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents.
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance Storable a => Elt (Primitive a) where
+ mshape (M_Primitive sh _) = sh
+ mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
+ mfromListOuter l@(arr1 :| _) =
+ let sh = SUnknown (length l) :$% mshape arr1
+ in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift ssh2 f (M_Primitive _ a)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , let result = f ZKX a
+ = M_Primitive (X.shape ssh2 result) result
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
+ mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
+ , let result = f ZKX a b
+ = M_Primitive (X.shape ssh3 result) result
+
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
+ let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShape sh) perm arr)
+
+ type ShapeTree (Primitive a) = ()
+ mshapeTree _ = ()
+ mshapeTreeEq _ () () = True
+ mshapeTreeEmpty _ () = False
+ mshowShapeTree _ () = "()"
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShape sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Int64 instance Elt Int64
+deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive CInt instance Elt CInt
+deriving via Primitive Double instance Elt Double
+deriving via Primitive Float instance Elt Float
+deriving via Primitive () instance Elt ()
+
+instance Storable a => KnownElt (Primitive a) where
+ memptyArray sh = M_Primitive sh (X.empty sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance KnownElt Int
+deriving via Primitive Int64 instance KnownElt Int64
+deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive CInt instance KnownElt CInt
+deriving via Primitive Double instance KnownElt Double
+deriving via Primitive Float instance KnownElt Float
+deriving via Primitive () instance KnownElt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+
+ mcast ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+
+ type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
+ mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
+ mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
+ mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance Elt a => Elt (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting depth because it repeatedly
+ -- truncates the shape vector to one a little shorter. Fix with a
+ -- moverlongShape method, a prefix of which is mshape.
+ mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
+ mshape (M_Nest sh arr)
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+
+ mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
+ mindex (M_Nest _ arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest sh arr) i
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mscalar = M_Nest ZSX
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
+ M_Nest (SUnknown (length l) :$% mshape arr)
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift ssh2 f (M_Nest sh1 arr) =
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ in M_Nest sh2 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
+ mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ in M_Nest sh3 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcast ssh1 sh2 _ (M_Nest sh1T arr)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (shxPermutePrefix perm sh)
+ (mtranspose perm arr)
+
+ type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+
+ mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+
+instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
+ memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
+
+ mvecsUnsafeNew sh example
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ where
+ sh' = mshape example
+
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
+
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+--
+-- __WARNING__: It is required that every @a@ returned by the argument to
+-- 'mgenerate' has the same shape. For example, the following will throw a
+-- runtime error:
+--
+-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
+-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
+-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
+-- > ...
+--
+-- because the size of the inner 'mgenerate' is not always the same (it depends
+-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
+-- the entire hierarchy (after distributing out tuples) must be a rectangular
+-- array. The type of 'mgenerate' allows this requirement to be broken very
+-- easily, hence the runtime check.
+mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum sh of
+ [] -> memptyArray sh
+ firstidx : restidxs ->
+ let firstelem = f (ixxZero' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEmpty (Proxy @a) shapetree
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ restidxs $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
+ mvecsFreeze sh vecs
+
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1P (M_Primitive (n :$% sh) arr) =
+ let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+
+mappend :: forall n m sh a. Elt a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
+ where
+ sn :$% sh = mshape arr1
+ sm :$% _ = mshape arr2
+ ssh = ssxFromShape sh
+ snm :: SMayNat () SNat (AddMaybe n m)
+ snm = case (sn, sm) of
+ (SUnknown{}, _) -> SUnknown ()
+ (SKnown{}, SUnknown{}) -> SUnknown ()
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
+
+ f :: forall sh' b. Storable b
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
+
+mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+
+mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+
+mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
+mtoVectorP (M_Primitive _ v) = X.toVector v
+
+mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
+mtoVector arr = mtoVectorP (toPrimitive arr)
+
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoListOuter
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
+
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr ZIX
+
+mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+ let sh1 = shxDropSSX sh ssh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
+
+-- | See the caveats at @X.rerank@.
+mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
+mrerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = ssxFromShape (mshape arr)
+ in mlift (ssxAppend (ssxFromShape sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+
+mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+mslice i n arr =
+ let _ :$% sh = mshape arr
+ in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+
+msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+
+mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
+mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+
+mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' arr =
+ mlift (ssxFromShape sh')
+ (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ arr
+
+miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
+miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+
+masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
+masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
+
+masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
+masXArrayPrim = masXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
+
+mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+
+mliftPrim :: PrimElt a
+ => (a -> a)
+ -> Mixed sh a -> Mixed sh a
+mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: PrimElt a
+ => (a -> a -> a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
+ fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))