aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs20
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs25
2 files changed, 18 insertions, 27 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index fa84efe..ddd44bf 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,18 +26,19 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
-- | A shape-typed array: the full shape of the array (the sizes of its
@@ -242,3 +244,17 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerceRefl
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0b7d1c9..fbfc7f5 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -230,14 +230,6 @@ ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX ZSS ZIX = ZIS
-ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
-
-ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
-ixxFromIxS ZIS = ZIX
-ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
-
ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)
@@ -321,23 +313,6 @@ shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shsFromShX (SUnknown _ :$% _) = error "impossible"
-
-shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list