diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Nested/Internal.hs | |
parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 326 |
1 files changed, 165 insertions, 161 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 712c5f1..0870789 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -60,7 +60,11 @@ import Unsafe.Coerce import Data.Array.Mixed import qualified Data.Array.Mixed as X -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Types -- Invariant in the API @@ -123,19 +127,19 @@ lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) -lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate SZ = Refl lemRankReplicate (SS (n :: SNat nm1)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 + | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 , Refl <- lemRankReplicate n = Refl -lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh lemRankMapJust ZSS = Refl lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl @@ -146,9 +150,9 @@ lemReplicatePlusApp sn _ _ = go sn go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- X.lemReplicateSucc @a @n'm1 + | Refl <- lemReplicateSucc @a @n'm1 , Refl <- go n - = sym (X.lemReplicateSucc @a @(n'm1 + m)) + = sym (lemReplicateSucc @a @(n'm1 + m)) lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl @@ -156,17 +160,17 @@ lemLeqPlus _ _ _ = Refl lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True lemLeqSuccSucc _ _ = unsafeCoerce Refl -lemDropLenApp :: X.Rank l1 <= X.Rank l2 +lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest - -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) lemDropLenApp _ _ _ = unsafeCoerce Refl -lemTakeLenApp :: X.Rank l1 <= X.Rank l2 +lemTakeLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest - -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) lemTakeLenApp _ _ _ = unsafeCoerce Refl -srankSh :: ShX sh f -> SNat (X.Rank sh) +srankSh :: ShX sh f -> SNat (Rank sh) srankSh ZSX = SNat srankSh (_ :$% sh) | SNat <- srankSh sh = SNat @@ -585,11 +589,11 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) - => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a -- ====== PRIVATE METHODS ====== -- @@ -635,20 +639,20 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr) + in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) mlift ssh2 f (M_Primitive _ a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result @@ -657,36 +661,36 @@ instance Storable a => Elt (Primitive a) where -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - , Refl <- X.lemAppNil @sh3 + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result - mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' - in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) + let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) mtranspose perm (M_Primitive sh arr) = - M_Primitive (X.shPermutePrefix perm sh) - (X.transpose (X.staticShapeFrom sh) perm arr) + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShape sh) perm arr) mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (X.staticShapeFrom sh') arr - offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh)) - VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr) + let arrsh = X.shape (ssxFromShape sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v @@ -701,7 +705,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArray sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -755,7 +759,7 @@ instance Elt a => Elt (Mixed sh' a) where -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr)) + = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) i = mindexPartial arr i @@ -763,8 +767,8 @@ instance Elt a => Elt (Mixed sh' a) where mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -773,95 +777,95 @@ instance Elt a => Elt (Mixed sh' a) where M_Nest (SUnknown (length l) :$% mshape arr) (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr) + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (X.ssxAppend ssh2 ssh') f' arr - (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result) + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr))) + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (X.ssxAppend ssh' sshT) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result) + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1))) + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (X.ssxAppend ssh' sshT) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) - mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) mcast ssh1 sh2 _ (M_Nest sh1T arr) - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T - in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) - => HList SNat is -> Mixed sh (Mixed sh' a) - -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = X.shDropSh @sh @sh' (mshape arr) sh - , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') - , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) - , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + | let sh' = shxDropSh @sh @sh' (mshape arr) sh + , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (X.shPermutePrefix perm sh) + = M_Nest (shxPermutePrefix perm sh) (mtranspose perm arr) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) + memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) where sh' = mshape example - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -- | Create an array given a size and a function that computes the element at a @@ -882,10 +886,10 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case X.enumShape sh of +mgenerate sh f = case shxEnum sh of [] -> memptyArray sh firstidx : restidxs -> - let firstelem = f (X.zeroIxX' sh) + let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem in if mshapeTreeEmpty (Proxy @a) shapetree then memptyArray sh @@ -905,28 +909,28 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) + in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 where sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 - ssh = X.staticShapeFrom sh - snm :: SMayNat () SNat (X.AddMaybe n m) + ssh = ssxFromShape sh + snm :: SMayNat () SNat (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (X.plusSNat n m) + (SKnown n, SKnown m) -> SKnown (snatPlus n m) f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b - f ssh' = X.append (X.ssxAppend ssh ssh') + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) @@ -971,9 +975,9 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shDropSSX sh ssh - in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2) + let sh1 = shxDropSSX sh ssh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) @@ -988,10 +992,10 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a mreplicate sh arr = - let ssh' = X.staticShapeFrom (mshape arr) - in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') + let ssh' = ssxFromShape (mshape arr) + in mlift (ssxAppend (ssxFromShape sh) ssh') (\(sshT :: StaticShX shT) -> - case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr @@ -1005,18 +1009,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' arr = - mlift (X.staticShapeFrom sh') - (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh') + mlift (ssxFromShape sh') + (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') arr miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a @@ -1095,26 +1099,26 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where log1pexp = mliftNumElt1 floatEltLog1pexp log1mexp = mliftNumElt1 floatEltLog1mexp -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked arr - | Refl <- X.lemAppNil @sh - , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat)) + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) , Refl <- lemRankReplicate (srankSh (mshape arr)) - = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) where - convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing) + convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) convSh ZSX = ZSX convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) = SUnknown (fromSMayNat' smn) :$% convSh sh -mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh') +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a mcastToShaped arr targetsh - | Refl <- X.lemAppNil @sh - , Refl <- X.lemAppNil @(MapJust sh') + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(MapJust sh') , Refl <- lemRankMapJust targetsh - = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -1418,7 +1422,7 @@ zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR zeroIxR (SS n) = 0 :.: zeroIxR n -ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) +ixCvtXR :: IIxX sh -> IIxR (Rank sh) ixCvtXR ZIX = ZIR ixCvtXR (n :.% idx) = n :.: ixCvtXR idx @@ -1429,7 +1433,7 @@ shCvtXR' ZSX = shCvtXR' (n :$% (idx :: IShX sh)) | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = castWith (subst2 (lem1 @sh Refl)) - (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) where lem1 :: forall sh' n' k. k : sh' :~: Replicate n' Nothing @@ -1443,13 +1447,13 @@ shCvtXR' (n :$% (idx :: IShX sh)) ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.% ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shCvtRX idx) shapeSizeR :: IShR n -> Int @@ -1506,7 +1510,7 @@ rsumOuter1P :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (NumElt a, PrimElt a) @@ -1559,7 +1563,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- snatFromShR (rshape arr1) , Dict <- lemKnownReplicate sn - , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -1582,7 +1586,7 @@ rtoVector = coerce mtoVector rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a @@ -1593,7 +1597,7 @@ rfromList1Prim l = Ranked (mfromList1Prim l) rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] @@ -1677,7 +1681,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = rlift (snatFromShR (rshape arr)) (\_ -> X.sliceU i n) arr @@ -1686,7 +1690,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 arr = rlift (snatFromShR (rshape arr)) (\(_ :: StaticShX sh') -> - case X.lemReplicateSucc @(Nothing @Nat) @n of + case lemReplicateSucc @(Nothing @Nat) @n of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) arr @@ -1707,12 +1711,12 @@ rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) -rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) , Refl <- lemRankMapJust targetsh @@ -1809,7 +1813,7 @@ shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh shapeSizeS :: ShS sh -> Int shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh +shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh @@ -1838,14 +1842,14 @@ slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -1855,28 +1859,28 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive -lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) -lemCommMapJustTakeLen HNil _ = Refl -lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl -lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" +lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemCommMapJustTakeLen PNil _ = Refl +lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" -lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) -lemCommMapJustDropLen HNil _ = Refl -lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl -lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" +lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemCommMapJustDropLen PNil _ = Refl +lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" -lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) lemCommMapJustIndex SZ (_ :$$ _) = Refl lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) | Refl <- lemCommMapJustIndex i sh - , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = Refl lemCommMapJustIndex _ ZSS = error "Index of empty" -lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) -lemCommMapJustPermute HNil _ = Refl -lemCommMapJustPermute (i `HCons` is) sh +lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemCommMapJustPermute PNil _ = Refl +lemCommMapJustPermute (i `PCons` is) sh | Refl <- lemCommMapJustPermute is sh , Refl <- lemCommMapJustIndex i sh = Refl @@ -1885,53 +1889,53 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f -listsTakeLen HNil _ = ZS -listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh -listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLen PNil _ = ZS +listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh +listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLen HNil sh = sh -listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh -listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLen PNil sh = sh +listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh +listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f -listsPermute HNil _ = ZS -listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest - | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh rest listsIndex _ _ _ ZS _ = error "Index into empty shape" -shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLen = coerce (listsTakeLen @SNat) -shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT) shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) -applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) -applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i applyPermIxS = coerce (applyPermS @(Const i)) -applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) applyPermShS = coerce (applyPermS @SNat) -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) - => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemCommMapJustTakeLen perm (sshape sarr) , Refl <- lemCommMapJustDropLen perm (sshape sarr) , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) + , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a @@ -1969,7 +1973,7 @@ stoList1 = map sunScalar . stoListOuter sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromListPrim sn l - | Refl <- X.lemAppNil @'[Just n] + | Refl <- lemAppNil @'[Just n] = let ssh = SUnknown () :!% ZKX xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr @@ -1989,7 +1993,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemCommMapJustApp sh (Proxy @sh1) , Refl <- lemCommMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh)))) + = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) (shCvtSX sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) @@ -2033,12 +2037,12 @@ sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) -stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr |