diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 495 |
1 files changed, 495 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs new file mode 100644 index 0000000..4bccbc4 --- /dev/null +++ b/src/Data/Array/Nested/Shaped.hs @@ -0,0 +1,495 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Array.Internal.ShapedG qualified as SG +import Data.Array.Internal.ShapedS qualified as SS +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Types +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Strided.Arith + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray sh = Shaped (memptyArray (shCvtSX sh)) + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shCvtXS' (mshape arr) + +srank :: Elt a => Shaped sh a -> SNat (Rank sh) +srank = shsRank . sshape + +-- | The total number of elements in the array. +ssize :: Elt a => Shaped sh a -> Int +ssize = shsSize . sshape + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (ixCvtSX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a + => ShS sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a +ssumAllPrim (Shaped arr) = msumAllPrim arr + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemTakeLenMapJust perm (sshape sarr) + , Refl <- lemDropLenMapJust perm (sshape sarr) + , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 + +sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoList1 :: Elt a => Shaped '[n] a -> [a] +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) + +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) + +stoListLinear :: Elt a => Shaped sh a -> [a] +stoListLinear (Shaped arr) = mtoListLinear arr + +sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a +sfromOrthotope sh (SS.A (SG.A arr)) = + Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + +stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a +stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) +snest sh arr + | Refl <- lemMapJustApp sh (Proxy @sh') + = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) + +sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a +sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) + | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') + = Shaped arr + +szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip = coerce mzip + +sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) +sunzip = coerce munzip + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh1) + , Refl <- lemMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) + (shCvtSX sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = + sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) + +sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a +sflatten arr = + case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff + n@SNat -> sreshape (n :$$ ZSS) arr + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +-- | Throws if the array is empty. +sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) + +-- | Throws if the array is empty. +smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) + +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'sdot1Inner' if applicable. +sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a +sdot = coerce mdot + +stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) + +stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh + | Refl <- lemRankMapJust targetsh + = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr |