diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 44 |
1 files changed, 22 insertions, 22 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 97c7277..c442d6f 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -44,7 +44,7 @@ import Data.Array.Strided.Arith semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) +semptyArray sh = Shaped (memptyArray (shxFromShS sh)) srank :: Elt a => Shaped sh a -> SNat (Rank sh) srank = shsRank . sshape @@ -54,7 +54,7 @@ ssize :: Elt a => Shaped sh a -> Int ssize = shsSize . sshape sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) +sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh shsTakeIx _ _ ZIS = ZSS @@ -64,26 +64,26 @@ sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) + (ixxFromIxS idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -113,7 +113,7 @@ sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) +sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) @@ -149,17 +149,17 @@ sfromListPrim sn l sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a sfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) + in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) stoListLinear :: Elt a => Shaped sh a -> [a] stoListLinear (Shaped arr) = mtoListLinear arr sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a sfromOrthotope sh (SS.A (SG.A arr)) = - Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) @@ -170,7 +170,7 @@ sunScalar arr = sindex arr ZIS snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) snest sh arr | Refl <- lemMapJustApp sh (Proxy @sh') - = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) + = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr)) sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) @@ -190,8 +190,8 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh1) , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) + (shxFromShS sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) @@ -205,10 +205,10 @@ srerank sh sh2 f (stoPrimitive -> arr) = sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) + = Shaped (mreplicate (shxFromShS sh) arr) sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) +sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) @@ -222,7 +222,7 @@ srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) +sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a sflatten arr = @@ -234,11 +234,11 @@ siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a @@ -257,16 +257,16 @@ sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a sdot = coerce mdot stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) +stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr) stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) +stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr) sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) |