{-# LANGUAGE DataKinds #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} module Data.Array.Nested.Shaped ( Shaped(Shaped), squotArray, sremArray, satan2Array, sshape, module Data.Array.Nested.Shaped, liftShaped1, liftShaped2, ) where import Prelude hiding (mappend, mconcat) import Data.Array.Internal.RankedG qualified as RG import Data.Array.Internal.RankedS qualified as RS import Data.Array.Internal.ShapedG qualified as SG import Data.Array.Internal.ShapedS qualified as SS import Data.Bifunctor (first) import Data.Coerce (coerce) import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits import Data.Array.Nested.Convert import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types import Data.Array.Strided.Arith import Data.Array.XArray (XArray) import Data.Array.XArray qualified as X semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a semptyArray sh = Shaped (memptyArray (shxFromShS sh)) srank :: Elt a => Shaped sh a -> SNat (Rank sh) srank = shsRank . sshape -- | The total number of elements in the array. ssize :: Elt a => Shaped sh a -> Int ssize = shsSize . sshape sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh shsTakeIx _ _ ZIS = ZSS shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) (ixxFromIxS idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemTakeLenMapJust perm (sshape sarr) , Refl <- lemDropLenMapJust perm (sshape sarr) , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] stoListOuter (Shaped arr) = coerce (mtoListOuter arr) stoList1 :: Elt a => Shaped '[n] a -> [a] stoList1 = map sunScalar . stoListOuter sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromListPrim sn l | Refl <- lemAppNil @'[Just n] = let ssh = SUnknown () :!% ZKX xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a sfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) stoListLinear :: Elt a => Shaped sh a -> [a] stoListLinear (Shaped arr) = mtoListLinear arr sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a sfromOrthotope sh (SS.A (SG.A arr)) = Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) snest sh arr | Refl <- lemMapJustApp sh (Proxy @sh') = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr)) sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') = Shaped arr szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) szip = coerce mzip sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) sunzip = coerce munzip srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) => ShS sh -> ShS sh2 -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh1) , Refl <- lemMapJustApp sh (Proxy @sh2) = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) (shxFromShS sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => ShS sh -> ShS sh2 -> (Shaped sh1 a -> Shaped sh2 b) -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b srerank sh sh2 f (stoPrimitive -> arr) = sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') = Shaped (mreplicate (shxFromShS sh) arr) sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a sslice i n@SNat arr = let _ :$$ sh = sshape arr in slift (n :$$ sh) (\_ -> X.slice i n) arr srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a sflatten arr = case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff n@SNat -> sreshape (n :$$ ZSS) arr siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) | Refl <- lemInitApp (Proxy @sh) (Proxy @n) , Refl <- lemLastApp (Proxy @sh) (Proxy @n) = case sshape sarr1 of _ :$$ _ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) _ -> error "unreachable" -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a sdot = coerce mdot stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr) stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr) sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)