diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 72 |
1 files changed, 28 insertions, 44 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 8c88d23..408bf8a 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -15,10 +15,10 @@ module Data.Array.Nested.Convert ( -- * Shape\/index\/list casting functions -- ** To ranked - ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX, listrCast, ixrCast, shrCast, -- ** To shaped - ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsFromIxR, ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, ixsCast, -- ** To mixed ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, @@ -38,9 +38,11 @@ module Data.Array.Nested.Convert ( ) where import Control.Category +import Data.Coerce (coerce) import Data.Proxy import Data.Type.Equality import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed @@ -55,48 +57,39 @@ import Data.Array.Nested.Types -- * To ranked +-- TODO: change all those unsafeCoerces into coerces by defining shaped +-- and ranekd index types as newtypes of the mixed index type +-- and similarly for the sized lists or, preferably, by defining +-- all as newtypes over [], exploiting fusion and getting free toList. ixrFromIxS :: IxS sh i -> IxR (Rank sh) i -ixrFromIxS ZIS = ZIR -ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix +ixrFromIxS = unsafeCoerce -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx +-- ixrFromIxX re-exported shrFromShS :: ShS sh -> IShR (Rank sh) shrFromShS ZSS = ZSR shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh --- shrFromShX re-exported --- shrFromShX2 re-exported +shrFromShXAnyShape :: IShX sh -> IShR (Rank sh) +shrFromShXAnyShape ZSX = ZSR +shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx + +shrFromShX :: IShX (Replicate n Nothing) -> IShR n +shrFromShX = coerce + -- listrCast re-exported -- ixrCast re-exported -- shrCast re-exported -- * To shaped --- TODO: these take a ShS because there are KnownNats inside IxS. - -ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i -ixsFromIxR ZSS ZIR = ZIS -ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR :: IxR (Rank sh) i -> IxS sh i +ixsFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled --- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the --- following, but more efficient: --- --- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) -ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i -ixsFromIxR' ZSS ZIR = ZIS -ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx -ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" - --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx +-- ixsFromIxX re-exported -- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to --- the following, but more efficient: +-- the following, but less verbose: -- -- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i @@ -113,7 +106,8 @@ withShsFromShR (n :$: sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" --- shsFromShX re-exported +shsFromShX :: IShX (MapJust sh) -> ShS sh +shsFromShX = coerce -- | Produce an existential 'ShS' from an 'IShX'. If you already know that -- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. @@ -128,6 +122,7 @@ withShsFromShX (SUnknown n :$% sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" +-- If it ever matters for performance, this is unsafeCoercible. shsFromSSX :: StaticShX (MapJust sh) -> ShS sh shsFromSSX = shsFromShX Prelude.. shxFromSSX @@ -135,25 +130,14 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX -- * To mixed -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) - (n :.% ixxFromIxR idx) - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh +-- ixxFromIxR re-exported +-- ixxFromIxS re-exported shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) - (SUnknown n :$% shxFromShR idx) +shxFromShR = coerce shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +shxFromShS = coerce -- ixxCast re-exported -- shxCast re-exported |
