aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs266
1 files changed, 216 insertions, 50 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 0351beb..ce18431 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
@@ -7,12 +8,16 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -21,14 +26,18 @@ import qualified Data.Array.Ranked as ORB
import Data.Coerce
import Data.Kind
import Data.Proxy
+import Data.Type.Bool
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
import Foreign.Storable (Storable)
+import GHC.TypeError
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
-import Data.INat
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
-- | The 'SNat' pattern synonym is complete, but it doesn't have a
-- @COMPLETE@ pragma. This copy of it does.
@@ -39,6 +48,28 @@ pattern GHC_SNat = SNat
fromSNat' :: SNat n -> Int
fromSNat' = fromIntegral . fromSNat
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
-- | Type-level list append.
type family l1 ++ l2 where
@@ -51,6 +82,11 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+
type IxX :: [Maybe Nat] -> Type -> Type
data IxX sh i where
ZIX :: IxX '[] i
@@ -103,11 +139,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :!$? knownShapeX
type family Rank sh where
- Rank '[] = Z
- Rank (_ : sh) = S (Rank sh)
+ Rank '[] = 0
+ Rank (_ : sh) = 1 + Rank sh
type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
+newtype XArray sh a = XArray (S.Array (Rank sh) a)
deriving (Show)
zeroIxX :: StaticShX sh -> IIxX sh
@@ -157,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX
ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
ssxToShape' (_ :!$? _) = Nothing
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKSX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = () :!$? ssxReplicate n
+
fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -211,23 +256,28 @@ ssxIotaFrom _ ZKSX = []
ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+staticShapeFrom :: IShX sh -> StaticShX sh
+staticShapeFrom ZSX = ZKSX
+staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
+staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
+
lemRankApp :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZSX = Dict
-lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict
-lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZKSX = Dict
-lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKSX = Dict
+lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
lemKnownShapeX ZKSX = Dict
@@ -254,8 +304,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.fromVector (shapeLshape sh) v)
toVector :: Storable a => XArray sh a -> VS.Vector a
@@ -269,15 +318,14 @@ unScalar (XArray a) = S.unScalar a
constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
constant sh x
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh) x)
generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownINatRank sh =
+-- generateM sh f | Dict <- lemKnownNatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
@@ -300,8 +348,7 @@ type family AddMaybe n m where
append :: forall n m sh a. (KnownShapeX sh, Storable a)
=> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append (XArray a) (XArray b)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
@@ -310,21 +357,18 @@ rerank :: forall sh sh1 sh2 a b.
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
+ = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a -> unXArray (f (XArray a)))
arr)
where
unXArray (XArray a) = a
-rerankTop :: forall sh sh1 sh2 a b.
+rerankTop :: forall sh1 sh2 sh a b.
(Storable a, Storable b)
=> StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
-> (XArray sh1 a -> XArray sh2 b)
@@ -337,26 +381,135 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
+ = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
unXArray (XArray a) = a
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
+data HList f list where
+ HNil :: HList f '[]
+ HCons :: f a -> HList f l -> HList f (a : l)
+infixr 5 `HCons`
+
+foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m
+foldHList _ HNil = mempty
+foldHList f (x `HCons` l) = f x <> foldHList f l
+
+class KnownNatList l where makeNatList :: HList SNat l
+instance KnownNatList '[] where makeNatList = HNil
+instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ HNil = Refl
+lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKSX HNil = Refl
+lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$@ _) HNil = Refl
+lemRankDropLen (_ :!$? _) HNil = Refl
+lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerce Refl
+
+ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen HNil _ = ZKSX
+ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh
+ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh
+ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen HNil sh = sh
+ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute HNil _ = ZKSX
+ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest
+ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest
+ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
+
-- | The list argument gives indices into the original dimension list.
-transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
+ => HList SNat is
+ -> XArray sh a
+ -> XArray (PermutePrefix is sh) a
transpose perm (XArray arr)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh))
+ , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankDropLen (knownShapeX @sh) perm
+ = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
+ in XArray (S.transpose perm' arr)
+
+-- | The list argument gives indices into the original dimension list.
+--
+-- The permutation (the list) must have length <= @n@. If it is longer, this
+-- function throws.
+transposeUntyped :: forall n sh a.
+ SNat n -> StaticShX sh -> [Int]
+ -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
+transposeUntyped sn ssh perm (XArray arr)
+ | length perm <= fromSNat' sn
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
= XArray (S.transpose perm arr)
+ | otherwise
+ = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
transpose2 :: forall sh1 sh2 a.
StaticShX sh1 -> StaticShX sh2
@@ -364,10 +517,8 @@ transpose2 :: forall sh1 sh2 a.
transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
, Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2)))
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
@@ -390,13 +541,12 @@ sumOuter ssh ssh'
fromList1 :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromList1 ssh l
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
+ | Dict <- lemKnownNatRankSSX ssh
= case ssh of
m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
+ _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
@@ -404,13 +554,29 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
empty sh
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh)
(error "Data.Array.Mixed.empty: shape was not empty"))
-slice :: [(Int, Int)] -> XArray sh a -> XArray sh a
-slice ivs (XArray arr) = XArray (S.slice ivs arr)
+slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
+slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr)
+
+sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a
+sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr)
rev1 :: XArray (n : sh) a -> XArray (n : sh) a
rev1 (XArray arr) = XArray (S.rev [0] arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
+reshape ssh1 sh2 (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh1
+ , Dict <- lemKnownNatRank sh2
+ = XArray (S.reshape (shapeLshape sh2) arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+reshapePartial ssh1 ssh' sh2 (XArray arr)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
+ = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)