From dc0270a1fd5db180df88023bb2628b046447df0d Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 29 Jun 2025 13:07:22 +0200 Subject: More shape/index conversion functions --- src/Data/Array/Nested/Convert.hs | 55 +++++++++++++++++++++++++++++------ src/Data/Array/Nested/Mixed/Shape.hs | 20 ++++++++++--- src/Data/Array/Nested/Ranked/Shape.hs | 20 +++++++++++++ src/Data/Array/Nested/Shaped/Base.hs | 7 +---- src/Data/Array/Nested/Shaped/Shape.hs | 17 +++++++++++ src/Data/Array/Nested/Types.hs | 7 +++++ 6 files changed, 107 insertions(+), 19 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 861bf20..fd59ba6 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -7,11 +7,18 @@ {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Data.Array.Nested.Convert ( - -- * Shape/index/list casting functions + -- * Shape\/index\/list casting functions + -- ** To ranked ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, - ixsFromIxX, shsFromShX, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', shsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', -- * Array conversions convert, @@ -57,24 +64,50 @@ shrFromShS ZSS = ZSR shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh -- shrFromShX re-exported - -- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported -- * To shaped --- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh --- ixsFromIxR = \ix -> go ix _ --- where --- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r --- go ZIR k = k ZIS --- go (i :.: ix) k = go ix (i :.$) +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ _ = error "unreachable" +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i ixsFromIxX ZSS ZIX = ZIS ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + -- shsFromShX re-exported +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + -- * To mixed ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i @@ -97,6 +130,10 @@ shxFromShS :: ShS sh -> IShX (MapJust sh) shxFromShS ZSS = ZSX shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + -- * Array conversions diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 2ee3600..e63277f 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -244,6 +244,11 @@ ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js @@ -409,11 +414,18 @@ shxToList :: IShX sh -> [Int] shxToList ZSX = [] shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" + -- | This may fail if @sh@ has @Nothing@s in it. -shxFromSSX' :: StaticShX sh -> Maybe (ShX sh i) -shxFromSSX' ZKX = Just ZSX -shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh -shxFromSSX' (SUnknown _ :!% _) = Nothing +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 3edebf6..8b670e5 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -131,6 +131,10 @@ listrLast (_ ::: sh@(_ ::: _)) = listrLast sh listrLast (n ::: ZR) = n listrLast ZR = error "unreachable" +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" + listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs @@ -230,6 +234,10 @@ ixrInit (IxR list) = IxR (listrInit list) ixrLast :: IxR (n + 1) i -> i ixrLast (IxR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) @@ -310,6 +318,10 @@ shrInit (ShR list) = ShR (listrInit list) shrLast :: ShR (n + 1) i -> i shrLast (ShR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) @@ -347,3 +359,11 @@ instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index a24a91a..879e6b5 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -249,12 +249,7 @@ sshape (Shaped arr) = shsFromShX (mshape arr) shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = - castWith (subst1 (lem Refl)) $ + castWith (subst1 (sym (lemMapJustCons Refl))) $ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index d34f3ec..ab16422 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -248,6 +248,12 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -374,6 +380,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index b8a9aea..df466cf 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -31,6 +31,7 @@ module Data.Array.Nested.Types ( Replicate, lemReplicateSucc, MapJust, + lemMapJustEmpty, lemMapJustCons, Head, Tail, Init, @@ -116,6 +117,12 @@ type family MapJust l = r | r -> l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + type family Head l where Head (x : _) = x -- cgit v1.2.3-70-g09d2