aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 11:27:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 11:27:18 +0200
commit0f13e61b9eecb0a14e4c62e78218d1652d9d4cf2 (patch)
treeced22fefa5df63d671df391c7e0a6df49b9997a9 /src/Data
parent31662543b094e04dc373daf264aa62cfc3550457 (diff)
mtranspose without rerank
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed.hs99
-rw-r--r--src/Data/Array/Nested/Internal.hs74
2 files changed, 122 insertions, 51 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 6ac3ab3..f62d781 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -303,27 +303,33 @@ completeShXzeros ZKX = ZSX
completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh
completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh
--- TODO: generalise all these things to arbitrary @i@
-ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
-ixAppend ZIX idx' = idx'
-ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx'
+listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend ZX idx' = idx'
+listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
+
+ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
+ixAppend = coerce (listxAppend @_ @(Const i))
+
+shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shAppend = coerce (listxAppend @_ @(SMayNat i SNat))
-shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh')
-shAppend ZSX sh' = sh'
-shAppend (n :$% sh) sh' = n :$% shAppend sh sh'
+listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
+listxDrop long ZX = long
+listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
-ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
-ixDrop long ZIX = long
-ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short
+ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixDrop = coerce (listxDrop @(Const i) @(Const i))
-shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh'
-shDropIx sh ZIX = sh
-shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx
+shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh'
-ssxDropIx ssh ZIX = ssh
-ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx
+shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+
+-- TODO: generalise all these things to arbitrary @i@
shTail :: IShX (n : sh) -> IShX sh
shTail (_ :$% sh) = sh
@@ -603,26 +609,59 @@ lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0"
lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
lemIndexSucc _ _ _ = unsafeCoerce Refl
-ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen HNil _ = ZKX
-ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh
-ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape"
+listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f
+listxTakeLen HNil _ = ZX
+listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape"
+
+listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f
+listxDropLen HNil sh = sh
+listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape"
+
+listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f
+listxPermute HNil _ = ZX
+listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
+listxIndex _ _ SZ (n ::% _) rest = n ::% rest
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listxIndex p pT i sh rest
+listxIndex _ _ _ ZX _ = error "Index into empty shape"
+
+listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen HNil sh = sh
-ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh
-ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape"
+ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute HNil _ = ZKX
-ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
+ssxPermute = coerce (listxPermute @(SMayNat () SNat))
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
-ssxIndex _ _ SZ (n :!% _) rest = n :!% rest
-ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = ssxIndex p pT i sh rest
-ssxIndex _ _ _ ZKX _ = error "Index into empty shape"
+ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+
+ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+
+shTakeLen :: forall is sh. HList SNat is -> IShX sh -> IShX (TakeLen is sh)
+shTakeLen = coerce (listxTakeLen @(SMayNat Int SNat))
+
+shDropLen :: HList SNat is -> IShX sh -> IShX (DropLen is sh)
+shDropLen = coerce (listxDropLen @(SMayNat Int SNat))
+
+shPermute :: HList SNat is -> IShX sh -> IShX (Permute is sh)
+shPermute = coerce (listxPermute @(SMayNat Int SNat))
+
+shIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> IShX (Permute is shT) -> IShX (Index i sh : Permute is shT)
+shIndex p1 p2 = coerce (listxIndex @(SMayNat Int SNat) p1 p2)
+
+shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh)
+shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-- | The list argument gives indices into the original dimension list.
transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh)
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index ededb60..2dc69ba 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -151,6 +151,19 @@ lemReplicatePlusApp sn _ _ = go sn
, Refl <- go n
= sym (X.lemReplicateSucc @a @(n'm1 + m))
+lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
+lemLeqPlus _ _ _ = Refl
+
+lemDropLenApp :: X.Rank l1 <= X.Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest)
+lemDropLenApp _ _ _ = unsafeCoerce Refl
+
+lemTakeLenApp :: X.Rank l1 <= X.Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest)
+lemTakeLenApp _ _ _ = unsafeCoerce Refl
+
-- === NEW INDEX TYPES === --
@@ -540,6 +553,9 @@ class Elt a where
mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+ mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
+ => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
+
-- ====== PRIVATE METHODS ====== --
mshapeTree :: a -> ShapeTree a
@@ -618,6 +634,10 @@ instance Storable a => Elt (Primitive a) where
let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1'
in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr)
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (X.shPermutePrefix perm sh)
+ (X.transpose (X.staticShapeFrom sh) perm arr)
+
mshapeTree _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
@@ -668,6 +688,8 @@ instance (Elt a, Elt b) => Elt (a, b) where
mcast ssh1 sh2 psh' (M_Tup2 a b) =
M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
@@ -755,6 +777,19 @@ instance Elt a => Elt (Mixed sh' a) where
= let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T
in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+ mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
+ => HList SNat is -> Mixed sh (Mixed sh' a)
+ -> Mixed (X.PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = X.shDropSh @sh @sh' (mshape arr) sh
+ , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh')
+ , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh'))
+ , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (X.shPermutePrefix perm sh)
+ (mtranspose perm arr)
+
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr)))))
@@ -825,13 +860,6 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
- => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
-mtranspose perm arr =
- let ssh = X.staticShapeFrom (mshape arr)
- sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh)
- in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr
-
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
@@ -1002,6 +1030,8 @@ instance Elt a => Elt (Ranked n a) where
mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -1098,6 +1128,8 @@ instance Elt a => Elt (Shaped sh a) where
mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -1499,21 +1531,21 @@ lemCommMapJustPermute (i `HCons` is) sh
, Refl <- lemCommMapJustIndex i sh
= Refl
-shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
-shTakeLen HNil _ = ZSS
-shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh
-shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
+shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shsTakeLen HNil _ = ZSS
+shsTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shsTakeLen is sh
+shsTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
-shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
-shPermute HNil _ = ZSS
-shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh)
+shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shsPermute HNil _ = ZSS
+shsPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shsIndex (Proxy @is') (Proxy @sh) i sh (shsPermute is sh)
-shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
-shIndex _ _ SZ (n :$$ _) rest = n :$$ rest
-shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shsIndex _ _ SZ (n :$$ _) rest = n :$$ rest
+shsIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
| Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = shIndex p pT i sh rest
-shIndex _ _ _ ZSS _ = error "Index into empty shape"
+ = shsIndex p pT i sh rest
+shsIndex _ _ _ ZSS _ = error "Index into empty shape"
stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
=> HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
@@ -1521,8 +1553,8 @@ stranspose perm sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
, Refl <- lemCommMapJustTakeLen perm (sshape sarr)
, Refl <- lemCommMapJustDropLen perm (sshape sarr)
- , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr))
- , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
+ , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr))
+ , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a