diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-14 21:22:20 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-14 21:22:22 +0100 |
| commit | 6e841f3e7d19253db65874d87e2277c050dad984 (patch) | |
| tree | b7e988b6373865f825a3d1288618287e7877b6ba /src | |
| parent | b0cc8caff4ccf5df85f3bea743be1f03ddde01c6 (diff) | |
Make ShS a newtype over ShX
TODO: use lemmas in place of the unsafeCoerceRefl
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 12 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 107 |
2 files changed, 93 insertions, 26 deletions
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 6bebcfb..8b46d81 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -215,6 +215,18 @@ ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ())) +shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) +shxTakeLen = coerce (listxTakeLen @(SMayNat Int)) + +shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) +shxDropLen = coerce (listxDropLen @(SMayNat Int)) + +shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) +shxPermute = coerce (listxPermute @(SMayNat Int)) + +shxIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> SMayNat Int (Index i sh) +shxIndex p1 p2 i = coerce (listxIndex @(SMayNat Int) p1 p2 i) + shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int)) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 6485c72..7206be3 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -331,21 +331,34 @@ ixsToLinear = \sh i -> go sh i 0 -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) +newtype ShS sh = ShS (ShX (MapJust sh) Int) deriving (Generic) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS +pattern ZSS <- ShS (matchZS -> Just Refl) + where ZSS = ShS ZSX + +matchZS :: forall sh f. ShX (MapJust sh) f -> Maybe (sh :~: '[]) +matchZS ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZS _ = Nothing pattern (:$$) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) +pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl)) + where i :$$ ShS shl = ShS (SKnown i :$% shl) + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) + | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing infixr 3 :$$ @@ -355,15 +368,16 @@ infixr 3 :$$ deriving instance Show (ShS sh) #else instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + showsPrec d (ShS shx) = showsPrec d shx #endif instance NFData (ShS sh) where - rnf (ShS ZS) = () - rnf (ShS (SNat ::$ l)) = rnf (ShS l) + rnf (ShS shx) = rnf shx instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) @@ -371,10 +385,13 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l +shsLength (ShS shx) = shxLength shx -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Rank (MapJust sh) :~: Rank sh) $ + shxRank shx shsSize :: ShS sh -> Int shsSize ZSS = 1 @@ -404,31 +421,69 @@ shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> in go topsh) shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @_ @_ @Int) -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @_ @Int) + +shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce shxTakeLen -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) +shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLen = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce shxDropLen -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce shxPermute -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh)) +shsIndex :: forall is shT i sh. + Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) +shsIndex pis pshT i sh = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex pis pshT i (coerce sh) of + SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 |
