aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 11:27:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 11:27:18 +0200
commit0f13e61b9eecb0a14e4c62e78218d1652d9d4cf2 (patch)
treeced22fefa5df63d671df391c7e0a6df49b9997a9 /src/Data/Array/Mixed.hs
parent31662543b094e04dc373daf264aa62cfc3550457 (diff)
mtranspose without rerank
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs99
1 files changed, 69 insertions, 30 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 6ac3ab3..f62d781 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -303,27 +303,33 @@ completeShXzeros ZKX = ZSX
completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh
completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh
--- TODO: generalise all these things to arbitrary @i@
-ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
-ixAppend ZIX idx' = idx'
-ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx'
+listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend ZX idx' = idx'
+listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
+
+ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
+ixAppend = coerce (listxAppend @_ @(Const i))
+
+shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shAppend = coerce (listxAppend @_ @(SMayNat i SNat))
-shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh')
-shAppend ZSX sh' = sh'
-shAppend (n :$% sh) sh' = n :$% shAppend sh sh'
+listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
+listxDrop long ZX = long
+listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
-ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
-ixDrop long ZIX = long
-ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short
+ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixDrop = coerce (listxDrop @(Const i) @(Const i))
-shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh'
-shDropIx sh ZIX = sh
-shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx
+shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh'
-ssxDropIx ssh ZIX = ssh
-ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx
+shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+
+-- TODO: generalise all these things to arbitrary @i@
shTail :: IShX (n : sh) -> IShX sh
shTail (_ :$% sh) = sh
@@ -603,26 +609,59 @@ lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0"
lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
lemIndexSucc _ _ _ = unsafeCoerce Refl
-ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen HNil _ = ZKX
-ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh
-ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape"
+listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f
+listxTakeLen HNil _ = ZX
+listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape"
+
+listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f
+listxDropLen HNil sh = sh
+listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape"
+
+listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f
+listxPermute HNil _ = ZX
+listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
+listxIndex _ _ SZ (n ::% _) rest = n ::% rest
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listxIndex p pT i sh rest
+listxIndex _ _ _ ZX _ = error "Index into empty shape"
+
+listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen HNil sh = sh
-ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh
-ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape"
+ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute HNil _ = ZKX
-ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
+ssxPermute = coerce (listxPermute @(SMayNat () SNat))
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
-ssxIndex _ _ SZ (n :!% _) rest = n :!% rest
-ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = ssxIndex p pT i sh rest
-ssxIndex _ _ _ ZKX _ = error "Index into empty shape"
+ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+
+ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+
+shTakeLen :: forall is sh. HList SNat is -> IShX sh -> IShX (TakeLen is sh)
+shTakeLen = coerce (listxTakeLen @(SMayNat Int SNat))
+
+shDropLen :: HList SNat is -> IShX sh -> IShX (DropLen is sh)
+shDropLen = coerce (listxDropLen @(SMayNat Int SNat))
+
+shPermute :: HList SNat is -> IShX sh -> IShX (Permute is sh)
+shPermute = coerce (listxPermute @(SMayNat Int SNat))
+
+shIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> IShX (Permute is shT) -> IShX (Index i sh : Permute is shT)
+shIndex p1 p2 = coerce (listxIndex @(SMayNat Int SNat) p1 p2)
+
+shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh)
+shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-- | The list argument gives indices into the original dimension list.
transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh)