diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 22:47:52 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 22:47:52 +0200 |
commit | 8b59d8ef4ff97936f2a753d1ce345e0404c26b2b (patch) | |
tree | 947f75cb43982fbdb551dc329f036b0591f3c2b2 /src/Data/Array/Nested/Internal/Mixed.hs | |
parent | f0752d67cd188f438280e1f0c692dc1f5f14a190 (diff) |
Clearer module purposes
Thanks Mikolaj for discussion
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 741 |
1 files changed, 741 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs new file mode 100644 index 0000000..98871d5 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -0,0 +1,741 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Internal.Mixed where + +import Control.DeepSeq (NFData) +import Control.Monad (forM_, when) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Foldable (toList) +import Data.Int +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.XArray (XArray(..)) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Lemmas + + +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + +-- Primitive element types +-- ======================= +-- +-- There are a few primitive element types; arrays containing elements of such +-- type are a newtype over an XArray, which it itself a newtype over a Vector. +-- Unfortunately, the setup of the library requires us to list these primitive +-- element types multiple times; to aid in extending the list, all these lists +-- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class Storable a => PrimElt a where + fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a + toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce + + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Int +instance PrimElt Int64 +instance PrimElt Int32 +instance PrimElt CInt +instance PrimElt Float +instance PrimElt Double +instance PrimElt () + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'mtranspose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, you might see (nonempty) +-- sizes of the elements of an empty array, which is information that should +-- ostensibly not exist; the full array is still empty. + +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) + deriving (Show, Eq, Generic) + +-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. +deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) + +instance NFData a => NFData (Mixed sh (Primitive a)) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector) +-- etc. + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) +deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) +deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) +deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) +deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) +deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) +deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) + +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) +-- etc., larger tuples (perhaps use generics to allow arbitrary product types) + +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) +deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) +instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) + + +-- | Internal helper data family mirroring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) +newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) + + +mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) + +mliftNumElt2 :: PrimElt a + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 numEltAdd + (-) = mliftNumElt2 numEltSub + (*) = mliftNumElt2 numEltMul + negate = mliftNumElt1 numEltNeg + abs = mliftNumElt1 numEltAbs + signum = mliftNumElt1 numEltSignum + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + recip = mliftNumElt1 floatEltRecip + (/) = mliftNumElt2 floatEltDiv + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + exp = mliftNumElt1 floatEltExp + log = mliftNumElt1 floatEltLog + sqrt = mliftNumElt1 floatEltSqrt + + (**) = mliftNumElt2 floatEltPow + logBase = mliftNumElt2 floatEltLogbase + + sin = mliftNumElt1 floatEltSin + cos = mliftNumElt1 floatEltCos + tan = mliftNumElt1 floatEltTan + asin = mliftNumElt1 floatEltAsin + acos = mliftNumElt1 floatEltAcos + atan = mliftNumElt1 floatEltAtan + sinh = mliftNumElt1 floatEltSinh + cosh = mliftNumElt1 floatEltCosh + tanh = mliftNumElt1 floatEltTanh + asinh = mliftNumElt1 floatEltAsinh + acosh = mliftNumElt1 floatEltAcosh + atanh = mliftNumElt1 floatEltAtanh + log1p = mliftNumElt1 floatEltLog1p + expm1 = mliftNumElt1 floatEltExpm1 + log1pexp = mliftNumElt1 floatEltLog1pexp + log1mexp = mliftNumElt1 floatEltLog1mexp + + +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where + -- ====== PUBLIC METHODS ====== -- + + mshape :: Mixed sh a -> IShX sh + mindex :: Mixed sh a -> IIxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a + + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] + + -- | Note: this library makes no particular guarantees about the shapes of + -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the + -- full 'XArray' and as such you can distinguish different empty arrays by + -- the "shapes" of their elements. This information is meaningless, so you + -- should not use it. + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a + + -- | See the documentation for 'mlift'. + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + + -- ====== PRIVATE METHODS ====== -- + + -- | Tree giving the shape of every array component. + type ShapeTree a + + mshapeTree :: a -> ShapeTree a + + mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + + mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + + mshowShapeTree :: Proxy a -> ShapeTree a -> String + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for 'mgenerate' / 'rgenerate' / +-- 'sgenerate'. +class Elt a => KnownElt a where + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArray :: IShX sh -> Mixed sh a + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance Storable a => Elt (Primitive a) where + mshape (M_Primitive sh _) = sh + mindex (M_Primitive _ a) i = Primitive (X.index a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) + mfromListOuter l@(arr1 :| _) = + let sh = SUnknown (length l) :$% mshape arr1 + in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) + mlift ssh2 f (M_Primitive _ a) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , let result = f ZKX a + = M_Primitive (X.shape ssh2 result) result + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) + mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 + , let result = f ZKX a b + = M_Primitive (X.shape ssh3 result) result + + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcast ssh1 sh2 _ (M_Primitive sh1' arr) = + let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + + mtranspose perm (M_Primitive sh arr) = + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShape sh) perm arr) + + type ShapeTree (Primitive a) = () + mshapeTree _ = () + mshapeTreeEq _ () () = True + mshapeTreeEmpty _ () = False + mshowShapeTree _ () = "()" + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + + -- TODO: this use of toVector is suboptimal + mvecsWritePartial + :: forall sh' sh s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShape sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance Elt Int +deriving via Primitive Int64 instance Elt Int64 +deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive CInt instance Elt CInt +deriving via Primitive Double instance Elt Double +deriving via Primitive Float instance Elt Float +deriving via Primitive () instance Elt () + +instance Storable a => KnownElt (Primitive a) where + memptyArray sh = M_Primitive sh (X.empty sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Int64 instance KnownElt Int64 +deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive CInt instance KnownElt CInt +deriving via Primitive Double instance KnownElt Double +deriving via Primitive Float instance KnownElt Float +deriving via Primitive () instance KnownElt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + + mcast ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + + type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) + mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' + mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) + +-- Arrays of arrays are just arrays, but with more dimensions. +instance Elt a => Elt (Mixed sh' a) where + -- TODO: this is quadratic in the nesting depth because it repeatedly + -- truncates the shape vector to one a little shorter. Fix with a + -- moverlongShape method, a prefix of which is mshape. + mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh + mshape (M_Nest sh arr) + = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + + mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a + mindex (M_Nest _ arr) i = mindexPartial arr i + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest sh arr) i + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + mscalar = M_Nest ZSX + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = + M_Nest (SUnknown (length l) :$% mshape arr) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) + mlift ssh2 f (M_Nest sh1 arr) = + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + in M_Nest sh2 result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) + mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + in M_Nest sh3 result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcast ssh1 sh2 _ (M_Nest sh1T arr) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = shxDropSh @sh @sh' (mshape arr) sh + , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (shxPermutePrefix perm sh) + (mtranspose perm arr) + + type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) + + mvecsUnsafeNew sh example + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) + where + sh' = mshape example + + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + + +-- | Create an array given a size and a function that computes the element at a +-- given index. +-- +-- __WARNING__: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> +-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> +-- > ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case shxEnum sh of + [] -> memptyArray sh + firstidx : restidxs -> + let firstelem = f (ixxZero' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArray sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs + +msumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = + let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX + in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + +mappend :: forall n m sh a. Elt a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 + where + sn :$% sh = mshape arr1 + sm :$% _ = mshape arr2 + ssh = ssxFromShape sh + snm :: SMayNat () SNat (AddMaybe n m) + snm = case (sn, sm) of + (SUnknown{}, _) -> SUnknown () + (SKnown{}, SUnknown{}) -> SUnknown () + (SKnown n, SKnown m) -> SKnown (snatPlus n m) + + f :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') + +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) + +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) + +mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a +mtoVectorP (M_Primitive _ v) = X.toVector v + +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a +mtoVector arr = mtoVectorP (toPrimitive arr) + +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? + +mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromList1Prim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) + +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr ZIX + +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) +mrerankP ssh sh2 f (M_Primitive sh arr) = + let sh1 = shxDropSSX sh ssh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) + +-- | See the caveats at @X.rerank@. +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b +mrerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = ssxFromShape (mshape arr) + in mlift (ssxAppend (ssxFromShape sh) ssh') + (\(sshT :: StaticShX shT) -> + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) + +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n arr = + let _ :$% sh = mshape arr + in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr + +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr + +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' arr = + mlift (ssxFromShape sh') + (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') + arr + +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + +masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +masXArrayPrimP (M_Primitive sh arr) = (sh, arr) + +masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) +masXArrayPrim = masXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr + +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP + +mliftPrim :: PrimElt a + => (a -> a) + -> Mixed sh a -> Mixed sh a +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) + +mliftPrim2 :: PrimElt a + => (a -> a -> a) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = + fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) |