diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shaped.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 495 |
1 files changed, 0 insertions, 495 deletions
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs deleted file mode 100644 index 372439f..0000000 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ /dev/null @@ -1,495 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shaped where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS -import Data.Array.Internal.ShapedG qualified as SG -import Data.Array.Internal.ShapedS qualified as SS -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape -import Data.Array.Strided.Arith - - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) - -instance (Show a, Elt a) => Show (Shaped n a) where - showsPrec d arr@(Shaped marr) = - let sh = show (shsToList (sshape arr)) - in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr - -instance Elt a => NFData (Shaped sh a) where - rnf (Shaped arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) - -instance Elt a => Elt (Shaped sh a) where - mshape (M_Shaped arr) = mshape arr - mindex (M_Shaped arr) i = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - - mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoListOuter (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift ssh2 f (M_Shaped arr) = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = - coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) - @(NonEmpty (Mixed sh2 (Shaped sh a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - - mconcat l = M_Shaped (mconcat (coerce l)) - - mrnf (M_Shaped arr) = mrnf arr - - type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Shaped arr) = marrayStrides arr - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where - memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - -liftShaped1 :: forall sh a b. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) - -> Shaped sh a -> Shaped sh b -liftShaped1 = coerce - -liftShaped2 :: forall sh a b c. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) - -> Shaped sh a -> Shaped sh b -> Shaped sh c -liftShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where - (+) = liftShaped2 (+) - (-) = liftShaped2 (-) - (*) = liftShaped2 (*) - negate = liftShaped1 negate - abs = liftShaped1 abs - signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" - recip = liftShaped1 recip - (/) = liftShaped2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" - exp = liftShaped1 exp - log = liftShaped1 log - sqrt = liftShaped1 sqrt - (**) = liftShaped2 (**) - logBase = liftShaped2 logBase - sin = liftShaped1 sin - cos = liftShaped1 cos - tan = liftShaped1 tan - asin = liftShaped1 asin - acos = liftShaped1 acos - atan = liftShaped1 atan - sinh = liftShaped1 sinh - cosh = liftShaped1 cosh - tanh = liftShaped1 tanh - asinh = liftShaped1 asinh - acosh = liftShaped1 acosh - atanh = liftShaped1 atanh - log1p = liftShaped1 GHC.Float.log1p - expm1 = liftShaped1 GHC.Float.expm1 - log1pexp = liftShaped1 GHC.Float.log1pexp - log1mexp = liftShaped1 GHC.Float.log1mexp - -squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -squotArray = liftShaped2 mquotArray -sremArray = liftShaped2 mremArray - -satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -satan2Array = liftShaped2 matan2Array - - -semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -srank :: Elt a => Shaped sh a -> SNat (Rank sh) -srank = shsRank . sshape - --- | The total number of elements in the array. -ssize :: Elt a => Shaped sh a -> Int -ssize = shsSize . sshape - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a - => ShS sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a - => ShS sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a -ssumAllPrim (Shaped arr) = msumAllPrim arr - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) - => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - , Refl <- lemTakeLenMapJust perm (sshape sarr) - , Refl <- lemDropLenMapJust perm (sshape sarr) - , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) - = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 - -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) - -stoListLinear :: Elt a => Shaped sh a -> [a] -stoListLinear (Shaped arr) = mtoListLinear arr - -sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a -sfromOrthotope sh (SS.A (SG.A arr)) = - Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) - -stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a -stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) -snest sh arr - | Refl <- lemMapJustApp sh (Proxy @sh') - = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) - -sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a -sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) - | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') - = Shaped arr - -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) -szip = coerce mzip - -sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) -sunzip = coerce munzip - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh1) - , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = - let _ :$$ sh = sshape arr - in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - --- | Throws if the array is empty. -sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) - --- | Throws if the array is empty. -smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) - -sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) - => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a -sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) - | Refl <- lemInitApp (Proxy @sh) (Proxy @n) - , Refl <- lemLastApp (Proxy @sh) (Proxy @n) - = case sshape sarr1 of - _ :$$ _ - | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) - -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) - _ -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1Inner' if applicable. -sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a -sdot = coerce mdot - -stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) - -stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mcast sshx arr |