From 0f13e61b9eecb0a14e4c62e78218d1652d9d4cf2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 11:27:18 +0200 Subject: mtranspose without rerank --- src/Data/Array/Mixed.hs | 99 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 30 deletions(-) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 6ac3ab3..f62d781 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -303,27 +303,33 @@ completeShXzeros ZKX = ZSX completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh --- TODO: generalise all these things to arbitrary @i@ -ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') -ixAppend ZIX idx' = idx' -ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx' +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixAppend = coerce (listxAppend @_ @(Const i)) + +shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shAppend = coerce (listxAppend @_ @(SMayNat i SNat)) -shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') -shAppend ZSX sh' = sh' -shAppend (n :$% sh) sh' = n :$% shAppend sh sh' +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short -ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' -ixDrop long ZIX = long -ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short +ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixDrop = coerce (listxDrop @(Const i) @(Const i)) -shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' -shDropIx sh ZIX = sh -shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx +shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) -ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh' -ssxDropIx ssh ZIX = ssh -ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx +shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +-- TODO: generalise all these things to arbitrary @i@ shTail :: IShX (n : sh) -> IShX sh shTail (_ :$% sh) = sh @@ -603,26 +609,59 @@ lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl -ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen HNil _ = ZKX -ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh -ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape" +listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen HNil _ = ZX +listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen HNil sh = sh +listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f +listxPermute HNil _ = ZX +listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f +listxIndex _ _ SZ (n ::% _) rest = n ::% rest +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh rest +listxIndex _ _ _ ZX _ = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen HNil sh = sh -ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh -ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape" +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute HNil _ = ZKX -ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex _ _ SZ (n :!% _) rest = n :!% rest -ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = ssxIndex p pT i sh rest -ssxIndex _ _ _ ZKX _ = error "Index into empty shape" +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shTakeLen :: forall is sh. HList SNat is -> IShX sh -> IShX (TakeLen is sh) +shTakeLen = coerce (listxTakeLen @(SMayNat Int SNat)) + +shDropLen :: HList SNat is -> IShX sh -> IShX (DropLen is sh) +shDropLen = coerce (listxDropLen @(SMayNat Int SNat)) + +shPermute :: HList SNat is -> IShX sh -> IShX (Permute is sh) +shPermute = coerce (listxPermute @(SMayNat Int SNat)) + +shIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> IShX (Permute is shT) -> IShX (Index i sh : Permute is shT) +shIndex p1 p2 = coerce (listxIndex @(SMayNat Int SNat) p1 p2) + +shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) +shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) -- cgit v1.2.3-70-g09d2