aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-16 12:17:26 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-16 12:17:26 +0200
commitbad1902b2b3d8835cfe65700893c8ed8b560c893 (patch)
treee9ab386f77372794fe314a427b7470fee29bc8a2 /src/Data/Array/Nested
parentac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (diff)
Fix transpose
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs130
1 files changed, 73 insertions, 57 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index b3f8143..de27336 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -85,9 +85,8 @@ import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..))
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate)
import qualified Data.Array.Mixed as X
@@ -124,36 +123,10 @@ import qualified Data.Array.Mixed as X
-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
-pattern SZ :: () => (n ~ 0) => SNat n
-pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
- where SZ = SNat
-
-pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
-pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
- where SS = snatSucc
-
-{-# COMPLETE SZ, SS #-}
-
-snatSucc :: SNat n -> SNat (n + 1)
-snatSucc SNat = SNat
-
-data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
-snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
-snatPred snp1 =
- withKnownNat snp1 $
- case cmpNat (Proxy @1) (Proxy @np1) of
- LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- GTI -> Nothing
-
-- Stupid things that the type checker should be able to figure out in-line, but can't
@@ -163,10 +136,6 @@ subst1 Refl = Refl
subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl
--- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that.
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerce Refl
-
lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl
@@ -179,7 +148,7 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
go :: SNat m -> StaticShX (Replicate m Nothing)
go SZ = ZKSX
- go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
+ go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = go (natSing @n)
@@ -187,7 +156,7 @@ lemRankReplicate _ = go (natSing @n)
go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
go (SS (n :: SNat nm1))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
, Refl <- go n
= Refl
@@ -205,9 +174,9 @@ lemReplicatePlusApp _ _ _ = go (natSing @n)
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS (n :: SNat n'm1))
- | Refl <- lemReplicateSucc @a @n'm1
+ | Refl <- X.lemReplicateSucc @a @n'm1
, Refl <- go n
- = sym (lemReplicateSucc @a @(n'm1 + m))
+ = sym (X.lemReplicateSucc @a @(n'm1 + m))
shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
shAppSplit _ ZKSX idx = (ZSX, idx)
@@ -583,10 +552,12 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a
-mtranspose perm = mlift $ \(Proxy @sh') ->
- X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
- (X.transpose perm)
+mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a
+mtranspose perm
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh)))
+ = mlift $ \(Proxy @sh') ->
+ X.rerankTop (knownShapeX @sh) (knownShapeX @(X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh)) (knownShapeX @sh')
+ (X.transpose perm)
mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
@@ -823,13 +794,10 @@ lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
go ZSS = ZKSX
go (n :$$ sh) = n :!$@ go sh
-lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
+lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2
-> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustPlusApp _ _ = go (knownShape @sh1)
- where
- go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
- go ZSS = Refl
- go (_ :$$ sh) | Refl <- go sh = Refl
+lemCommMapJustApp ZSS _ = Refl
+lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl
instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
@@ -1057,11 +1025,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
+ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
+shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
shapeSizeR :: IShR n -> Int
shapeSizeR ZSR = 1
@@ -1105,7 +1073,7 @@ rsumOuter1P :: forall n a.
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked
. coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
. X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
@@ -1119,15 +1087,17 @@ rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
rtranspose perm
| Dict <- lemKnownReplicate (Proxy @n)
+ , length perm <= fromIntegral (natVal (Proxy @n))
= rlift $ \(Proxy @sh') ->
- X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh')
- (X.transposeUntyped perm)
+ X.transposeUntyped (natSing @n) (knownShapeX @sh') perm
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
rappend :: forall n a. (KnownNat n, Elt a)
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
rappend
| Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
rscalar :: Elt a => a -> Ranked 0 a
@@ -1150,7 +1120,7 @@ rtoVector = coerce mtoVector
rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))
rfromList :: Elt a => NonEmpty a -> Ranked 1 a
@@ -1158,7 +1128,7 @@ rfromList = Ranked . mfromList1 . fmap mscalar
rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoList (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
@@ -1181,7 +1151,7 @@ rslice ivs = rlift $ \_ -> X.slice ivs
rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 = rlift $ \(Proxy @sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
+ case X.lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')
rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
@@ -1308,7 +1278,7 @@ sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial (Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
+ (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
@@ -1342,10 +1312,56 @@ ssumOuter1 :: forall sh n a.
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a
+lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh)
+lemCommMapJustTakeLen HNil _ = Refl
+lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
+lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty"
+
+lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh)
+lemCommMapJustDropLen HNil _ = Refl
+lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
+lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty"
+
+lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh)
+lemCommMapJustIndex SZ (_ :$$ _) = Refl
+lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemCommMapJustIndex i sh
+ , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemCommMapJustIndex _ ZSS = error "Index of empty"
+
+lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh)
+lemCommMapJustPermute HNil _ = Refl
+lemCommMapJustPermute (i `HCons` is) sh
+ | Refl <- lemCommMapJustPermute is sh
+ , Refl <- lemCommMapJustIndex i sh
+ = Refl
+
+shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shTakeLen HNil _ = ZSS
+shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh
+shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
+
+shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shPermute HNil _ = ZSS
+shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh)
+
+shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shIndex _ _ SZ (n :$$ _) rest = n :$$ rest
+shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
+ | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = shIndex p pT i sh rest
+shIndex _ _ _ ZSS _ = error "Index into empty shape"
+
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
, Refl <- lemRankMapJust (Proxy @sh)
+ , Refl <- lemCommMapJustTakeLen perm (knownShape @sh)
+ , Refl <- lemCommMapJustDropLen perm (knownShape @sh)
+ , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh))
+ , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)