diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-16 12:17:26 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-16 12:17:26 +0200 |
commit | bad1902b2b3d8835cfe65700893c8ed8b560c893 (patch) | |
tree | e9ab386f77372794fe314a427b7470fee29bc8a2 /src/Data/Array/Mixed.hs | |
parent | ac5c0f1d9f3ba04d1e6647625a7699f463bb3e73 (diff) |
Fix transpose
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 114 |
1 files changed, 93 insertions, 21 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 7f9076b..856e6cb 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -16,6 +16,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where @@ -47,6 +48,28 @@ pattern GHC_SNat = SNat fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + -- | Type-level list append. type family l1 ++ l2 where @@ -59,6 +82,11 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + + type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i @@ -165,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh ssxToShape' (_ :!$? _) = Nothing +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerce Refl + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKSX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = () :!$? ssxReplicate n + fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx @@ -331,7 +368,7 @@ rerank ssh ssh1 ssh2 f (XArray arr) where unXArray (XArray a) = a -rerankTop :: forall sh sh1 sh2 a b. +rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) @@ -402,40 +439,75 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -lemPermuteRank :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is -lemPermuteRank _ HNil = Refl -lemPermuteRank p (_ `HCons` is) | Refl <- lemPermuteRank p is = Refl - -lemPermuteRank2 :: forall is sh. (Rank is <= Rank sh) - => Proxy sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemPermuteRank2 _ HNil = Refl -lemPermuteRank2 p ((_ :: SNat n) `HCons` (is :: HList SNat is')) = - let p1 :: Rank (DropLen is' sh) :~: Rank sh - Rank is' - p1 = lemPermuteRank2 p is - p9 :: Rank (DropLen (n : is') sh) :~: Rank sh - (1 + Rank is') - p9 = _ - in p9 +lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ HNil = Refl +lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKSX HNil = Refl +lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$@ _) HNil = Refl +lemRankDropLen (_ :!$? _) HNil = Refl +lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerce Refl + +ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen HNil _ = ZKSX +ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh +ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh +ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen HNil sh = sh +ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute HNil _ = ZKSX +ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest +ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest +ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. --- --- This function does not throw: the constraints ensure that the permutation is always valid. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) => HList SNat is -> XArray sh a -> XArray (Permute is (TakeLen is sh) ++ DropLen is sh) a transpose perm (XArray arr) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Refl <- lemPermuteRank (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen (knownShapeX @sh) perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) -- | The list argument gives indices into the original dimension list. -- --- This version throws a runtime error if the permutation is invalid. -transposeUntyped :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a -transposeUntyped perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 |