aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-16 12:17:26 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-16 12:17:26 +0200
commitbad1902b2b3d8835cfe65700893c8ed8b560c893 (patch)
treee9ab386f77372794fe314a427b7470fee29bc8a2 /src/Data
parentac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (diff)
Fix transpose
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed.hs114
-rw-r--r--src/Data/Array/Nested/Internal.hs130
2 files changed, 166 insertions, 78 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 7f9076b..856e6cb 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -47,6 +48,28 @@ pattern GHC_SNat = SNat
fromSNat' :: SNat n -> Int
fromSNat' = fromIntegral . fromSNat
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
-- | Type-level list append.
type family l1 ++ l2 where
@@ -59,6 +82,11 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+
type IxX :: [Maybe Nat] -> Type -> Type
data IxX sh i where
ZIX :: IxX '[] i
@@ -165,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX
ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
ssxToShape' (_ :!$? _) = Nothing
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKSX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = () :!$? ssxReplicate n
+
fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -331,7 +368,7 @@ rerank ssh ssh1 ssh2 f (XArray arr)
where
unXArray (XArray a) = a
-rerankTop :: forall sh sh1 sh2 a b.
+rerankTop :: forall sh1 sh2 sh a b.
(Storable a, Storable b)
=> StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
-> (XArray sh1 a -> XArray sh2 b)
@@ -402,40 +439,75 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
-lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
-lemPermuteRank _ HNil = Refl
-lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl
-
-lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh)
- => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
-lemPermuteRank2 _ HNil = Refl
-lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) =
- let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is'
- p1 = lemPermuteRank2 p is
- p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is')
- p9 = _
- in p9
+lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ HNil = Refl
+lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKSX HNil = Refl
+lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$@ _) HNil = Refl
+lemRankDropLen (_ :!$? _) HNil = Refl
+lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerce Refl
+
+ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen HNil _ = ZKSX
+ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh
+ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh
+ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen HNil sh = sh
+ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute HNil _ = ZKSX
+ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest
+ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest
+ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
-- | The list argument gives indices into the original dimension list.
---
--- This function does not throw: the constraints ensure that the permutation is always valid.
transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
=> HList SNat is
-> XArray sh a
-> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
- , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh))
+ , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankDropLen (knownShapeX @sh) perm
= let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
in XArray (S.transpose perm' arr)
-- | The list argument gives indices into the original dimension list.
--
--- This version throws a runtime error if the permutation is invalid.
-transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
-transposeUntyped perm (XArray arr)
- | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+-- The permutation (the list) must have length <= @n@. If it is longer, this
+-- function throws.
+transposeUntyped :: forall n sh a.
+ SNat n -> StaticShX sh -> [Int]
+ -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
+transposeUntyped sn ssh perm (XArray arr)
+ | length perm <= fromSNat' sn
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
= XArray (S.transpose perm arr)
+ | otherwise
+ = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
transpose2 :: forall sh1 sh2 a.
StaticShX sh1 -> StaticShX sh2
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index b3f8143..de27336 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -85,9 +85,8 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..))
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate)
import qualified Data.Array.Mixed as X
@@ -124,36 +123,10 @@ import qualified Data.Array.Mixed as X
-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
-pattern SZ :: () => (n ~ 0) => SNat n
-pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
- where SZ = SNat
-
-pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
-pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
- where SS = snatSucc
-
-{-# COMPLETE SZ, SS #-}
-
-snatSucc :: SNat n -> SNat (n + 1)
-snatSucc SNat = SNat
-
-data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
-snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
-snatPred snp1 =
- withKnownNat snp1 $
- case cmpNat (Proxy @1) (Proxy @np1) of
- LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- GTI -> Nothing
-
-- Stupid things that the type checker should be able to figure out in-line, but can't
@@ -163,10 +136,6 @@ subst1 Refl = Refl
subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl
--- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that.
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerce Refl
-
lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl
@@ -179,7 +148,7 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
go :: SNat m -> StaticShX (Replicate m Nothing)
go SZ = ZKSX
- go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
+ go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = go (natSing @n)
@@ -187,7 +156,7 @@ lemRankReplicate _ = go (natSing @n)
go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
go (SS (n :: SNat nm1))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
, Refl <- go n
= Refl
@@ -205,9 +174,9 @@ lemReplicatePlusApp _ _ _ = go (natSing @n)
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS (n :: SNat n'm1))
- | Refl <- lemReplicateSucc @a @n'm1
+ | Refl <- X.lemReplicateSucc @a @n'm1
, Refl <- go n
- = sym (lemReplicateSucc @a @(n'm1 + m))
+ = sym (X.lemReplicateSucc @a @(n'm1 + m))
shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
shAppSplit _ ZKSX idx = (ZSX, idx)
@@ -583,10 +552,12 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a
-mtranspose perm = mlift $ \(Proxy @sh') ->
- X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
- (X.transpose perm)
+mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a
+mtranspose perm
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh)))
+ = mlift $ \(Proxy @sh') ->
+ X.rerankTop (knownShapeX @sh) (knownShapeX @(X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh)) (knownShapeX @sh')
+ (X.transpose perm)
mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
@@ -823,13 +794,10 @@ lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
go ZSS = ZKSX
go (n :$$ sh) = n :!$@ go sh
-lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
+lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2
-> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustPlusApp _ _ = go (knownShape @sh1)
- where
- go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
- go ZSS = Refl
- go (_ :$$ sh) | Refl <- go sh = Refl
+lemCommMapJustApp ZSS _ = Refl
+lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl
instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
@@ -1057,11 +1025,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
+ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
+shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
shapeSizeR :: IShR n -> Int
shapeSizeR ZSR = 1
@@ -1105,7 +1073,7 @@ rsumOuter1P :: forall n a.
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked
. coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
. X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
@@ -1119,15 +1087,17 @@ rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
rtranspose perm
| Dict <- lemKnownReplicate (Proxy @n)
+ , length perm <= fromIntegral (natVal (Proxy @n))
= rlift $ \(Proxy @sh') ->
- X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh')
- (X.transposeUntyped perm)
+ X.transposeUntyped (natSing @n) (knownShapeX @sh') perm
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
rappend :: forall n a. (KnownNat n, Elt a)
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
rappend
| Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
rscalar :: Elt a => a -> Ranked 0 a
@@ -1150,7 +1120,7 @@ rtoVector = coerce mtoVector
rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))
rfromList :: Elt a => NonEmpty a -> Ranked 1 a
@@ -1158,7 +1128,7 @@ rfromList = Ranked . mfromList1 . fmap mscalar
rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoList (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
@@ -1181,7 +1151,7 @@ rslice ivs = rlift $ \_ -> X.slice ivs
rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 = rlift $ \(Proxy @sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
+ case X.lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')
rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
@@ -1308,7 +1278,7 @@ sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial (Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
+ (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
@@ -1342,10 +1312,56 @@ ssumOuter1 :: forall sh n a.
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a
+lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh)
+lemCommMapJustTakeLen HNil _ = Refl
+lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
+lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty"
+
+lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh)
+lemCommMapJustDropLen HNil _ = Refl
+lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
+lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty"
+
+lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh)
+lemCommMapJustIndex SZ (_ :$$ _) = Refl
+lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemCommMapJustIndex i sh
+ , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemCommMapJustIndex _ ZSS = error "Index of empty"
+
+lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh)
+lemCommMapJustPermute HNil _ = Refl
+lemCommMapJustPermute (i `HCons` is) sh
+ | Refl <- lemCommMapJustPermute is sh
+ , Refl <- lemCommMapJustIndex i sh
+ = Refl
+
+shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shTakeLen HNil _ = ZSS
+shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh
+shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
+
+shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shPermute HNil _ = ZSS
+shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh)
+
+shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shIndex _ _ SZ (n :$$ _) rest = n :$$ rest
+shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
+ | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = shIndex p pT i sh rest
+shIndex _ _ _ ZSS _ = error "Index into empty shape"
+
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
, Refl <- lemRankMapJust (Proxy @sh)
+ , Refl <- lemCommMapJustTakeLen perm (knownShape @sh)
+ , Refl <- lemCommMapJustDropLen perm (knownShape @sh)
+ , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh))
+ , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)