aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-16 12:17:26 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-16 12:17:26 +0200
commitbad1902b2b3d8835cfe65700893c8ed8b560c893 (patch)
treee9ab386f77372794fe314a427b7470fee29bc8a2 /src/Data/Array/Mixed.hs
parentac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (diff)
Fix transpose
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs114
1 files changed, 93 insertions, 21 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 7f9076b..856e6cb 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -47,6 +48,28 @@ pattern GHC_SNat = SNat
fromSNat' :: SNat n -> Int
fromSNat' = fromIntegral . fromSNat
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
-- | Type-level list append.
type family l1 ++ l2 where
@@ -59,6 +82,11 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+
type IxX :: [Maybe Nat] -> Type -> Type
data IxX sh i where
ZIX :: IxX '[] i
@@ -165,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX
ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
ssxToShape' (_ :!$? _) = Nothing
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKSX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = () :!$? ssxReplicate n
+
fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -331,7 +368,7 @@ rerank ssh ssh1 ssh2 f (XArray arr)
where
unXArray (XArray a) = a
-rerankTop :: forall sh sh1 sh2 a b.
+rerankTop :: forall sh1 sh2 sh a b.
(Storable a, Storable b)
=> StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
-> (XArray sh1 a -> XArray sh2 b)
@@ -402,40 +439,75 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
-lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
-lemPermuteRank _ HNil = Refl
-lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl
-
-lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh)
- => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
-lemPermuteRank2 _ HNil = Refl
-lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) =
- let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is'
- p1 = lemPermuteRank2 p is
- p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is')
- p9 = _
- in p9
+lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ HNil = Refl
+lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKSX HNil = Refl
+lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$@ _) HNil = Refl
+lemRankDropLen (_ :!$? _) HNil = Refl
+lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerce Refl
+
+ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen HNil _ = ZKSX
+ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh
+ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh
+ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen HNil sh = sh
+ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute HNil _ = ZKSX
+ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest
+ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest
+ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
-- | The list argument gives indices into the original dimension list.
---
--- This function does not throw: the constraints ensure that the permutation is always valid.
transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
=> HList SNat is
-> XArray sh a
-> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
- , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh))
+ , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankDropLen (knownShapeX @sh) perm
= let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
in XArray (S.transpose perm' arr)
-- | The list argument gives indices into the original dimension list.
--
--- This version throws a runtime error if the permutation is invalid.
-transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
-transposeUntyped perm (XArray arr)
- | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+-- The permutation (the list) must have length <= @n@. If it is longer, this
+-- function throws.
+transposeUntyped :: forall n sh a.
+ SNat n -> StaticShX sh -> [Int]
+ -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
+transposeUntyped sn ssh perm (XArray arr)
+ | length perm <= fromSNat' sn
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
= XArray (S.transpose perm arr)
+ | otherwise
+ = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
transpose2 :: forall sh1 sh2 a.
StaticShX sh1 -> StaticShX sh2