diff options
-rw-r--r-- | src/Data/Array/Mixed.hs | 114 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 130 |
2 files changed, 166 insertions, 78 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 7f9076b..856e6cb 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -16,6 +16,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where @@ -47,6 +48,28 @@ pattern GHC_SNat = SNat fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + -- | Type-level list append. type family l1 ++ l2 where @@ -59,6 +82,11 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + + type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i @@ -165,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh ssxToShape' (_ :!$? _) = Nothing +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerce Refl + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKSX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = () :!$? ssxReplicate n + fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx @@ -331,7 +368,7 @@ rerank ssh ssh1 ssh2 f (XArray arr) where unXArray (XArray a) = a -rerankTop :: forall sh sh1 sh2 a b. +rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) @@ -402,40 +439,75 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is -lemPermuteRank _ HNil = Refl -lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl - -lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh) - => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemPermuteRank2 _ HNil = Refl -lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) = - let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is' - p1 = lemPermuteRank2 p is - p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is') - p9 = _ - in p9 +lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ HNil = Refl +lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKSX HNil = Refl +lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$@ _) HNil = Refl +lemRankDropLen (_ :!$? _) HNil = Refl +lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerce Refl + +ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen HNil _ = ZKSX +ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh +ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh +ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen HNil sh = sh +ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute HNil _ = ZKSX +ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest +ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest +ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. --- --- This function does not throw: the constraints ensure that the permutation is always valid. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) => HList SNat is -> XArray sh a -> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a transpose perm (XArray arr) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen (knownShapeX @sh) perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) -- | The list argument gives indices into the original dimension list. -- --- This version throws a runtime error if the permutation is invalid. -transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a -transposeUntyped perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index b3f8143..de27336 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -85,9 +85,8 @@ import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..)) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate) import qualified Data.Array.Mixed as X @@ -124,36 +123,10 @@ import qualified Data.Array.Mixed as X -- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - -- Stupid things that the type checker should be able to figure out in-line, but can't @@ -163,10 +136,6 @@ subst1 Refl = Refl subst2 :: forall f c a b. a :~: b -> f a c :~: f b c subst2 Refl = Refl --- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that. -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerce Refl - lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l lemAppLeft _ Refl = Refl @@ -179,7 +148,7 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where go :: SNat m -> StaticShX (Replicate m Nothing) go SZ = ZKSX - go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n + go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (natSing @n) @@ -187,7 +156,7 @@ lemRankReplicate _ = go (natSing @n) go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl go (SS (n :: SNat nm1)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 , Refl <- go n = Refl @@ -205,9 +174,9 @@ lemReplicatePlusApp _ _ _ = go (natSing @n) go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 + | Refl <- X.lemReplicateSucc @a @n'm1 , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) + = sym (X.lemReplicateSucc @a @(n'm1 + m)) shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') shAppSplit _ ZKSX idx = (ZSX, idx) @@ -583,10 +552,12 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a -mtranspose perm = mlift $ \(Proxy @sh') -> - X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') - (X.transpose perm) +mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a +mtranspose perm + | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh))) + = mlift $ \(Proxy @sh') -> + X.rerankTop (knownShapeX @sh) (knownShapeX @(X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh)) (knownShapeX @sh') + (X.transpose perm) mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a @@ -823,13 +794,10 @@ lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) go ZSS = ZKSX go (n :$$ sh) = n :!$@ go sh -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 +lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ZSS = Refl - go (_ :$$ sh) | Refl <- go sh = Refl +lemCommMapJustApp ZSS _ = Refl +lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr @@ -1057,11 +1025,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) +ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) +shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) shapeSizeR :: IShR n -> Int shapeSizeR ZSR = 1 @@ -1105,7 +1073,7 @@ rsumOuter1P :: forall n a. => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) @@ -1119,15 +1087,17 @@ rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm | Dict <- lemKnownReplicate (Proxy @n) + , length perm <= fromIntegral (natVal (Proxy @n)) = rlift $ \(Proxy @sh') -> - X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh') - (X.transposeUntyped perm) + X.transposeUntyped (natSing @n) (knownShapeX @sh') perm + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" rappend :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a rappend | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) rscalar :: Elt a => a -> Ranked 0 a @@ -1150,7 +1120,7 @@ rtoVector = coerce mtoVector rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromList1 l | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l)) rfromList :: Elt a => NonEmpty a -> Ranked 1 a @@ -1158,7 +1128,7 @@ rfromList = Ranked . mfromList1 . fmap mscalar rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoList (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] @@ -1181,7 +1151,7 @@ rslice ivs = rlift $ \_ -> X.slice ivs rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 = rlift $ \(Proxy @sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of + case X.lemReplicateSucc @(Nothing @Nat) @n of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh') rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) @@ -1308,7 +1278,7 @@ sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) + (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) -- | __WARNING__: All values returned from the function must have equal shape. @@ -1342,10 +1312,56 @@ ssumOuter1 :: forall sh n a. => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a +lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) +lemCommMapJustTakeLen HNil _ = Refl +lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" + +lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) +lemCommMapJustDropLen HNil _ = Refl +lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" + +lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex SZ (_ :$$ _) = Refl +lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemCommMapJustIndex i sh + , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemCommMapJustIndex _ ZSS = error "Index of empty" + +lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) +lemCommMapJustPermute HNil _ = Refl +lemCommMapJustPermute (i `HCons` is) sh + | Refl <- lemCommMapJustPermute is sh + , Refl <- lemCommMapJustIndex i sh + = Refl + +shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shTakeLen HNil _ = ZSS +shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh +shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" + +shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shPermute HNil _ = ZSS +shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh) + +shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shIndex _ _ SZ (n :$$ _) rest = n :$$ rest +shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest + | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = shIndex p pT i sh rest +shIndex _ _ _ ZSS _ = error "Index into empty shape" + +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) , Refl <- lemRankMapJust (Proxy @sh) + , Refl <- lemCommMapJustTakeLen perm (knownShape @sh) + , Refl <- lemCommMapJustDropLen perm (knownShape @sh) + , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh)) + , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) |