From 77ab86ede90938fa43f7f9988ac10f7026440a1c Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 14 May 2024 13:47:44 +0200 Subject: reshape --- src/Data/Array/Mixed.hs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 0351beb..d782e9f 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -211,6 +211,11 @@ ssxIotaFrom _ ZKSX = [] ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +staticShapeFrom :: IShX sh -> StaticShX sh +staticShapeFrom ZSX = ZKSX +staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh +staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh + lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this @@ -414,3 +419,21 @@ slice ivs (XArray arr) = XArray (S.slice ivs arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownINatRankSSX ssh1 + , Dict <- knownNatFromINat (Proxy @(Rank sh1)) + , Dict <- lemKnownINatRank sh2 + , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + = XArray (S.reshape (shapeLshape sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh') + , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh'))) + , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') + , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh'))) + = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) -- cgit v1.2.3-70-g09d2