aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed.hs23
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs19
3 files changed, 45 insertions, 3 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 0351beb..d782e9f 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -211,6 +211,11 @@ ssxIotaFrom _ ZKSX = []
ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+staticShapeFrom :: IShX sh -> StaticShX sh
+staticShapeFrom ZSX = ZKSX
+staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
+staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
+
lemRankApp :: StaticShX sh1 -> StaticShX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
@@ -414,3 +419,21 @@ slice ivs (XArray arr) = XArray (S.slice ivs arr)
rev1 :: XArray (n : sh) a -> XArray (n : sh) a
rev1 (XArray arr) = XArray (S.rev [0] arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
+reshape ssh1 sh2 (XArray arr)
+ | Dict <- lemKnownINatRankSSX ssh1
+ , Dict <- knownNatFromINat (Proxy @(Rank sh1))
+ , Dict <- lemKnownINatRank sh2
+ , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ = XArray (S.reshape (shapeLshape sh2) arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+reshapePartial ssh1 ssh' sh2 (XArray arr)
+ | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh')))
+ , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh')))
+ = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index ec5f0b5..c7d1819 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -9,7 +9,7 @@ module Data.Array.Nested (
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
rconstant, rfromList, rfromList1, rtoList, rtoList1,
- rslice, rrev1,
+ rslice, rrev1, rreshape,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
@@ -21,7 +21,7 @@ module Data.Array.Nested (
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
sconstant, sfromList, sfromList1, stoList, stoList1,
- sslice, srev1,
+ sslice, srev1, sreshape,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
@@ -30,7 +30,7 @@ module Data.Array.Nested (
IxX(..), IIxX,
KnownShapeX(..), StaticShX(..),
mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
- mconstant, mfromList, mtoList, mslice, mrev1,
+ mconstant, mfromList, mtoList, mslice, mrev1, mreshape,
-- * Array elements
Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2),
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 350eb6f..d041aff 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -540,6 +540,11 @@ mslice ivs = mlift $ \_ -> X.slice ivs
mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a
mrev1 = mlift $ \_ -> X.rev1
+mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a)
+ => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' = mlift $ \(_ :: Proxy shIn) ->
+ X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh'
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -1083,6 +1088,13 @@ rslice ivs = rlift $ \_ -> X.slice ivs
rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a
rrev1 = rlift $ \_ -> X.rev1
+rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a)
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Dict <- lemKnownReplicate (Proxy @n')
+ = Ranked (mreshape (shCvtRX sh') arr)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -1292,3 +1304,10 @@ sslice ivs = slift $ \_ -> X.slice ivs
srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a
srev1 = slift $ \_ -> X.rev1
+
+sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a)
+ => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ , Dict <- lemKnownMapJust (Proxy @sh')
+ = Shaped (mreshape (shCvtSX sh') arr)