diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-14 19:16:21 +0200 | 
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-14 19:16:35 +0200 | 
| commit | 554eff1ebc7bf4f467c8566a0e22b8a0cfb9d0a4 (patch) | |
| tree | 53cdebd831061cd13740134ab7dd5fdba3aa6b68 /src/Data/Array/Nested/Internal | |
| parent | 03626ae119438452551962359b5d445a4ddbc0b3 (diff) | |
Rename the three main public tensor API modules
Diffstat (limited to 'src/Data/Array/Nested/Internal')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Convert.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 955 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 559 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 495 | 
4 files changed, 3 insertions, 2012 deletions
| diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs index 4e0f17d..611b45e 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -14,9 +14,9 @@ import Data.Type.Equality  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shaped +import Data.Array.Nested.Mixed +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped  import Data.Array.Nested.Mixed.Shape  import Data.Array.Nested.Shaped.Shape diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs deleted file mode 100644 index a979d6c..0000000 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ /dev/null @@ -1,955 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Internal.Mixed where - -import Prelude hiding (mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (bimap) -import Data.Coerce -import Data.Foldable (toList) -import Data.Int -import Data.Kind (Constraint, Type) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.List.NonEmpty qualified as NE -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types (CInt) -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Mixed.Shape -import Data.Bag - - --- TODO: ---   sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a ---   rminIndex1 :: Ranked (n + 1) a -> Ranked n Int ---   gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) ---   After benchmarking: matmul and matvec - - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- ---   arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) ---   rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? ---   -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a -  deriving (Show) - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class (Storable a, Elt a) => PrimElt a where -  fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a -  toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - -  default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a -  fromPrimitive = coerce - -  default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) -  toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Bool -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) -  deriving (Eq, Ord, Generic) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic)  -- no content, orthotope optimises this (via Vector) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) -deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) - -deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) -deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ())  -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - -showsMixedArray :: (Show a, Elt a) -                => String  -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ -                -> String  -- ^ replicate prefix: e.g. @rreplicate [2,3]@ -                -> Int -> Mixed sh a -> ShowS -showsMixedArray fromlistPrefix replicatePrefix d arr = -  showParen (d > 10) $ -    -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here -    case mtoListLinear arr of -      hd : _ : _ -        | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> -            showString replicatePrefix . showString " " . showsPrec 11 hd -      _ -> -       showString fromlistPrefix . showString " " . shows (mtoListLinear arr) - -instance (Show a, Elt a) => Show (Mixed sh a) where -  showsPrec d arr = -    let sh = show (shxToList (mshape arr)) -    in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr - -instance Elt a => NFData (Mixed sh a) where -  rnf = mrnf - - -mliftNumElt1 :: (PrimElt a, PrimElt b) -             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) -             -> Mixed sh a -> Mixed sh b -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) - -mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) -             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) -             -> Mixed sh a -> Mixed sh b -> Mixed sh c -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) -  | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) -  | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where -  (+) = mliftNumElt2 (liftO2 . numEltAdd) -  (-) = mliftNumElt2 (liftO2 . numEltSub) -  (*) = mliftNumElt2 (liftO2 . numEltMul) -  negate = mliftNumElt1 (liftO1 . numEltNeg) -  abs = mliftNumElt1 (liftO1 . numEltAbs) -  signum = mliftNumElt1 (liftO1 . numEltSignum) -  -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS -  fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where -  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" -  recip = mliftNumElt1 (liftO1 . floatEltRecip) -  (/) = mliftNumElt2 (liftO2 . floatEltDiv) - -instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" -  exp = mliftNumElt1 (liftO1 . floatEltExp) -  log = mliftNumElt1 (liftO1 . floatEltLog) -  sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) - -  (**) = mliftNumElt2 (liftO2 . floatEltPow) -  logBase = mliftNumElt2 (liftO2 . floatEltLogbase) - -  sin = mliftNumElt1 (liftO1 . floatEltSin) -  cos = mliftNumElt1 (liftO1 . floatEltCos) -  tan = mliftNumElt1 (liftO1 . floatEltTan) -  asin = mliftNumElt1 (liftO1 . floatEltAsin) -  acos = mliftNumElt1 (liftO1 . floatEltAcos) -  atan = mliftNumElt1 (liftO1 . floatEltAtan) -  sinh = mliftNumElt1 (liftO1 . floatEltSinh) -  cosh = mliftNumElt1 (liftO1 . floatEltCosh) -  tanh = mliftNumElt1 (liftO1 . floatEltTanh) -  asinh = mliftNumElt1 (liftO1 . floatEltAsinh) -  acosh = mliftNumElt1 (liftO1 . floatEltAcosh) -  atanh = mliftNumElt1 (liftO1 . floatEltAtanh) -  log1p = mliftNumElt1 (liftO1 . floatEltLog1p) -  expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) -  log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) -  log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) - -mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 (liftO2 . intEltQuot) -mremArray = mliftNumElt2 (liftO2 . intEltRem) - -matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where -  -- ====== PUBLIC METHODS ====== -- - -  mshape :: Mixed sh a -> IShX sh -  mindex :: Mixed sh a -> IIxX sh -> a -  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a -  mscalar :: a -> Mixed '[] a - -  -- | All arrays in the list, even subarrays inside @a@, must have the same -  -- shape; if they do not, a runtime error will be thrown. See the -  -- documentation of 'mgenerate' for more information about this restriction. -  -- Furthermore, the length of the list must correspond with @n@: if @n@ is -  -- @Just m@ and @m@ does not equal the length of the list, a runtime error is -  -- thrown. -  -- -  -- Consider also 'mfromListPrim', which can avoid intermediate arrays. -  mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - -  mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - -  -- | Note: this library makes no particular guarantees about the shapes of -  -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the -  -- full 'XArray' and as such you can distinguish different empty arrays by -  -- the "shapes" of their elements. This information is meaningless, so you -  -- should not use it. -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -        -> Mixed sh1 a -> Mixed sh2 a - -  -- | See the documentation for 'mlift'. -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -         -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - -  -- TODO: mliftL is currently unused. -  -- | All arrays in the input must have equal shapes, including subarrays -  -- inside their elements. -  mliftL :: forall sh1 sh2. -            StaticShX sh2 -         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -         -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) - -  mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 -               => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - -  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) -             => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - -  -- | All arrays in the input must have equal shapes, including subarrays -  -- inside their elements. -  mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a - -  mrnf :: Mixed sh a -> () - -  -- ====== PRIVATE METHODS ====== -- - -  -- | Tree giving the shape of every array component. -  type ShapeTree a - -  mshapeTree :: a -> ShapeTree a - -  mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - -  mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - -  mshowShapeTree :: Proxy a -> ShapeTree a -> String - -  -- | Returns the stride vector of each underlying component array making up -  -- this mixed array. -  marrayStrides :: Mixed sh a -> Bag [Int] - -  -- | Given the shape of this array, an index and a value, write the value at -  -- that index in the vectors. -  mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - -  -- | Given the shape of this array, an index and a value, write the value at -  -- that index in the vectors. -  mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - -  -- | Given the shape of this array, finalise the vectors into 'XArray's. -  mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. -class Elt a => KnownElt a where -  -- | Create an empty array. The given shape must have size zero; this may or may not be checked. -  memptyArrayUnsafe :: IShX sh -> Mixed sh a - -  -- | Create uninitialised vectors for this array type, given the shape of -  -- this vector and an example for the contents. -  mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - -  mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where -  mshape (M_Primitive sh _) = sh -  mindex (M_Primitive _ a) i = Primitive (X.index a i) -  mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) -  mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) -  mfromListOuter l@(arr1 :| _) = -    let sh = SUnknown (length l) :$% mshape arr1 -    in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) -  mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -        -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -  mlift ssh2 f (M_Primitive _ a) -    | Refl <- lemAppNil @sh1 -    , Refl <- lemAppNil @sh2 -    , let result = f ZKX a -    = M_Primitive (X.shape ssh2 result) result - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -         -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) -  mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) -    | Refl <- lemAppNil @sh1 -    , Refl <- lemAppNil @sh2 -    , Refl <- lemAppNil @sh3 -    , let result = f ZKX a b -    = M_Primitive (X.shape ssh3 result) result - -  mliftL :: forall sh1 sh2. -            StaticShX sh2 -         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -         -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) -  mliftL ssh2 f l -    | Refl <- lemAppNil @sh1 -    , Refl <- lemAppNil @sh2 -    = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ -        f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) - -  mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 -               => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) -  mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = -    let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' -        sh2 = shxCast' sh1 ssh2 -    in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - -  mtranspose perm (M_Primitive sh arr) = -    M_Primitive (shxPermutePrefix perm sh) -                (X.transpose (ssxFromShape sh) perm arr) - -  mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) -  mconcat l@(M_Primitive (_ :$% sh) _ :| _) = -    let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) -    in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result - -  mrnf (M_Primitive sh a) = rnf sh `seq` rnf a - -  type ShapeTree (Primitive a) = () -  mshapeTree _ = () -  mshapeTreeEq _ () () = True -  mshapeTreeEmpty _ () = False -  mshowShapeTree _ () = "()" -  marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) -  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - -  -- TODO: this use of toVector is suboptimal -  mvecsWritePartial -    :: forall sh' sh s. -       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () -  mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do -    let arrsh = X.shape (ssxFromShape sh') arr -        offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) -    VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - -  mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Bool instance Elt Bool -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where -  memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) -  mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) -  mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Bool instance KnownElt Bool -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where -  mshape (M_Tup2 a _) = mshape a -  mindex (M_Tup2 a b) i = (mindex a i, mindex b i) -  mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) -  mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) -  mfromListOuter l = -    M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) -           (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) -  mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) -  mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) -  mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) -  mliftL ssh2 f = -    let unzipT2l [] = ([], []) -        unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) -        unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) -    in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 - -  mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = -    M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) - -  mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) -  mconcat = -    let unzipT2l [] = ([], []) -        unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) -        unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) -    in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 - -  mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b - -  type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) -  mshapeTree (x, y) = (mshapeTree x, mshapeTree y) -  mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' -  mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 -  mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" -  marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b -  mvecsWrite sh i (x, y) (MV_Tup2 a b) = do -    mvecsWrite sh i x a -    mvecsWrite sh i y b -  mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do -    mvecsWritePartial sh i x a -    mvecsWritePartial sh i y b -  mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where -  memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) -  mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y -  mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where -  -- TODO: this is quadratic in the nesting depth because it repeatedly -  -- truncates the shape vector to one a little shorter. Fix with a -  -- moverlongShape method, a prefix of which is mshape. -  mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh -  mshape (M_Nest sh arr) -    = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - -  mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a -  mindex (M_Nest _ arr) i = mindexPartial arr i - -  mindexPartial :: forall sh1 sh2. -                   Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -  mindexPartial (M_Nest sh arr) i -    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - -  mscalar = M_Nest ZSX - -  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) -  mfromListOuter l@(arr :| _) = -    M_Nest (SUnknown (length l) :$% mshape arr) -           (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - -  mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -        -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -  mlift ssh2 f (M_Nest sh1 arr) = -    let result = mlift (ssxAppend ssh2 ssh') f' arr -        (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) -    in M_Nest sh2 result -    where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - -      f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -      f' sshT -        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        = f (ssxAppend ssh' sshT) - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -         -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) -  mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = -    let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 -        (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) -    in M_Nest sh3 result -    where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - -      f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b -      f' sshT -        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) -        = f (ssxAppend ssh' sshT) - -  mliftL :: forall sh1 sh2. -            StaticShX sh2 -         -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) -         -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) -  mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = -    let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) -        (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) -    in fmap (M_Nest sh2) result -    where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - -      f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) -      f' sshT -        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        = f (ssxAppend ssh' sshT) - -  mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 -               => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) -  mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) -    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') -    , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') -    = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T -          sh2 = shxCast' sh1 ssh2 -      in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) - -  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) -             => Perm is -> Mixed sh (Mixed sh' a) -             -> Mixed (PermutePrefix is sh) (Mixed sh' a) -  mtranspose perm (M_Nest sh arr) -    | let sh' = shxDropSh @sh @sh' (mshape arr) sh -    , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') -    , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) -    , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') -    , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') -    , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') -    = M_Nest (shxPermutePrefix perm sh) -             (mtranspose perm arr) - -  mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) -  mconcat l@(M_Nest sh1 _ :| _) = -    let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) -    in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result - -  mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr - -  type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) - -  mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) -  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - -  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - -  mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - -  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - -  marrayStrides (M_Nest _ arr) = marrayStrides arr - -  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - -  mvecsWritePartial :: forall sh1 sh2 s. -                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -                    -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -                    -> ST s () -  mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) -    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - -  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -  memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - -  mvecsUnsafeNew sh example -    | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) -    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) -    where -      sh' = mshape example - -  mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - -memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a -memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) - -mrank :: Elt a => Mixed sh a -> SNat (Rank sh) -mrank = shxRank . mshape - --- | The total number of elements in the array. -msize :: Elt a => Mixed sh a -> Int -msize = shxSize . mshape - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- >         mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- >           ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of -  [] -> memptyArrayUnsafe sh -  firstidx : restidxs -> -    let firstelem = f (ixxZero' sh) -        shapetree = mshapeTree firstelem -    in if mshapeTreeEmpty (Proxy @a) shapetree -         then memptyArrayUnsafe sh -         else runST $ do -                vecs <- mvecsUnsafeNew sh firstelem -                mvecsWrite sh firstidx firstelem vecs -                -- TODO: This is likely fine if @a@ is big, but if @a@ is a -                -- scalar this array copying inefficient. Should improve this. -                forM_ restidxs $ \idx -> do -                  let val = f idx -                  when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ -                    error "Data.Array.Nested mgenerate: generated values do not have equal shapes" -                  mvecsWrite sh idx val vecs -                mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) -            => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = -  let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX -  in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) -           => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr - -mappend :: forall n m sh a. Elt a -        => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 -  where -    sn :$% sh = mshape arr1 -    sm :$% _ = mshape arr2 -    ssh = ssxFromShape sh -    snm :: SMayNat () SNat (AddMaybe n m) -    snm = case (sn, sm) of -            (SUnknown{}, _) -> SUnknown () -            (SKnown{}, SUnknown{}) -> SUnknown () -            (SKnown n, SKnown m) -> SKnown (snatPlus n m) - -    f :: forall sh' b. Storable b -      => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b -    f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar  -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = -  let ssh = SUnknown () :!% ZKX -      xarr = X.fromList1 ssh l -  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = -  let ssh = SUnknown () :!% ZKX -      xarr = X.fromList1 ssh l -  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = -  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) - -mtoListLinear :: Elt a => Mixed sh a -> [a] -mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr))  -- TODO: optimise - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr - -munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a -munNest (M_Nest _ arr) = arr - -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 - -munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) -munzip (M_Tup2 a b) = (a, b) - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -         => StaticShX sh -> IShX sh2 -         -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -         -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = -  let sh1 = shxDropSSX sh ssh -  in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) -                 (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) -                           (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) -                           arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) -        => StaticShX sh -> IShX sh2 -        -> (Mixed sh1 a -> Mixed sh2 b) -        -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = -  fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a -           => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = -  let ssh' = ssxFromShape (mshape arr) -  in mlift (ssxAppend (ssxFromShape sh) ssh') -           (\(sshT :: StaticShX shT) -> -              case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of -                Refl -> X.replicate sh (ssxAppend ssh' sshT)) -           arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a -               => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = -  let _ :$% sh = mshape arr -  in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = -  mlift (ssxFromShape sh') -        (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') -        arr - -mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a -mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - --- | Throws if the array is empty. -mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh -mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) - --- | Throws if the array is empty. -mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh -mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) - -mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) -           => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a -mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) -  | Refl <- lemInitApp (Proxy @sh) (Proxy @n) -  , Refl <- lemLastApp (Proxy @sh) (Proxy @n) -  = case sh1 of -      _ :$% _ -        | sh1 == sh2 -        , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> -            fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) -        | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" -      ZSX -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'mdot1Inner' if applicable. -mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a -mdot a b = -  munScalar $ -    mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) -                     (fromPrimitive (mflatten (toPrimitive b))) - -mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -mtoXArrayPrim = mtoXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: (PrimElt a, PrimElt b) -          => (a -> b) -          -> Mixed sh a -> Mixed sh b -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) -           => (a -> b -> c) -           -> Mixed sh a -> Mixed sh b -> Mixed sh c -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = -  fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) -      => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr -  | Refl <- lemAppNil @sh1 -  , Refl <- lemAppNil @sh2 -  = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr - --- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors -data SafeMCastSpec -  = MCastId -  | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec -  | MCastForget - -type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint -type family SafeMCast spec sh1 sh2 where -  SafeMCast MCastId sh sh = () -  SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) -  SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing - --- | This is an O(1) operation: the 'SafeMCast' constraint ensures that --- type-level shape information can only be forgotten, not introduced, and thus --- that no runtime shape checks are required. The @spec@ describes to --- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. --- --- To see how to construct the spec, read the equations of 'SafeMCast' closely. -mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a -mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs deleted file mode 100644 index 368e337..0000000 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ /dev/null @@ -1,559 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Ranked where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Foldable (toList) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Ranked.Shape -import Data.Array.Strided.Arith - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) - -instance (Show a, Elt a) => Show (Ranked n a) where -  showsPrec d arr@(Ranked marr) = -    let sh = show (toList (rshape arr)) -    in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr - -instance Elt a => NFData (Ranked n a) where -  rnf (Ranked arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -  deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where -  mshape (M_Ranked arr) = mshape arr -  mindex (M_Ranked arr) i = Ranked (mindex arr i) - -  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) -  mindexPartial (M_Ranked arr) i = -    coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ -        mindexPartial arr i - -  mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - -  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) -  mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - -  mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] -  mtoListOuter (M_Ranked arr) = -    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -        -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -  mlift ssh2 f (M_Ranked arr) = -    coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ -      mlift ssh2 f arr - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -         -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) -  mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = -    coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ -      mlift2 ssh3 f arr1 arr2 - -  mliftL :: forall sh1 sh2. -            StaticShX sh2 -         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -         -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) -  mliftL ssh2 f l = -    coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) -           @(NonEmpty (Mixed sh2 (Ranked n a))) $ -      mliftL ssh2 f (coerce l) - -  mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) - -  mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - -  mconcat l = M_Ranked (mconcat (coerce l)) - -  mrnf (M_Ranked arr) = mrnf arr - -  type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - -  mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - -  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - -  mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - -  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - -  marrayStrides (M_Ranked arr) = marrayStrides arr - -  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () -  mvecsWrite sh idx (Ranked arr) vecs = -    mvecsWrite sh idx arr -      (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) -         vecs) - -  mvecsWritePartial :: forall sh sh' s. -                       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) -                    -> MixedVecs s (sh ++ sh') (Ranked n a) -                    -> ST s () -  mvecsWritePartial sh idx arr vecs = -    mvecsWritePartial sh idx -      (coerce @(Mixed sh' (Ranked n a)) -              @(Mixed sh' (Mixed (Replicate n Nothing) a)) -         arr) -      (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) -              @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) -         vecs) - -  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) -  mvecsFreeze sh vecs = -    coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) -           @(Mixed sh (Ranked n a)) -      <$> mvecsFreeze sh -            (coerce @(MixedVecs s sh (Ranked n a)) -                    @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) -                    vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where -  memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) -  memptyArrayUnsafe i -    | Dict <- lemKnownReplicate (SNat @n) -    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ -        memptyArrayUnsafe i - -  mvecsUnsafeNew idx (Ranked arr) -    | Dict <- lemKnownReplicate (SNat @n) -    = MV_Ranked <$> mvecsUnsafeNew idx arr - -  mvecsNewEmpty _ -    | Dict <- lemKnownReplicate (SNat @n) -    = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - -liftRanked1 :: forall n a b. -               (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) -            -> Ranked n a -> Ranked n b -liftRanked1 = coerce - -liftRanked2 :: forall n a b c. -               (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) -            -> Ranked n a -> Ranked n b -> Ranked n c -liftRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where -  (+) = liftRanked2 (+) -  (-) = liftRanked2 (-) -  (*) = liftRanked2 (*) -  negate = liftRanked1 negate -  abs = liftRanked1 abs -  signum = liftRanked1 signum -  fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where -  fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" -  recip = liftRanked1 recip -  (/) = liftRanked2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where -  pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" -  exp = liftRanked1 exp -  log = liftRanked1 log -  sqrt = liftRanked1 sqrt -  (**) = liftRanked2 (**) -  logBase = liftRanked2 logBase -  sin = liftRanked1 sin -  cos = liftRanked1 cos -  tan = liftRanked1 tan -  asin = liftRanked1 asin -  acos = liftRanked1 acos -  atan = liftRanked1 atan -  sinh = liftRanked1 sinh -  cosh = liftRanked1 cosh -  tanh = liftRanked1 tanh -  asinh = liftRanked1 asinh -  acosh = liftRanked1 acosh -  atanh = liftRanked1 atanh -  log1p = liftRanked1 GHC.Float.log1p -  expm1 = liftRanked1 GHC.Float.expm1 -  log1pexp = liftRanked1 GHC.Float.log1pexp -  log1mexp = liftRanked1 GHC.Float.log1mexp - -rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -rquotArray = liftRanked2 mquotArray -rremArray = liftRanked2 mremArray - -ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -ratan2Array = liftRanked2 matan2Array - - -remptyArray :: KnownElt a => Ranked 1 a -remptyArray = mtoRanked (memptyArray ZSX) - -rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rrank :: Elt a => Ranked n a -> SNat n -rrank = shrRank . rshape - --- | The total number of elements in the array. -rsize :: Elt a => Ranked n a -> Int -rsize = shrSize . rshape - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = -  Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) -            (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) -            (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f -  | sn@SNat <- shrRank sh -  , Dict <- lemKnownReplicate sn -  , Refl <- lemRankReplicate sn -  = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. Elt a -      => SNat n2 -      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -      -> Ranked n1 a -> Ranked n2 a -rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) - --- | See the documentation of 'mlift2'. -rlift2 :: forall n1 n2 n3 a. Elt a -       => SNat n3 -       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) -       -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a -rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) - -rsumOuter1P :: forall n a. -               (Storable a, NumElt a) -            => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = Ranked (msumOuter1P arr) - -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) -           => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive - -rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a -rsumAllPrim (Ranked arr) = msumAllPrim arr - -rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a -rtranspose perm arr -  | sn@SNat <- rrank arr -  , Dict <- lemKnownReplicate sn -  , length perm <= fromIntegral (natVal (Proxy @n)) -  = rlift sn -          (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) -          arr -  | otherwise -  = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" - -rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a -rconcat -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = coerce mconcat - -rappend :: forall n a. Elt a -        => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend arr1 arr2 -  | sn@SNat <- rrank arr1 -  , Dict <- lemKnownReplicate sn -  , Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) -      arr1 arr2 - -rscalar :: Elt a => a -> Ranked 0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v -  | Dict <- lemKnownReplicate (shrRank sh) -  = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = rfromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: PrimElt a => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - -rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) - -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) - -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = -  let ssh = SUnknown () :!% ZKX -      xarr = X.fromList1 ssh l -  in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = -  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) - -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) - -rtoListLinear :: Elt a => Ranked n a -> [a] -rtoListLinear (Ranked arr) = mtoListLinear arr - -rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a -rfromOrthotope sn arr -  | Refl <- lemRankReplicate sn -  = let xarr = XArray arr -    in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) - -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a -rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) -  | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) -  = arr - -runScalar :: Elt a => Ranked 0 a -> a -runScalar arr = rindex arr ZIR - -rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) -rnest n arr -  | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) -  = coerce (mnest (ssxFromSNat n) (coerce arr)) - -runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a -runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) -  | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) -  = Ranked arr - -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) -rzip = coerce mzip - -runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) -runzip = coerce munzip - -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) -         => SNat n -> IShR n2 -         -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) -         -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) -  | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) -  , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) -  = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) -                     (\a -> let Ranked r = f (Ranked a) in r) -                     arr) - --- | If there is a zero-sized dimension in the @n@-prefix of the shape of the --- input array, then there is no way to deduce the full shape of the output --- array (more precisely, the @n2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the @n2@ part of the output shape with zeros. --- --- For example, if: --- --- @ --- arr :: Ranked 5 Int   -- of shape [3, 0, 4, 2, 21] --- f :: Ranked 2 Int -> Ranked 3 Float --- @ --- --- then: --- --- @ --- rrerank _ _ _ f arr :: Ranked 5 Float --- @ --- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) -        => SNat n -> IShR n2 -        -> (Ranked n1 a -> Ranked n2 b) -        -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = -  rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr - -rreplicate :: forall n m a. Elt a -           => IShR n -> Ranked m a -> Ranked (n + m) a -rreplicate sh (Ranked arr) -  | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) -  = Ranked (mreplicate (shCvtRX sh) arr) - -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x -  | Dict <- lemKnownReplicate (shrRank sh) -  = Ranked (mreplicateScalP (shCvtRX sh) x) - -rreplicateScal :: forall n a. PrimElt a -               => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) - -rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = rlift (rrank arr) -          (\_ -> X.sliceU i n) -          arr - -rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = -  rlift (rrank arr) -        (\(_ :: StaticShX sh') -> -          case lemReplicateSucc @(Nothing @Nat) @n of -            Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) -        arr - -rreshape :: forall n n' a. Elt a -         => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' rarr@(Ranked arr) -  | Dict <- lemKnownReplicate (rrank rarr) -  , Dict <- lemKnownReplicate (shrRank sh') -  = Ranked (mreshape (shCvtRX sh') arr) - -rflatten :: Elt a => Ranked n a -> Ranked 1 a -rflatten (Ranked arr) = mtoRanked (mflatten arr) - -riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a -riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota - --- | Throws if the array is empty. -rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n -rminIndexPrim rarr@(Ranked arr) -  | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) -  = ixCvtXR (mminIndexPrim arr) - --- | Throws if the array is empty. -rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n -rmaxIndexPrim rarr@(Ranked arr) -  | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) -  = ixCvtXR (mmaxIndexPrim arr) - -rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a -rdot1Inner arr1 arr2 -  | SNat <- rrank arr1 -  , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) -  = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'rdot1Inner' if applicable. -rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a -rdot = coerce mdot - -rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) - -rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) - -rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a -rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) - -rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) -rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr -  | Refl <- lemRankReplicate (shxRank (mshape arr)) -  = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) -  where -    convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) -    convSh ZSX = ZSX -    convSh (smn :$% (sh :: IShX sh'T)) -      | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) -      = SUnknown (fromSMayNat' smn) :$% convSh sh - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) -  | Refl <- lemRankReplicate (rrank rarr) -  = mcast sshx arr diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs deleted file mode 100644 index 1415815..0000000 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ /dev/null @@ -1,495 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shaped where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS -import Data.Array.Internal.ShapedG qualified as SG -import Data.Array.Internal.ShapedS qualified as SS -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape -import Data.Array.Strided.Arith - - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) - -instance (Show a, Elt a) => Show (Shaped n a) where -  showsPrec d arr@(Shaped marr) = -    let sh = show (shsToList (sshape arr)) -    in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr - -instance Elt a => NFData (Shaped sh a) where -  rnf (Shaped arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) -  deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) - -instance Elt a => Elt (Shaped sh a) where -  mshape (M_Shaped arr) = mshape arr -  mindex (M_Shaped arr) i = Shaped (mindex arr i) - -  mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -  mindexPartial (M_Shaped arr) i = -    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ -      mindexPartial arr i - -  mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - -  mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) -  mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - -  mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] -  mtoListOuter (M_Shaped arr) -    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -        -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -  mlift ssh2 f (M_Shaped arr) = -    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ -      mlift ssh2 f arr - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -         -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) -  mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = -    coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ -      mlift2 ssh3 f arr1 arr2 - -  mliftL :: forall sh1 sh2. -            StaticShX sh2 -         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -         -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) -  mliftL ssh2 f l = -    coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) -           @(NonEmpty (Mixed sh2 (Shaped sh a))) $ -      mliftL ssh2 f (coerce l) - -  mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) - -  mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - -  mconcat l = M_Shaped (mconcat (coerce l)) - -  mrnf (M_Shaped arr) = mrnf arr - -  type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - -  mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - -  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - -  mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - -  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - -  marrayStrides (M_Shaped arr) = marrayStrides arr - -  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () -  mvecsWrite sh idx (Shaped arr) vecs = -    mvecsWrite sh idx arr -      (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) -         vecs) - -  mvecsWritePartial :: forall sh1 sh2 s. -                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -                    -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -                    -> ST s () -  mvecsWritePartial sh idx arr vecs = -    mvecsWritePartial sh idx -      (coerce @(Mixed sh2 (Shaped sh a)) -              @(Mixed sh2 (Mixed (MapJust sh) a)) -         arr) -      (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) -              @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) -         vecs) - -  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) -  mvecsFreeze sh vecs = -    coerce @(Mixed sh' (Mixed (MapJust sh) a)) -           @(Mixed sh' (Shaped sh a)) -      <$> mvecsFreeze sh -            (coerce @(MixedVecs s sh' (Shaped sh a)) -                    @(MixedVecs s sh' (Mixed (MapJust sh) a)) -                    vecs) - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where -  memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) -  memptyArrayUnsafe i -    | Dict <- lemKnownMapJust (Proxy @sh) -    = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ -        memptyArrayUnsafe i - -  mvecsUnsafeNew idx (Shaped arr) -    | Dict <- lemKnownMapJust (Proxy @sh) -    = MV_Shaped <$> mvecsUnsafeNew idx arr - -  mvecsNewEmpty _ -    | Dict <- lemKnownMapJust (Proxy @sh) -    = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - -liftShaped1 :: forall sh a b. -               (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) -            -> Shaped sh a -> Shaped sh b -liftShaped1 = coerce - -liftShaped2 :: forall sh a b c. -               (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) -            -> Shaped sh a -> Shaped sh b -> Shaped sh c -liftShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where -  (+) = liftShaped2 (+) -  (-) = liftShaped2 (-) -  (*) = liftShaped2 (*) -  negate = liftShaped1 negate -  abs = liftShaped1 abs -  signum = liftShaped1 signum -  fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where -  fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" -  recip = liftShaped1 recip -  (/) = liftShaped2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" -  exp = liftShaped1 exp -  log = liftShaped1 log -  sqrt = liftShaped1 sqrt -  (**) = liftShaped2 (**) -  logBase = liftShaped2 logBase -  sin = liftShaped1 sin -  cos = liftShaped1 cos -  tan = liftShaped1 tan -  asin = liftShaped1 asin -  acos = liftShaped1 acos -  atan = liftShaped1 atan -  sinh = liftShaped1 sinh -  cosh = liftShaped1 cosh -  tanh = liftShaped1 tanh -  asinh = liftShaped1 asinh -  acosh = liftShaped1 acosh -  atanh = liftShaped1 atanh -  log1p = liftShaped1 GHC.Float.log1p -  expm1 = liftShaped1 GHC.Float.expm1 -  log1pexp = liftShaped1 GHC.Float.log1pexp -  log1mexp = liftShaped1 GHC.Float.log1mexp - -squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -squotArray = liftShaped2 mquotArray -sremArray = liftShaped2 mremArray - -satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -satan2Array = liftShaped2 matan2Array - - -semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -srank :: Elt a => Shaped sh a -> SNat (Rank sh) -srank = shsRank . sshape - --- | The total number of elements in the array. -ssize :: Elt a => Shaped sh a -> Int -ssize = shsSize . sshape - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = -  Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) -            (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) -            (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a -      => ShS sh2 -      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -      -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a -       => ShS sh3 -       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -       -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) -            => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) -           => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a -ssumAllPrim (Shaped arr) = msumAllPrim arr - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) -           => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) -  | Refl <- lemRankMapJust (sshape sarr) -  , Refl <- lemTakeLenMapJust perm (sshape sarr) -  , Refl <- lemDropLenMapJust perm (sshape sarr) -  , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) -  , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) -  = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 - -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l -  | Refl <- lemAppNil @'[Just n] -  = let ssh = SUnknown () :!% ZKX -        xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) -    in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = -  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) - -stoListLinear :: Elt a => Shaped sh a -> [a] -stoListLinear (Shaped arr) = mtoListLinear arr - -sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a -sfromOrthotope sh (SS.A (SG.A arr)) = -  Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) - -stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a -stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) -snest sh arr -  | Refl <- lemMapJustApp sh (Proxy @sh') -  = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) - -sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a -sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) -  | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') -  = Shaped arr - -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) -szip = coerce mzip - -sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) -sunzip = coerce munzip - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -         => ShS sh -> ShS sh2 -         -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) -         -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) -  | Refl <- lemMapJustApp sh (Proxy @sh1) -  , Refl <- lemMapJustApp sh (Proxy @sh2) -  = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) -                     (shCvtSX sh2) -                     (\a -> let Shaped r = f (Shaped a) in r) -                     arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) -        => ShS sh -> ShS sh2 -        -> (Shaped sh1 a -> Shaped sh2 b) -        -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = -  sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) -  | Refl <- lemMapJustApp sh (Proxy @sh') -  = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = -  let _ :$$ sh = sshape arr -  in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = -  case shsProduct (sshape arr) of  -- TODO: simplify when removing the KnownNat stuff -    n@SNat -> sreshape (n :$$ ZSS) arr - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - --- | Throws if the array is empty. -sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) - --- | Throws if the array is empty. -smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) - -sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) -           => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a -sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) -  | Refl <- lemInitApp (Proxy @sh) (Proxy @n) -  , Refl <- lemLastApp (Proxy @sh) (Proxy @n) -  = case sshape sarr1 of -      _ :$$ _ -        | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) -        -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) -      _ -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1Inner' if applicable. -sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a -sdot = coerce mdot - -stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) - -stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') -              => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh -  | Refl <- lemRankMapJust targetsh -  = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') -             => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) -  | Refl <- lemRankMapJust (sshape sarr) -  = mcast sshx arr | 
