From 95544b35615f6714fbef914cb6f2935a088e4d06 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 May 2024 09:35:06 +0200 Subject: Add invertPermutation --- src/Data/Array/Mixed.hs | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index f7fca0f..065756d 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -34,6 +34,7 @@ import Data.Coerce import qualified Data.Foldable as Foldable import Data.Functor.Const import Data.Kind +import Data.List (sort) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Bool @@ -259,6 +260,17 @@ instance Show (StaticShX sh) where lengthStaticShX :: StaticShX sh -> Int lengthStaticShX (StaticShX l) = lengthListX l +geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +geqStaticShX ZKX ZKX = Just Refl +geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') + | Just Refl <- sameNat n m + , Just Refl <- geqStaticShX sh sh' + = Just Refl +geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh') + | Just Refl <- geqStaticShX sh sh' + = Just Refl +geqStaticShX _ _ = Nothing + -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. @@ -767,6 +779,33 @@ ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) +-- TODO: test this thing more properly +invertPermutation :: HList SNat is + -> (forall is'. + HList SNat is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +invertPermutation = \perm k -> + genPerm perm $ \invperm -> + k invperm + (\ssh -> case provePermInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + where + genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r + genPerm perm = + let permList = foldHList (pure . fromSNat) perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r + toHList [] k = k HNil + toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) + + provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) + provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh + class KnownNatList l where makeNatList :: HList SNat l instance KnownNatList '[] where makeNatList = HNil instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList -- cgit v1.2.3-70-g09d2