diff options
| -rw-r--r-- | src/Data/Array/Mixed.hs | 23 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 19 | 
3 files changed, 45 insertions, 3 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 0351beb..d782e9f 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -211,6 +211,11 @@ ssxIotaFrom _ ZKSX = []  ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh  ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +staticShapeFrom :: IShX sh -> StaticShX sh +staticShapeFrom ZSX = ZKSX +staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh +staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh +  lemRankApp :: StaticShX sh1 -> StaticShX sh2             -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)  lemRankApp _ _ = unsafeCoerce Refl  -- TODO improve this @@ -414,3 +419,21 @@ slice ivs (XArray arr) = XArray (S.slice ivs arr)  rev1 :: XArray (n : sh) a -> XArray (n : sh) a  rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) +  | Dict <- lemKnownINatRankSSX ssh1 +  , Dict <- knownNatFromINat (Proxy @(Rank sh1)) +  , Dict <- lemKnownINatRank sh2 +  , Dict <- knownNatFromINat (Proxy @(Rank sh2)) +  = XArray (S.reshape (shapeLshape sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) +  | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh') +  , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh'))) +  , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') +  , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh'))) +  = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ec5f0b5..c7d1819 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,7 +9,7 @@ module Data.Array.Nested (    rshape, rindex, rindexPartial, rgenerate, rsumOuter1,    rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,    rconstant, rfromList, rfromList1, rtoList, rtoList1, -  rslice, rrev1, +  rslice, rrev1, rreshape,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, @@ -21,7 +21,7 @@ module Data.Array.Nested (    sshape, sindex, sindexPartial, sgenerate, ssumOuter1,    stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,    sconstant, sfromList, sfromList1, stoList, stoList1, -  sslice, srev1, +  sslice, srev1, sreshape,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, @@ -30,7 +30,7 @@ module Data.Array.Nested (    IxX(..), IIxX,    KnownShapeX(..), StaticShX(..),    mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, -  mconstant, mfromList, mtoList, mslice, mrev1, +  mconstant, mfromList, mtoList, mslice, mrev1, mreshape,    -- * Array elements    Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 350eb6f..d041aff 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -540,6 +540,11 @@ mslice ivs = mlift $ \_ -> X.slice ivs  mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a  mrev1 = mlift $ \_ -> X.rev1 +mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a) +         => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' = mlift $ \(_ :: Proxy shIn) -> +                        X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh' +  mliftPrim :: (KnownShapeX sh, Storable a)            => (a -> a)            -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -1083,6 +1088,13 @@ rslice ivs = rlift $ \_ -> X.slice ivs  rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a  rrev1 = rlift $ \_ -> X.rev1 +rreshape :: forall n n' a. (KnownINat n, KnownINat n', Elt a) +         => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' (Ranked arr) +  | Dict <- lemKnownReplicate (Proxy @n) +  , Dict <- lemKnownReplicate (Proxy @n') +  = Ranked (mreshape (shCvtRX sh') arr) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -1292,3 +1304,10 @@ sslice ivs = slift $ \_ -> X.slice ivs  srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a  srev1 = slift $ \_ -> X.rev1 + +sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a) +         => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) +  | Dict <- lemKnownMapJust (Proxy @sh) +  , Dict <- lemKnownMapJust (Proxy @sh') +  = Shaped (mreshape (shCvtSX sh') arr) | 
