aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs23
1 files changed, 23 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 0351beb..d782e9f 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -211,6 +211,11 @@ ssxIotaFrom _ ZKSX = []
ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+staticShapeFrom :: IShX sh -> StaticShX sh
+staticShapeFrom ZSX = ZKSX
+staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
+staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
+
lemRankApp :: StaticShX sh1 -> StaticShX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
@@ -414,3 +419,21 @@ slice ivs (XArray arr) = XArray (S.slice ivs arr)
rev1 :: XArray (n : sh) a -> XArray (n : sh) a
rev1 (XArray arr) = XArray (S.rev [0] arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
+reshape ssh1 sh2 (XArray arr)
+ | Dict <- lemKnownINatRankSSX ssh1
+ , Dict <- knownNatFromINat (Proxy @(Rank sh1))
+ , Dict <- lemKnownINatRank sh2
+ , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ = XArray (S.reshape (shapeLshape sh2) arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+reshapePartial ssh1 ssh' sh2 (XArray arr)
+ | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh')))
+ , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
+ , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh')))
+ = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)