aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-14 23:12:30 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-15 00:57:40 +0100
commit5c85e5d5b6357ac3eb7a54ab6e7eccdc987004fa (patch)
tree57d29f95ef3813b72e9a5e3375926286ab20910d /src
parent6e841f3e7d19253db65874d87e2277c050dad984 (diff)
Implement shxFromShS and shsFromShX as a newtype coerce
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Convert.hs8
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs1
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs14
3 files changed, 8 insertions, 15 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index da1c384..3706105 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -38,6 +38,7 @@ module Data.Array.Nested.Convert (
) where
import Control.Category
+import Data.Coerce (coerce)
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
@@ -111,7 +112,8 @@ withShsFromShR (n :$: sh) k =
Just sn@SNat -> k (sn :$$ sh')
Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
--- shsFromShX re-exported
+shsFromShX :: IShX (MapJust sh) -> ShS sh
+shsFromShX = coerce
-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
@@ -126,6 +128,7 @@ withShsFromShX (SUnknown n :$% sh) k =
Just sn@SNat -> k (sn :$$ sh')
Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+-- If it ever matters for performance, this is unsafeCoercible.
shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
shsFromSSX = shsFromShX Prelude.. shxFromSSX
@@ -146,8 +149,7 @@ shxFromShR (n :$: (idx :: ShR m i)) =
(SUnknown n :$% shxFromShR idx)
shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+shxFromShS = coerce
-- ixxCast re-exported
-- shxCast re-exported
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 3f4ee9a..f08b8be 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -434,6 +434,7 @@ shxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
go (smn :$% sh) = fromSMayNat' smn `cons` go sh
in go list)
+-- If it ever matters for performance, this is unsafeCoercible.
shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
shxFromSSX ZKX = ZSX
shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index e2ec416..b86bfe5 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -26,7 +26,6 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -132,7 +131,7 @@ instance Elt a => Elt (Shaped sh a) where
type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
- mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+ mshapeTree (Shaped arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -256,13 +255,4 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shsFromShX (mshape arr)
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
- castWith (subst1 (sym (lemMapJustCons Refl))) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
-shsFromShX (SUnknown _ :$% _) = error "impossible"
+sshape (Shaped arr) = coerce (mshape arr)