From 554eff1ebc7bf4f467c8566a0e22b8a0cfb9d0a4 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 14 May 2025 19:16:21 +0200 Subject: Rename the three main public tensor API modules --- src/Data/Array/Nested.hs | 6 +- src/Data/Array/Nested/Internal/Convert.hs | 6 +- src/Data/Array/Nested/Internal/Mixed.hs | 955 ------------------------------ src/Data/Array/Nested/Internal/Ranked.hs | 559 ----------------- src/Data/Array/Nested/Internal/Shaped.hs | 495 ---------------- src/Data/Array/Nested/Mixed.hs | 955 ++++++++++++++++++++++++++++++ src/Data/Array/Nested/Ranked.hs | 559 +++++++++++++++++ src/Data/Array/Nested/Shaped.hs | 495 ++++++++++++++++ 8 files changed, 2015 insertions(+), 2015 deletions(-) delete mode 100644 src/Data/Array/Nested/Internal/Mixed.hs delete mode 100644 src/Data/Array/Nested/Internal/Ranked.hs delete mode 100644 src/Data/Array/Nested/Internal/Shaped.hs create mode 100644 src/Data/Array/Nested/Mixed.hs create mode 100644 src/Data/Array/Nested/Ranked.hs create mode 100644 src/Data/Array/Nested/Shaped.hs (limited to 'src/Data') diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 8198a54..af195ee 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -106,9 +106,9 @@ import Prelude hiding (mappend, mconcat) import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Convert -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shaped +import Data.Array.Nested.Mixed +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Shape diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs index 4e0f17d..611b45e 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -14,9 +14,9 @@ import Data.Type.Equality import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shaped +import Data.Array.Nested.Mixed +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs deleted file mode 100644 index a979d6c..0000000 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ /dev/null @@ -1,955 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Internal.Mixed where - -import Prelude hiding (mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (bimap) -import Data.Coerce -import Data.Foldable (toList) -import Data.Int -import Data.Kind (Constraint, Type) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.List.NonEmpty qualified as NE -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types (CInt) -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Mixed.Shape -import Data.Bag - - --- TODO: --- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a --- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int --- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) --- After benchmarking: matmul and matvec - - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - deriving (Show) - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class (Storable a, Elt a) => PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Bool -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Eq, Ord, Generic) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) -deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) - -deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) -deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - -showsMixedArray :: (Show a, Elt a) - => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ - -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ - -> Int -> Mixed sh a -> ShowS -showsMixedArray fromlistPrefix replicatePrefix d arr = - showParen (d > 10) $ - -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here - case mtoListLinear arr of - hd : _ : _ - | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> - showString replicatePrefix . showString " " . showsPrec 11 hd - _ -> - showString fromlistPrefix . showString " " . shows (mtoListLinear arr) - -instance (Show a, Elt a) => Show (Mixed sh a) where - showsPrec d arr = - let sh = show (shxToList (mshape arr)) - in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr - -instance Elt a => NFData (Mixed sh a) where - rnf = mrnf - - -mliftNumElt1 :: (PrimElt a, PrimElt b) - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) - -> Mixed sh a -> Mixed sh b -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) - -mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) - -> Mixed sh a -> Mixed sh b -> Mixed sh c -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) - | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) - | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 (liftO2 . numEltAdd) - (-) = mliftNumElt2 (liftO2 . numEltSub) - (*) = mliftNumElt2 (liftO2 . numEltMul) - negate = mliftNumElt1 (liftO1 . numEltNeg) - abs = mliftNumElt1 (liftO1 . numEltAbs) - signum = mliftNumElt1 (liftO1 . numEltSignum) - -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS - fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 (liftO1 . floatEltRecip) - (/) = mliftNumElt2 (liftO2 . floatEltDiv) - -instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 (liftO1 . floatEltExp) - log = mliftNumElt1 (liftO1 . floatEltLog) - sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) - - (**) = mliftNumElt2 (liftO2 . floatEltPow) - logBase = mliftNumElt2 (liftO2 . floatEltLogbase) - - sin = mliftNumElt1 (liftO1 . floatEltSin) - cos = mliftNumElt1 (liftO1 . floatEltCos) - tan = mliftNumElt1 (liftO1 . floatEltTan) - asin = mliftNumElt1 (liftO1 . floatEltAsin) - acos = mliftNumElt1 (liftO1 . floatEltAcos) - atan = mliftNumElt1 (liftO1 . floatEltAtan) - sinh = mliftNumElt1 (liftO1 . floatEltSinh) - cosh = mliftNumElt1 (liftO1 . floatEltCosh) - tanh = mliftNumElt1 (liftO1 . floatEltTanh) - asinh = mliftNumElt1 (liftO1 . floatEltAsinh) - acosh = mliftNumElt1 (liftO1 . floatEltAcosh) - atanh = mliftNumElt1 (liftO1 . floatEltAtanh) - log1p = mliftNumElt1 (liftO1 . floatEltLog1p) - expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) - log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) - log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) - -mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 (liftO2 . intEltQuot) -mremArray = mliftNumElt2 (liftO2 . intEltRem) - -matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - - mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - -- TODO: mliftL is currently unused. - -- | All arrays in the input must have equal shapes, including subarrays - -- inside their elements. - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) - - mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - - -- | All arrays in the input must have equal shapes, including subarrays - -- inside their elements. - mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a - - mrnf :: Mixed sh a -> () - - -- ====== PRIVATE METHODS ====== -- - - -- | Tree giving the shape of every array component. - type ShapeTree a - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Returns the stride vector of each underlying component array making up - -- this mixed array. - marrayStrides :: Mixed sh a -> Bag [Int] - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. -class Elt a => KnownElt a where - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArrayUnsafe :: IShX sh -> Mixed sh a - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive sh _) = sh - mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift ssh2 f (M_Primitive _ a) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , let result = f ZKX a - = M_Primitive (X.shape ssh2 result) result - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - , Refl <- lemAppNil @sh3 - , let result = f ZKX a b - = M_Primitive (X.shape ssh3 result) result - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) - mliftL ssh2 f l - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ - f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) - - mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = - let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - - mtranspose perm (M_Primitive sh arr) = - M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) - - mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) - mconcat l@(M_Primitive (_ :$% sh) _ :| _) = - let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) - in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result - - mrnf (M_Primitive sh a) = rnf sh `seq` rnf a - - type ShapeTree (Primitive a) = () - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Bool instance Elt Bool -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where - memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Bool instance KnownElt Bool -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) - mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) - mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) - mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - mliftL ssh2 f = - let unzipT2l [] = ([], []) - unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) - unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) - in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 - - mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) - - mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) - mconcat = - let unzipT2l [] = ([], []) - unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) - unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) - in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 - - mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b - - type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where - memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - - mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest sh arr) i - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest ZSX - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - - mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) - in M_Nest sh2 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) - in M_Nest sh3 result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) - -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) - mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = - let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) - in fmap (M_Nest sh2) result - where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - - f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) - f' sshT - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (ssxAppend ssh' sshT) - - mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 - in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) - => Perm is -> Mixed sh (Mixed sh' a) - -> Mixed (PermutePrefix is sh) (Mixed sh' a) - mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') - , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) - , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') - , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (shxPermutePrefix perm sh) - (mtranspose perm arr) - - mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mconcat l@(M_Nest sh1 _ :| _) = - let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result - - mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr - - type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) - - mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Nest _ arr) = marrayStrides arr - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - - mvecsUnsafeNew sh example - | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - -memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a -memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) - -mrank :: Elt a => Mixed sh a -> SNat (Rank sh) -mrank = shxRank . mshape - --- | The total number of elements in the array. -msize :: Elt a => Mixed sh a -> Int -msize = shxSize . mshape - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of - [] -> memptyArrayUnsafe sh - firstidx : restidxs -> - let firstelem = f (ixxZero' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArrayUnsafe sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = - let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr - -mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 - where - sn :$% sh = mshape arr1 - sm :$% _ = mshape arr2 - ssh = ssxFromShape sh - snm :: SMayNat () SNat (AddMaybe n m) - snm = case (sn, sm) of - (SUnknown{}, _) -> SUnknown () - (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (snatPlus n m) - - f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b - f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) - -mtoListLinear :: Elt a => Mixed sh a -> [a] -mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr - -munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a -munNest (M_Nest _ arr) = arr - -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 - -munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) -munzip (M_Tup2 a b) = (a, b) - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a - => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') - (\(sshT :: StaticShX shT) -> - case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of - Refl -> X.replicate sh (ssxAppend ssh' sshT)) - arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a - => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = - let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') - arr - -mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a -mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - --- | Throws if the array is empty. -mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh -mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) - --- | Throws if the array is empty. -mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh -mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) - -mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) - => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a -mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) - | Refl <- lemInitApp (Proxy @sh) (Proxy @n) - , Refl <- lemLastApp (Proxy @sh) (Proxy @n) - = case sh1 of - _ :$% _ - | sh1 == sh2 - , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> - fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) - | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" - ZSX -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'mdot1Inner' if applicable. -mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a -mdot a b = - munScalar $ - mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) - (fromPrimitive (mflatten (toPrimitive b))) - -mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -mtoXArrayPrim = mtoXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: (PrimElt a, PrimElt b) - => (a -> b) - -> Mixed sh a -> Mixed sh b -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) - => (a -> b -> c) - -> Mixed sh a -> Mixed sh b -> Mixed sh c -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = - fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) - => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr - --- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors -data SafeMCastSpec - = MCastId - | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec - | MCastForget - -type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint -type family SafeMCast spec sh1 sh2 where - SafeMCast MCastId sh sh = () - SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) - SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing - --- | This is an O(1) operation: the 'SafeMCast' constraint ensures that --- type-level shape information can only be forgotten, not introduced, and thus --- that no runtime shape checks are required. The @spec@ describes to --- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. --- --- To see how to construct the spec, read the equations of 'SafeMCast' closely. -mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a -mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs deleted file mode 100644 index 368e337..0000000 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ /dev/null @@ -1,559 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Ranked where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Foldable (toList) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Ranked.Shape -import Data.Array.Strided.Arith - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) - -instance (Show a, Elt a) => Show (Ranked n a) where - showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) - in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr - -instance Elt a => NFData (Ranked n a) where - rnf (Ranked arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where - mshape (M_Ranked arr) = mshape arr - mindex (M_Ranked arr) i = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i = - coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - - mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoListOuter (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift ssh2 f (M_Ranked arr) = - coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = - coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) - @(NonEmpty (Mixed sh2 (Ranked n a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - - mconcat l = M_Ranked (mconcat (coerce l)) - - mrnf (M_Ranked arr) = mrnf arr - - type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Ranked arr) = marrayStrides arr - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where - memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArrayUnsafe i - | Dict <- lemKnownReplicate (SNat @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - -liftRanked1 :: forall n a b. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) - -> Ranked n a -> Ranked n b -liftRanked1 = coerce - -liftRanked2 :: forall n a b c. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) - -> Ranked n a -> Ranked n b -> Ranked n c -liftRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where - (+) = liftRanked2 (+) - (-) = liftRanked2 (-) - (*) = liftRanked2 (*) - negate = liftRanked1 negate - abs = liftRanked1 abs - signum = liftRanked1 signum - fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" - recip = liftRanked1 recip - (/) = liftRanked2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" - exp = liftRanked1 exp - log = liftRanked1 log - sqrt = liftRanked1 sqrt - (**) = liftRanked2 (**) - logBase = liftRanked2 logBase - sin = liftRanked1 sin - cos = liftRanked1 cos - tan = liftRanked1 tan - asin = liftRanked1 asin - acos = liftRanked1 acos - atan = liftRanked1 atan - sinh = liftRanked1 sinh - cosh = liftRanked1 cosh - tanh = liftRanked1 tanh - asinh = liftRanked1 asinh - acosh = liftRanked1 acosh - atanh = liftRanked1 atanh - log1p = liftRanked1 GHC.Float.log1p - expm1 = liftRanked1 GHC.Float.expm1 - log1pexp = liftRanked1 GHC.Float.log1pexp - log1mexp = liftRanked1 GHC.Float.log1mexp - -rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -rquotArray = liftRanked2 mquotArray -rremArray = liftRanked2 mremArray - -ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -ratan2Array = liftRanked2 matan2Array - - -remptyArray :: KnownElt a => Ranked 1 a -remptyArray = mtoRanked (memptyArray ZSX) - -rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rrank :: Elt a => Ranked n a -> SNat n -rrank = shrRank . rshape - --- | The total number of elements in the array. -rsize :: Elt a => Ranked n a -> Int -rsize = shrSize . rshape - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) - (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f - | sn@SNat <- shrRank sh - , Dict <- lemKnownReplicate sn - , Refl <- lemRankReplicate sn - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. Elt a - => SNat n2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) - --- | See the documentation of 'mlift2'. -rlift2 :: forall n1 n2 n3 a. Elt a - => SNat n3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a -rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) - -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) - -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive - -rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a -rsumAllPrim (Ranked arr) = msumAllPrim arr - -rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a -rtranspose perm arr - | sn@SNat <- rrank arr - , Dict <- lemKnownReplicate sn - , length perm <= fromIntegral (natVal (Proxy @n)) - = rlift sn - (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) - arr - | otherwise - = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" - -rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a -rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce mconcat - -rappend :: forall n a. Elt a - => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend arr1 arr2 - | sn@SNat <- rrank arr1 - , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) - arr1 arr2 - -rscalar :: Elt a => a -> Ranked 0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v - | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = rfromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: PrimElt a => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - -rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) - -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) - -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) - -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) - -rtoListLinear :: Elt a => Ranked n a -> [a] -rtoListLinear (Ranked arr) = mtoListLinear arr - -rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a -rfromOrthotope sn arr - | Refl <- lemRankReplicate sn - = let xarr = XArray arr - in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) - -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a -rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) - = arr - -runScalar :: Elt a => Ranked 0 a -> a -runScalar arr = rindex arr ZIR - -rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) -rnest n arr - | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) - = coerce (mnest (ssxFromSNat n) (coerce arr)) - -runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a -runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) - | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked arr - -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) -rzip = coerce mzip - -runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) -runzip = coerce munzip - -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) - --- | If there is a zero-sized dimension in the @n@-prefix of the shape of the --- input array, then there is no way to deduce the full shape of the output --- array (more precisely, the @n2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the @n2@ part of the output shape with zeros. --- --- For example, if: --- --- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] --- f :: Ranked 2 Int -> Ranked 3 Float --- @ --- --- then: --- --- @ --- rrerank _ _ _ f arr :: Ranked 5 Float --- @ --- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr - -rreplicate :: forall n m a. Elt a - => IShR n -> Ranked m a -> Ranked (n + m) a -rreplicate sh (Ranked arr) - | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked (mreplicate (shCvtRX sh) arr) - -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x - | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shCvtRX sh) x) - -rreplicateScal :: forall n a. PrimElt a - => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) - -rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (rrank arr) - (\_ -> X.sliceU i n) - arr - -rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr - -rreshape :: forall n n' a. Elt a - => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' rarr@(Ranked arr) - | Dict <- lemKnownReplicate (rrank rarr) - , Dict <- lemKnownReplicate (shrRank sh') - = Ranked (mreshape (shCvtRX sh') arr) - -rflatten :: Elt a => Ranked n a -> Ranked 1 a -rflatten (Ranked arr) = mtoRanked (mflatten arr) - -riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a -riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota - --- | Throws if the array is empty. -rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n -rminIndexPrim rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mminIndexPrim arr) - --- | Throws if the array is empty. -rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n -rmaxIndexPrim rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mmaxIndexPrim arr) - -rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a -rdot1Inner arr1 arr2 - | SNat <- rrank arr1 - , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) - = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'rdot1Inner' if applicable. -rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a -rdot = coerce mdot - -rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) - -rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) - -rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a -rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) - -rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) -rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank rarr) - = mcast sshx arr diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs deleted file mode 100644 index 1415815..0000000 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ /dev/null @@ -1,495 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shaped where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS -import Data.Array.Internal.ShapedG qualified as SG -import Data.Array.Internal.ShapedS qualified as SS -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape -import Data.Array.Strided.Arith - - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) - -instance (Show a, Elt a) => Show (Shaped n a) where - showsPrec d arr@(Shaped marr) = - let sh = show (shsToList (sshape arr)) - in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr - -instance Elt a => NFData (Shaped sh a) where - rnf (Shaped arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) - -instance Elt a => Elt (Shaped sh a) where - mshape (M_Shaped arr) = mshape arr - mindex (M_Shaped arr) i = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - - mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoListOuter (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift ssh2 f (M_Shaped arr) = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = - coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) - @(NonEmpty (Mixed sh2 (Shaped sh a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - - mconcat l = M_Shaped (mconcat (coerce l)) - - mrnf (M_Shaped arr) = mrnf arr - - type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Shaped arr) = marrayStrides arr - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where - memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - -liftShaped1 :: forall sh a b. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) - -> Shaped sh a -> Shaped sh b -liftShaped1 = coerce - -liftShaped2 :: forall sh a b c. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) - -> Shaped sh a -> Shaped sh b -> Shaped sh c -liftShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where - (+) = liftShaped2 (+) - (-) = liftShaped2 (-) - (*) = liftShaped2 (*) - negate = liftShaped1 negate - abs = liftShaped1 abs - signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" - recip = liftShaped1 recip - (/) = liftShaped2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" - exp = liftShaped1 exp - log = liftShaped1 log - sqrt = liftShaped1 sqrt - (**) = liftShaped2 (**) - logBase = liftShaped2 logBase - sin = liftShaped1 sin - cos = liftShaped1 cos - tan = liftShaped1 tan - asin = liftShaped1 asin - acos = liftShaped1 acos - atan = liftShaped1 atan - sinh = liftShaped1 sinh - cosh = liftShaped1 cosh - tanh = liftShaped1 tanh - asinh = liftShaped1 asinh - acosh = liftShaped1 acosh - atanh = liftShaped1 atanh - log1p = liftShaped1 GHC.Float.log1p - expm1 = liftShaped1 GHC.Float.expm1 - log1pexp = liftShaped1 GHC.Float.log1pexp - log1mexp = liftShaped1 GHC.Float.log1mexp - -squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -squotArray = liftShaped2 mquotArray -sremArray = liftShaped2 mremArray - -satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -satan2Array = liftShaped2 matan2Array - - -semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -srank :: Elt a => Shaped sh a -> SNat (Rank sh) -srank = shsRank . sshape - --- | The total number of elements in the array. -ssize :: Elt a => Shaped sh a -> Int -ssize = shsSize . sshape - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a - => ShS sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a - => ShS sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a -ssumAllPrim (Shaped arr) = msumAllPrim arr - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) - => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - , Refl <- lemTakeLenMapJust perm (sshape sarr) - , Refl <- lemDropLenMapJust perm (sshape sarr) - , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) - = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 - -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) - -stoListLinear :: Elt a => Shaped sh a -> [a] -stoListLinear (Shaped arr) = mtoListLinear arr - -sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a -sfromOrthotope sh (SS.A (SG.A arr)) = - Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) - -stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a -stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) -snest sh arr - | Refl <- lemMapJustApp sh (Proxy @sh') - = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) - -sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a -sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) - | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') - = Shaped arr - -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) -szip = coerce mzip - -sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) -sunzip = coerce munzip - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh1) - , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = - let _ :$$ sh = sshape arr - in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - --- | Throws if the array is empty. -sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) - --- | Throws if the array is empty. -smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) - -sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) - => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a -sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) - | Refl <- lemInitApp (Proxy @sh) (Proxy @n) - , Refl <- lemLastApp (Proxy @sh) (Proxy @n) - = case sshape sarr1 of - _ :$$ _ - | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) - -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) - _ -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1Inner' if applicable. -sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a -sdot = coerce mdot - -stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) - -stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mcast sshx arr diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs new file mode 100644 index 0000000..50a1b71 --- /dev/null +++ b/src/Data/Array/Nested/Mixed.hs @@ -0,0 +1,955 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Mixed where + +import Prelude hiding (mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad (forM_, when) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Bifunctor (bimap) +import Data.Coerce +import Data.Foldable (toList) +import Data.Int +import Data.Kind (Constraint, Type) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty qualified as NE +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) + +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Types +import Data.Array.Mixed.XArray (XArray(..)) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Nested.Mixed.Shape +import Data.Bag + + +-- TODO: +-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +-- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int +-- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) +-- After benchmarking: matmul and matvec + + + +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + +-- Primitive element types +-- ======================= +-- +-- There are a few primitive element types; arrays containing elements of such +-- type are a newtype over an XArray, which it itself a newtype over a Vector. +-- Unfortunately, the setup of the library requires us to list these primitive +-- element types multiple times; to aid in extending the list, all these lists +-- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + deriving (Show) + +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class (Storable a, Elt a) => PrimElt a where + fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a + toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce + + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Bool +instance PrimElt Int +instance PrimElt Int64 +instance PrimElt Int32 +instance PrimElt CInt +instance PrimElt Float +instance PrimElt Double +instance PrimElt () + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'mtranspose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, you might see (nonempty) +-- sizes of the elements of an empty array, which is information that should +-- ostensibly not exist; the full array is still empty. + +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) + deriving (Eq, Ord, Generic) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector) +-- etc. + +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +-- etc., larger tuples (perhaps use generics to allow arbitrary product types) + +deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) +deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) + +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) + +deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) +deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) + + +-- | Internal helper data family mirroring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) +newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) + + +showsMixedArray :: (Show a, Elt a) + => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ + -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ + -> Int -> Mixed sh a -> ShowS +showsMixedArray fromlistPrefix replicatePrefix d arr = + showParen (d > 10) $ + -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here + case mtoListLinear arr of + hd : _ : _ + | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> + showString replicatePrefix . showString " " . showsPrec 11 hd + _ -> + showString fromlistPrefix . showString " " . shows (mtoListLinear arr) + +instance (Show a, Elt a) => Show (Mixed sh a) where + showsPrec d arr = + let sh = show (shxToList (mshape arr)) + in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr + +instance Elt a => NFData (Mixed sh a) where + rnf = mrnf + + +mliftNumElt1 :: (PrimElt a, PrimElt b) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) + -> Mixed sh a -> Mixed sh b +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) + +mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 (liftO2 . numEltAdd) + (-) = mliftNumElt2 (liftO2 . numEltSub) + (*) = mliftNumElt2 (liftO2 . numEltMul) + negate = mliftNumElt1 (liftO1 . numEltNeg) + abs = mliftNumElt1 (liftO1 . numEltAbs) + signum = mliftNumElt1 (liftO1 . numEltSignum) + -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS + fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + recip = mliftNumElt1 (liftO1 . floatEltRecip) + (/) = mliftNumElt2 (liftO2 . floatEltDiv) + +instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + exp = mliftNumElt1 (liftO1 . floatEltExp) + log = mliftNumElt1 (liftO1 . floatEltLog) + sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) + + (**) = mliftNumElt2 (liftO2 . floatEltPow) + logBase = mliftNumElt2 (liftO2 . floatEltLogbase) + + sin = mliftNumElt1 (liftO1 . floatEltSin) + cos = mliftNumElt1 (liftO1 . floatEltCos) + tan = mliftNumElt1 (liftO1 . floatEltTan) + asin = mliftNumElt1 (liftO1 . floatEltAsin) + acos = mliftNumElt1 (liftO1 . floatEltAcos) + atan = mliftNumElt1 (liftO1 . floatEltAtan) + sinh = mliftNumElt1 (liftO1 . floatEltSinh) + cosh = mliftNumElt1 (liftO1 . floatEltCosh) + tanh = mliftNumElt1 (liftO1 . floatEltTanh) + asinh = mliftNumElt1 (liftO1 . floatEltAsinh) + acosh = mliftNumElt1 (liftO1 . floatEltAcosh) + atanh = mliftNumElt1 (liftO1 . floatEltAtanh) + log1p = mliftNumElt1 (liftO1 . floatEltLog1p) + expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) + log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) + log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) + +mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem) + +matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) + +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where + -- ====== PUBLIC METHODS ====== -- + + mshape :: Mixed sh a -> IShX sh + mindex :: Mixed sh a -> IIxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a + + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] + + -- | Note: this library makes no particular guarantees about the shapes of + -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the + -- full 'XArray' and as such you can distinguish different empty arrays by + -- the "shapes" of their elements. This information is meaningless, so you + -- should not use it. + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a + + -- | See the documentation for 'mlift'. + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + + -- TODO: mliftL is currently unused. + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) + + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a + + mrnf :: Mixed sh a -> () + + -- ====== PRIVATE METHODS ====== -- + + -- | Tree giving the shape of every array component. + type ShapeTree a + + mshapeTree :: a -> ShapeTree a + + mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + + mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + + mshowShapeTree :: Proxy a -> ShapeTree a -> String + + -- | Returns the stride vector of each underlying component array making up + -- this mixed array. + marrayStrides :: Mixed sh a -> Bag [Int] + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for 'mgenerate', +-- 'Data.Array.Nested.Ranked.rgenerate' and +-- 'Data.Array.Nested.Shaped.sgenerate'. +class Elt a => KnownElt a where + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArrayUnsafe :: IShX sh -> Mixed sh a + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance Storable a => Elt (Primitive a) where + mshape (M_Primitive sh _) = sh + mindex (M_Primitive _ a) i = Primitive (X.index a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) + mfromListOuter l@(arr1 :| _) = + let sh = SUnknown (length l) :$% mshape arr1 + in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) + mlift ssh2 f (M_Primitive _ a) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , let result = f ZKX a + = M_Primitive (X.shape ssh2 result) result + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) + mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 + , let result = f ZKX a b + = M_Primitive (X.shape ssh3 result) result + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) + mliftL ssh2 f l + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ + f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) + + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = + let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + sh2 = shxCast' sh1 ssh2 + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + + mtranspose perm (M_Primitive sh arr) = + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShape sh) perm arr) + + mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) + mconcat l@(M_Primitive (_ :$% sh) _ :| _) = + let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result + + mrnf (M_Primitive sh a) = rnf sh `seq` rnf a + + type ShapeTree (Primitive a) = () + mshapeTree _ = () + mshapeTreeEq _ () () = True + mshapeTreeEmpty _ () = False + mshowShapeTree _ () = "()" + marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + + -- TODO: this use of toVector is suboptimal + mvecsWritePartial + :: forall sh' sh s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShape sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Bool instance Elt Bool +deriving via Primitive Int instance Elt Int +deriving via Primitive Int64 instance Elt Int64 +deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive CInt instance Elt CInt +deriving via Primitive Double instance Elt Double +deriving via Primitive Float instance Elt Float +deriving via Primitive () instance Elt () + +instance Storable a => KnownElt (Primitive a) where + memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Bool instance KnownElt Bool +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Int64 instance KnownElt Int64 +deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive CInt instance KnownElt CInt +deriving via Primitive Double instance KnownElt Double +deriving via Primitive Float instance KnownElt Float +deriving via Primitive () instance KnownElt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + mliftL ssh2 f = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 + + mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) + + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + mconcat = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 + + mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b + + type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) + mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' + mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) + +-- Arrays of arrays are just arrays, but with more dimensions. +instance Elt a => Elt (Mixed sh' a) where + -- TODO: this is quadratic in the nesting depth because it repeatedly + -- truncates the shape vector to one a little shorter. Fix with a + -- moverlongShape method, a prefix of which is mshape. + mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh + mshape (M_Nest sh arr) + = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + + mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a + mindex (M_Nest _ arr) i = mindexPartial arr i + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest sh arr) i + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + mscalar = M_Nest ZSX + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = + M_Nest (SUnknown (length l) :$% mshape arr) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) + mlift ssh2 f (M_Nest sh1 arr) = + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + in M_Nest sh2 result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) + mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + in M_Nest sh3 result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) + -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) + mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = + let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + in fmap (M_Nest sh2) result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + sh2 = shxCast' sh1 ssh2 + in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = shxDropSh @sh @sh' (mshape arr) sh + , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (shxPermutePrefix perm sh) + (mtranspose perm arr) + + mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mconcat l@(M_Nest sh1 _ :| _) = + let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result + + mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr + + type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Nest _ arr) = marrayStrides arr + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) + + mvecsUnsafeNew sh example + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) + where + sh' = mshape example + + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + + +memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) + +mrank :: Elt a => Mixed sh a -> SNat (Rank sh) +mrank = shxRank . mshape + +-- | The total number of elements in the array. +msize :: Elt a => Mixed sh a -> Int +msize = shxSize . mshape + +-- | Create an array given a size and a function that computes the element at a +-- given index. +-- +-- __WARNING__: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> +-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> +-- > ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case shxEnum sh of + [] -> memptyArrayUnsafe sh + firstidx : restidxs -> + let firstelem = f (ixxZero' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArrayUnsafe sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs + +msumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = + let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX + in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + +msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr + +mappend :: forall n m sh a. Elt a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 + where + sn :$% sh = mshape arr1 + sm :$% _ = mshape arr2 + ssh = ssxFromShape sh + snm :: SMayNat () SNat (AddMaybe n m) + snm = case (sn, sm) of + (SUnknown{}, _) -> SUnknown () + (SKnown{}, SUnknown{}) -> SUnknown () + (SKnown n, SKnown m) -> SKnown (snatPlus n m) + + f :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') + +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) + +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) + +mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a +mtoVectorP (M_Primitive _ v) = X.toVector v + +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a +mtoVector arr = mtoVectorP (toPrimitive arr) + +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? + +mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromList1Prim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) + +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1 l) + +mtoListLinear :: Elt a => Mixed sh a -> [a] +mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise + +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr ZIX + +mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) +mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr + +munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a +munNest (M_Nest _ arr) = arr + +mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip = M_Tup2 + +munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) +munzip (M_Tup2 a b) = (a, b) + +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) +mrerankP ssh sh2 f (M_Primitive sh arr) = + let sh1 = shxDropSSX sh ssh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) + +-- | See the caveats at @X.rerank@. +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b +mrerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = ssxFromShape (mshape arr) + in mlift (ssxAppend (ssxFromShape sh) ssh') + (\(sshT :: StaticShX shT) -> + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) + +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n arr = + let _ :$% sh = mshape arr + in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr + +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr + +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' arr = + mlift (ssxFromShape sh') + (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') + arr + +mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a +mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr + +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + +-- | Throws if the array is empty. +mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) + +-- | Throws if the array is empty. +mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) + +mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a +mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sh1 of + _ :$% _ + | sh1 == sh2 + , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) + | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" + ZSX -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'mdot1Inner' if applicable. +mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a +mdot a b = + munScalar $ + mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) + (fromPrimitive (mflatten (toPrimitive b))) + +mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) + +mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) +mtoXArrayPrim = mtoXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr + +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP + +mliftPrim :: (PrimElt a, PrimElt b) + => (a -> b) + -> Mixed sh a -> Mixed sh b +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) + => (a -> b -> c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = + fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr + +-- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors +data SafeMCastSpec + = MCastId + | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec + | MCastForget + +type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint +type family SafeMCast spec sh1 sh2 where + SafeMCast MCastId sh sh = () + SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) + SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing + +-- | This is an O(1) operation: the 'SafeMCast' constraint ensures that +-- type-level shape information can only be forgotten, not introduced, and thus +-- that no runtime shape checks are required. The @spec@ describes to +-- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. +-- +-- To see how to construct the spec, read the equations of 'SafeMCast' closely. +mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a +mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs new file mode 100644 index 0000000..fb5caa9 --- /dev/null +++ b/src/Data/Array/Nested/Ranked.hs @@ -0,0 +1,559 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Foldable (toList) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Types +import Data.Array.Mixed.XArray (XArray(..)) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Strided.Arith + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) + +instance (Show a, Elt a) => Show (Ranked n a) where + showsPrec d arr@(Ranked marr) = + let sh = show (toList (rshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr + +instance Elt a => NFData (Ranked n a) where + rnf (Ranked arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) + deriving (Generic) + +deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) + @(NonEmpty (Mixed sh2 (Ranked n a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + + mconcat l = M_Ranked (mconcat (coerce l)) + + mrnf (M_Ranked arr) = mrnf arr + + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + + mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Ranked arr) = marrayStrides arr + + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArrayUnsafe i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +liftRanked1 :: forall n a b. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) + -> Ranked n a -> Ranked n b +liftRanked1 = coerce + +liftRanked2 :: forall n a b c. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) + -> Ranked n a -> Ranked n b -> Ranked n c +liftRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where + (+) = liftRanked2 (+) + (-) = liftRanked2 (-) + (*) = liftRanked2 (*) + negate = liftRanked1 negate + abs = liftRanked1 abs + signum = liftRanked1 signum + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + recip = liftRanked1 recip + (/) = liftRanked2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + exp = liftRanked1 exp + log = liftRanked1 log + sqrt = liftRanked1 sqrt + (**) = liftRanked2 (**) + logBase = liftRanked2 logBase + sin = liftRanked1 sin + cos = liftRanked1 cos + tan = liftRanked1 tan + asin = liftRanked1 asin + acos = liftRanked1 acos + atan = liftRanked1 atan + sinh = liftRanked1 sinh + cosh = liftRanked1 cosh + tanh = liftRanked1 tanh + asinh = liftRanked1 asinh + acosh = liftRanked1 acosh + atanh = liftRanked1 atanh + log1p = liftRanked1 GHC.Float.log1p + expm1 = liftRanked1 GHC.Float.expm1 + log1pexp = liftRanked1 GHC.Float.log1pexp + log1mexp = liftRanked1 GHC.Float.log1mexp + +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = liftRanked2 mquotArray +rremArray = liftRanked2 mremArray + +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = liftRanked2 matan2Array + + +remptyArray :: KnownElt a => Ranked 1 a +remptyArray = mtoRanked (memptyArray ZSX) + +rshape :: Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shCvtXR' (mshape arr) + +rrank :: Elt a => Ranked n a -> SNat n +rrank = shrRank . rshape + +-- | The total number of elements in the array. +rsize :: Elt a => Ranked n a -> Int +rsize = shrSize . rshape + +rindex :: Elt a => Ranked n a -> IIxR n -> a +rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) + (ixCvtRX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate sh f + | sn@SNat <- shrRank sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn + = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) + +-- | See the documentation of 'mlift'. +rlift :: forall n1 n2 a. Elt a + => SNat n2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) + +-- | See the documentation of 'mlift2'. +rlift2 :: forall n1 n2 n3 a. Elt a + => SNat n3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a +rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) + +rsumOuter1P :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (msumOuter1P arr) + +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive + +rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a +rsumAllPrim (Ranked arr) = msumAllPrim arr + +rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a +rtranspose perm arr + | sn@SNat <- rrank arr + , Dict <- lemKnownReplicate sn + , length perm <= fromIntegral (natVal (Proxy @n)) + = rlift sn + (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + arr + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a +rconcat + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce mconcat + +rappend :: forall n a. Elt a + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend arr1 arr2 + | sn@SNat <- rrank arr1 + , Dict <- lemKnownReplicate sn + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) + arr1 arr2 + +rscalar :: Elt a => a -> Ranked 0 a +rscalar x = Ranked (mscalar x) + +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mfromVectorP (shCvtRX sh) v) + +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = rfromPrimitive (rfromVectorP sh v) + +rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a +rtoVectorP = coerce mtoVectorP + +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a +rtoVector = coerce mtoVector + +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) + +rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a +rfromList1Prim l = Ranked (mfromList1Prim l) + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) + +rtoList1 :: Elt a => Ranked 1 a -> [a] +rtoList1 = map runScalar . rtoListOuter + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) + +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = rreshape sh (rfromList1 l) + +rtoListLinear :: Elt a => Ranked n a -> [a] +rtoListLinear (Ranked arr) = mtoListLinear arr + +rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a +rfromOrthotope sn arr + | Refl <- lemRankReplicate sn + = let xarr = XArray arr + in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) + +rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) + | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) + = arr + +runScalar :: Elt a => Ranked 0 a -> a +runScalar arr = rindex arr ZIR + +rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) +rnest n arr + | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) + = coerce (mnest (ssxFromSNat n) (coerce arr)) + +runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a +runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) + | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked arr + +rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip = coerce mzip + +runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) +runzip = coerce munzip + +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) + => SNat n -> IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) +rrerankP sn sh2 f (Ranked arr) + | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) + , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) + = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr) + +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 5 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => SNat n -> IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked (n + n1) a -> Ranked (n + n2) b +rrerank sn sh2 f (rtoPrimitive -> arr) = + rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr + +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shCvtRX sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mreplicateScalP (shCvtRX sh) x) + +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) + +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = rlift (rrank arr) + (\_ -> X.sliceU i n) + arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = + rlift (rrank arr) + (\(_ :: StaticShX sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) + arr + +rreshape :: forall n n' a. Elt a + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' rarr@(Ranked arr) + | Dict <- lemKnownReplicate (rrank rarr) + , Dict <- lemKnownReplicate (shrRank sh') + = Ranked (mreshape (shCvtRX sh') arr) + +rflatten :: Elt a => Ranked n a -> Ranked 1 a +rflatten (Ranked arr) = mtoRanked (mflatten arr) + +riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a +riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota + +-- | Throws if the array is empty. +rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rminIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixCvtXR (mminIndexPrim arr) + +-- | Throws if the array is empty. +rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rmaxIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixCvtXR (mmaxIndexPrim arr) + +rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a +rdot1Inner arr1 arr2 + | SNat <- rrank arr1 + , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) + = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'rdot1Inner' if applicable. +rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a +rdot = coerce mdot + +rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) + +rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) + +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) + +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) + +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked arr + | Refl <- lemRankReplicate (shxRank (mshape arr)) + = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) + where + convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) + convSh ZSX = ZSX + convSh (smn :$% (sh :: IShX sh'T)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) + = SUnknown (fromSMayNat' smn) :$% convSh sh + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs new file mode 100644 index 0000000..ba767cd --- /dev/null +++ b/src/Data/Array/Nested/Shaped.hs @@ -0,0 +1,495 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Array.Internal.ShapedG qualified as SG +import Data.Array.Internal.ShapedS qualified as SS +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Types +import Data.Array.Mixed.XArray (XArray) +import Data.Array.Mixed.XArray qualified as X +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Strided.Arith + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray sh = Shaped (memptyArray (shCvtSX sh)) + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shCvtXS' (mshape arr) + +srank :: Elt a => Shaped sh a -> SNat (Rank sh) +srank = shsRank . sshape + +-- | The total number of elements in the array. +ssize :: Elt a => Shaped sh a -> Int +ssize = shsSize . sshape + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (ixCvtSX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a + => ShS sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a +ssumAllPrim (Shaped arr) = msumAllPrim arr + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemTakeLenMapJust perm (sshape sarr) + , Refl <- lemDropLenMapJust perm (sshape sarr) + , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 + +sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoList1 :: Elt a => Shaped '[n] a -> [a] +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) + +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) + +stoListLinear :: Elt a => Shaped sh a -> [a] +stoListLinear (Shaped arr) = mtoListLinear arr + +sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a +sfromOrthotope sh (SS.A (SG.A arr)) = + Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + +stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a +stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) +snest sh arr + | Refl <- lemMapJustApp sh (Proxy @sh') + = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) + +sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a +sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) + | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') + = Shaped arr + +szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip = coerce mzip + +sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) +sunzip = coerce munzip + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh1) + , Refl <- lemMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) + (shCvtSX sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = + sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) + +sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a +sflatten arr = + case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff + n@SNat -> sreshape (n :$$ ZSS) arr + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +-- | Throws if the array is empty. +sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) + +-- | Throws if the array is empty. +smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) + +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'sdot1Inner' if applicable. +sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a +sdot = coerce mdot + +stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) + +stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh + | Refl <- lemRankMapJust targetsh + = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr -- cgit v1.2.3-70-g09d2