diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-14 23:12:30 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-15 00:57:40 +0100 |
| commit | 5c85e5d5b6357ac3eb7a54ab6e7eccdc987004fa (patch) | |
| tree | 57d29f95ef3813b72e9a5e3375926286ab20910d /src | |
| parent | 6e841f3e7d19253db65874d87e2277c050dad984 (diff) | |
Implement shxFromShS and shsFromShX as a newtype coerce
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 1 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 14 |
3 files changed, 8 insertions, 15 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index da1c384..3706105 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -38,6 +38,7 @@ module Data.Array.Nested.Convert ( ) where import Control.Category +import Data.Coerce (coerce) import Data.Proxy import Data.Type.Equality import GHC.TypeLits @@ -111,7 +112,8 @@ withShsFromShR (n :$: sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" --- shsFromShX re-exported +shsFromShX :: IShX (MapJust sh) -> ShS sh +shsFromShX = coerce -- | Produce an existential 'ShS' from an 'IShX'. If you already know that -- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. @@ -126,6 +128,7 @@ withShsFromShX (SUnknown n :$% sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" +-- If it ever matters for performance, this is unsafeCoercible. shsFromSSX :: StaticShX (MapJust sh) -> ShS sh shsFromSSX = shsFromShX Prelude.. shxFromSSX @@ -146,8 +149,7 @@ shxFromShR (n :$: (idx :: ShR m i)) = (SUnknown n :$% shxFromShR idx) shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +shxFromShS = coerce -- ixxCast re-exported -- shxCast re-exported diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 3f4ee9a..f08b8be 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -434,6 +434,7 @@ shxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> go (smn :$% sh) = fromSMayNat' smn `cons` go sh in go list) +-- If it ever matters for performance, this is unsafeCoercible. shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i shxFromSSX ZKX = ZSX shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index e2ec416..b86bfe5 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -26,7 +26,6 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy -import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -132,7 +131,7 @@ instance Elt a => Elt (Shaped sh a) where type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + mshapeTree (Shaped arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -256,13 +255,4 @@ satan2Array = liftShaped2 matan2Array sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shsFromShX (mshape arr) - --- Needed already here, but re-exported in Data.Array.Nested.Convert. -shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = - castWith (subst1 (sym (lemMapJustCons Refl))) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) -shsFromShX (SUnknown _ :$% _) = error "impossible" +sshape (Shaped arr) = coerce (mshape arr) |
