aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-11 14:08:18 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-03-15 10:35:21 +0100
commit01a79440952d789184101fc1aad277c00d010a25 (patch)
tree9d41db480536f9e856bdc058351adacd6b10e99e /src
parent816249cd59a7e243bec82651e2def22f8c3b439c (diff)
Make ShS a newtype over ShX
TODO: use lemmas in place of the unsafeCoerceRefl
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Convert.hs7
-rw-r--r--src/Data/Array/Nested/Permutation.hs12
-rw-r--r--src/Data/Array/Nested/Shaped.hs4
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs147
4 files changed, 114 insertions, 56 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 8c88d23..3d0da37 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -75,12 +75,12 @@ shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
-- * To shaped
--- TODO: these take a ShS because there are KnownNats inside IxS.
-
+-- TODO: remove the ShS now that no KnownNats is inside IxS.
ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
ixsFromIxR ZSS ZIR = ZIS
ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS.
-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
-- following, but more efficient:
--
@@ -90,11 +90,12 @@ ixsFromIxR' ZSS ZIR = ZIS
ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
--- TODO: this takes a ShS because there are KnownNats inside IxS.
+-- TODO: remove the ShS now that no KnownNats is inside IxS.
ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
ixsFromIxX ZSS ZIX = ZIS
ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS.
-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
-- the following, but more efficient:
--
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 6bebcfb..8b46d81 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -215,6 +215,18 @@ ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ()))
+shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh)
+shxTakeLen = coerce (listxTakeLen @(SMayNat Int))
+
+shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh)
+shxDropLen = coerce (listxDropLen @(SMayNat Int))
+
+shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh)
+shxPermute = coerce (listxPermute @(SMayNat Int))
+
+shxIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> SMayNat Int (Index i sh)
+shxIndex p1 p2 i = coerce (listxIndex @(SMayNat Int) p1 p2 i)
+
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int))
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index acb7c89..142a536 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -255,9 +255,7 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape
sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
-sflatten arr =
- case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
- n@SNat -> sreshape (n :$$ ZSS) arr
+sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0c042b7..d815adf 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -52,16 +52,14 @@ import Data.Array.Nested.Types
-- * Shaped lists
--- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
--- removed in a future release.
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+ (::$) :: forall n sh {f}. f n -> ListS sh f -> ListS (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+
infixr 3 ::$
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -76,7 +74,7 @@ instance (forall m. NFData (f m)) => NFData (ListS n f) where
rnf (x ::$ l) = rnf x `seq` rnf l
data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
@@ -199,11 +197,11 @@ listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
+ item -> item ::$ listsPermute is sh
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh)
+listsIndex _ _ SZ (n ::$ _) = n
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listsIndex p pT i sh
@@ -227,7 +225,7 @@ pattern ZIS = IxS ZS
-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
where i :.$ IxS shl = IxS (Const i ::$ shl)
@@ -295,11 +293,9 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
--- TODO: this takes a ShS because there are KnownNats inside IxS.
-ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
-ixsCast ZSS ZIS = ZIS
-ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
-ixsCast _ _ = error "ixsCast: ranks don't match"
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -335,21 +331,34 @@ ixsToLinear = \sh i -> go sh i 0
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
+newtype ShS sh = ShS (ShX (MapJust sh) Int)
deriving (Generic)
instance Eq (ShS sh) where _ == _ = True
instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
+pattern ZSS <- ShS (matchZSX -> Just Refl)
+ where ZSS = ShS ZSX
+
+matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZSX _ = Nothing
pattern (:$$)
:: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
+pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl))
+ where i :$$ ShS shl = ShS (SKnown i :$% shl)
+
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
+shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1)
+shsUncons (ShS (SKnown x :$% sh'))
+ | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsShSRes x (ShS sh'))
+shsUncons (ShS _) = Nothing
infixr 3 :$$
@@ -359,15 +368,16 @@ infixr 3 :$$
deriving instance Show (ShS sh)
#else
instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+ showsPrec d (ShS shx) = showsPrec d shx
#endif
instance NFData (ShS sh) where
- rnf (ShS ZS) = ()
- rnf (ShS (SNat ::$ l)) = rnf (ShS l)
+ rnf (ShS shx) = rnf shx
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+ testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of
+ Nothing -> Nothing
+ Just Refl -> Just unsafeCoerceRefl
-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
@@ -375,17 +385,20 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality
shsLength :: ShS sh -> Int
-shsLength (ShS l) = listsLength l
+shsLength (ShS shx) = shxLength shx
-shsRank :: ShS sh -> SNat (Rank sh)
-shsRank (ShS l) = listsRank l
+shsRank :: forall sh. ShS sh -> SNat (Rank sh)
+shsRank (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Rank (MapJust sh) :~: Rank sh) $
+ shxRank shx
shsSize :: ShS sh -> Int
-shsSize ZSS = 1
-shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+shsSize (ShS sh) = shxSize sh
-- | This is a partial @const@ that fails when the second argument
--- doesn't match the first.
+-- doesn't match the first. It also has a better error message comparing
+-- to just coercing 'shxFromList'.
shsFromList :: ShS sh -> [Int] -> ShS sh
shsFromList topsh topl = go topsh topl `seq` topsh
where
@@ -401,38 +414,72 @@ shsFromList topsh topl = go topsh topl `seq` topsh
{-# INLINEABLE shsToList #-}
shsToList :: ShS sh -> [Int]
-shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) ->
- let go :: ShS sh -> is
- go ZSS = nil
- go (sn :$$ sh) = fromSNat' sn `cons` go sh
- in go topsh)
+shsToList = coerce shxToList
shsHead :: ShS (n : sh) -> SNat n
-shsHead (ShS list) = listsHead list
+shsHead (ShS shx) = case shxHead shx of
+ SKnown SNat -> SNat
-shsTail :: ShS (n : sh) -> ShS sh
-shsTail (ShS list) = ShS (listsTail list)
+shsTail :: forall n sh. ShS (n : sh) -> ShS sh
+shsTail = coerce (shxTail @_ @_ @Int)
-shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
-shsInit (ShS list) = ShS (listsInit list)
+shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
+shsInit =
+ gcastWith (unsafeCoerceRefl
+ :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $
+ coerce (shxInit @_ @_ @Int)
-shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
-shsLast (ShS list) = listsLast list
+shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $
+ case shxLast shx of
+ SKnown SNat -> SNat
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
-shsAppend = coerce (listsAppend @_ @SNat)
+shsAppend =
+ gcastWith (unsafeCoerceRefl
+ :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $
+ coerce (shxAppend @_ @_ @Int)
+
+shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen =
+ gcastWith (unsafeCoerceRefl
+ :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $
+ coerce shxTakeLen
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
+shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh)
+shsDropLen =
+ gcastWith (unsafeCoerceRefl
+ :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $
+ coerce shxDropLen
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
+shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute =
+ gcastWith (unsafeCoerceRefl
+ :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $
+ coerce shxPermute
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+shsIndex :: forall is shT i sh.
+ Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex pis pshT i sh =
+ gcastWith (unsafeCoerceRefl
+ :: Index i (MapJust sh) :~: Just (Index i sh)) $
+ case shxIndex pis pshT i (coerce sh) of
+ SKnown SNat -> SNat
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+shsPermutePrefix perm (ShS shx)
+ {- TODO: here and elsewhere, solve the module dependency cycle and add this:
+ | Refl <- lemTakeLenMapJust perm sh
+ , Refl <- lemDropLenMapJust perm sh
+ , Refl <- lemPermuteMapJust perm sh
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -}
+ = gcastWith (unsafeCoerceRefl
+ :: Permute is (TakeLen is (MapJust sh))
+ ++ DropLen is (MapJust sh)
+ :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $
+ ShS (shxPermutePrefix perm shx)
type family Product sh where
Product '[] = 1