aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs39
1 files changed, 39 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index f7fca0f..065756d 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -34,6 +34,7 @@ import Data.Coerce
import qualified Data.Foldable as Foldable
import Data.Functor.Const
import Data.Kind
+import Data.List (sort)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Bool
@@ -259,6 +260,17 @@ instance Show (StaticShX sh) where
lengthStaticShX :: StaticShX sh -> Int
lengthStaticShX (StaticShX l) = lengthListX l
+geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+geqStaticShX ZKX ZKX = Just Refl
+geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- geqStaticShX sh sh'
+ = Just Refl
+geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh')
+ | Just Refl <- geqStaticShX sh sh'
+ = Just Refl
+geqStaticShX _ _ = Nothing
+
-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
@@ -767,6 +779,33 @@ ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh)
shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+-- TODO: test this thing more properly
+invertPermutation :: HList SNat is
+ -> (forall is'.
+ HList SNat is'
+ -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
+ -> r)
+ -> r
+invertPermutation = \perm k ->
+ genPerm perm $ \invperm ->
+ k invperm
+ (\ssh -> case provePermInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ where
+ genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r
+ genPerm perm =
+ let permList = foldHList (pure . fromSNat) perm
+ in toHList $ map snd (sort (zip permList [0..]))
+ where
+ toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r
+ toHList [] k = k HNil
+ toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l)
+
+ provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh)
+ provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh
+
class KnownNatList l where makeNatList :: HList SNat l
instance KnownNatList '[] where makeNatList = HNil
instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList