From bad1902b2b3d8835cfe65700893c8ed8b560c893 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 16 May 2024 12:17:26 +0200 Subject: Fix transpose --- src/Data/Array/Nested/Internal.hs | 130 +++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 57 deletions(-) (limited to 'src/Data/Array/Nested/Internal.hs') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index b3f8143..de27336 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -85,9 +85,8 @@ import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..)) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate) import qualified Data.Array.Mixed as X @@ -124,36 +123,10 @@ import qualified Data.Array.Mixed as X -- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - -- Stupid things that the type checker should be able to figure out in-line, but can't @@ -163,10 +136,6 @@ subst1 Refl = Refl subst2 :: forall f c a b. a :~: b -> f a c :~: f b c subst2 Refl = Refl --- TODO: is this sound? @n@ cannot be negative, surely, but the plugin doesn't see even that. -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerce Refl - lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l lemAppLeft _ Refl = Refl @@ -179,7 +148,7 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where go :: SNat m -> StaticShX (Replicate m Nothing) go SZ = ZKSX - go (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n + go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (natSing @n) @@ -187,7 +156,7 @@ lemRankReplicate _ = go (natSing @n) go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl go (SS (n :: SNat nm1)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 , Refl <- go n = Refl @@ -205,9 +174,9 @@ lemReplicatePlusApp _ _ _ = go (natSing @n) go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 + | Refl <- X.lemReplicateSucc @a @n'm1 , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) + = sym (X.lemReplicateSucc @a @(n'm1 + m)) shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') shAppSplit _ ZKSX idx = (ZSX, idx) @@ -583,10 +552,12 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed sh a -mtranspose perm = mlift $ \(Proxy @sh') -> - X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') - (X.transpose perm) +mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a +mtranspose perm + | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh))) + = mlift $ \(Proxy @sh') -> + X.rerankTop (knownShapeX @sh) (knownShapeX @(X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh)) (knownShapeX @sh') + (X.transpose perm) mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a @@ -823,13 +794,10 @@ lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) go ZSS = ZKSX go (n :$$ sh) = n :!$@ go sh -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 +lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ZSS = Refl - go (_ :$$ sh) | Refl <- go sh = Refl +lemCommMapJustApp ZSS _ = Refl +lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr @@ -1057,11 +1025,11 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) +ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) +shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) shapeSizeR :: IShR n -> Int shapeSizeR ZSR = 1 @@ -1105,7 +1073,7 @@ rsumOuter1P :: forall n a. => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) @@ -1119,15 +1087,17 @@ rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a rtranspose perm | Dict <- lemKnownReplicate (Proxy @n) + , length perm <= fromIntegral (natVal (Proxy @n)) = rlift $ \(Proxy @sh') -> - X.rerankTop (knownShapeX @(Replicate n Nothing)) (knownShapeX @(Replicate n Nothing)) (knownShapeX @sh') - (X.transposeUntyped perm) + X.transposeUntyped (natSing @n) (knownShapeX @sh') perm + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" rappend :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a rappend | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) rscalar :: Elt a => a -> Ranked 0 a @@ -1150,7 +1120,7 @@ rtoVector = coerce mtoVector rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromList1 l | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l)) rfromList :: Elt a => NonEmpty a -> Ranked 1 a @@ -1158,7 +1128,7 @@ rfromList = Ranked . mfromList1 . fmap mscalar rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoList (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] @@ -1181,7 +1151,7 @@ rslice ivs = rlift $ \_ -> X.slice ivs rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 = rlift $ \(Proxy @sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of + case X.lemReplicateSucc @(Nothing @Nat) @n of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh') rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) @@ -1308,7 +1278,7 @@ sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) + (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) -- | __WARNING__: All values returned from the function must have equal shape. @@ -1342,10 +1312,56 @@ ssumOuter1 :: forall sh n a. => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped sh a +lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) +lemCommMapJustTakeLen HNil _ = Refl +lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" + +lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) +lemCommMapJustDropLen HNil _ = Refl +lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" + +lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex SZ (_ :$$ _) = Refl +lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemCommMapJustIndex i sh + , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemCommMapJustIndex _ ZSS = error "Index of empty" + +lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) +lemCommMapJustPermute HNil _ = Refl +lemCommMapJustPermute (i `HCons` is) sh + | Refl <- lemCommMapJustPermute is sh + , Refl <- lemCommMapJustIndex i sh + = Refl + +shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shTakeLen HNil _ = ZSS +shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh +shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" + +shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shPermute HNil _ = ZSS +shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh) + +shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shIndex _ _ SZ (n :$$ _) rest = n :$$ rest +shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest + | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = shIndex p pT i sh rest +shIndex _ _ _ ZSS _ = error "Index into empty shape" + +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.Permute is (X.TakeLen is sh) ++ X.DropLen is sh) a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) , Refl <- lemRankMapJust (Proxy @sh) + , Refl <- lemCommMapJustTakeLen perm (knownShape @sh) + , Refl <- lemCommMapJustDropLen perm (knownShape @sh) + , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh)) + , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) -- cgit v1.2.3-70-g09d2