aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Base.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Base.hs')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs20
1 files changed, 18 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index fa84efe..ddd44bf 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,18 +26,19 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
-- | A shape-typed array: the full shape of the array (the sizes of its
@@ -242,3 +244,17 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerceRefl
+shsFromShX (SUnknown _ :$% _) = error "impossible"