From 77ab86ede90938fa43f7f9988ac10f7026440a1c Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 14 May 2024 13:47:44 +0200 Subject: reshape --- src/Data/Array/Mixed.hs | 23 +++++++++++++++++++++++ src/Data/Array/Nested.hs | 6 +++--- src/Data/Array/Nested/Internal.hs | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 0351beb..d782e9f 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -211,6 +211,11 @@ ssxIotaFrom _ ZKSX = [] ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +staticShapeFrom :: IShX sh -> StaticShX sh +staticShapeFrom ZSX = ZKSX +staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh +staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh + lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this @@ -414,3 +419,21 @@ slice ivs (XArray arr) = XArray (S.slice ivs arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownINatRankSSX ssh1 + , Dict <- knownNatFromINat (Proxy @(Rank sh1)) + , Dict <- lemKnownINatRank sh2 + , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + = XArray (S.reshape (shapeLshape sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh') + , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh'))) + , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') + , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh'))) + = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ec5f0b5..c7d1819 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,7 +9,7 @@ module Data.Array.Nested ( rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, rconstant, rfromList, rfromList1, rtoList, rtoList1, - rslice, rrev1, + rslice, rrev1, rreshape, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -21,7 +21,7 @@ module Data.Array.Nested ( sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, sconstant, sfromList, sfromList1, stoList, stoList1, - sslice, srev1, + sslice, srev1, sreshape, -- ** Lifting orthotope operations to 'Shaped' arrays slift, @@ -30,7 +30,7 @@ module Data.Array.Nested ( IxX(..), IIxX, KnownShapeX(..), StaticShX(..), mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, - mconstant, mfromList, mtoList, mslice, mrev1, + mconstant, mfromList, mtoList, mslice, mrev1, mreshape, -- * Array elements Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 350eb6f..d041aff 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -540,6 +540,11 @@ mslice ivs = mlift $ \_ -> X.slice ivs mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 = mlift $ \_ -> X.rev1 +mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a) + => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' = mlift $ \(_ :: Proxy shIn) -> + X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh' + mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -1083,6 +1088,13 @@ rslice ivs = rlift $ \_ -> X.slice ivs rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a rrev1 = rlift $ \_ -> X.rev1 +rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a) + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + , Dict <- lemKnownReplicate (Proxy @n') + = Ranked (mreshape (shCvtRX sh') arr) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -1292,3 +1304,10 @@ sslice ivs = slift $ \_ -> X.slice ivs srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a srev1 = slift $ \_ -> X.rev1 + +sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a) + => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + , Dict <- lemKnownMapJust (Proxy @sh') + = Shaped (mreshape (shCvtSX sh') arr) -- cgit v1.2.3-70-g09d2