diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 10:40:28 +0200 | 
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 10:40:28 +0200 | 
| commit | ccbfa6fd2cd1225dfe9f0dc5a281437f3e302b15 (patch) | |
| tree | 6ee1e6e2f2bf46847c626d74e7ebc23779eca02a /src/Data/Array/Nested/Shaped | |
| parent | 3361aa23c6a415adf50194d69680d7d2f519b512 (diff) | |
Move shape conversion ops to Data.Array.Nested.Convert
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 20 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 25 | 
2 files changed, 18 insertions, 27 deletions
| diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index fa84efe..ddd44bf 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -5,6 +5,7 @@  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-} @@ -25,18 +26,19 @@ import Data.Coerce (coerce)  import Data.Kind (Type)  import Data.List.NonEmpty (NonEmpty)  import Data.Proxy +import Data.Type.Equality  import Foreign.Storable (Storable)  import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)  import GHC.Generics (Generic)  import GHC.TypeLits -import Data.Array.Nested.Types -import Data.Array.XArray (XArray)  import Data.Array.Nested.Lemmas  import Data.Array.Nested.Mixed  import Data.Array.Nested.Mixed.Shape  import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types  import Data.Array.Strided.Arith +import Data.Array.XArray (XArray)  -- | A shape-typed array: the full shape of the array (the sizes of its @@ -242,3 +244,17 @@ satan2Array = liftShaped2 matan2Array  sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh  sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = +  castWith (subst1 (lem Refl)) $ +    n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) +                                   idx) +  where +    lem :: forall sh1 sh' n. +           Just n : sh1 :~: MapJust sh' +        -> n : Tail sh' :~: sh' +    lem Refl = unsafeCoerceRefl +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0b7d1c9..fbfc7f5 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -230,14 +230,6 @@ ixsZero :: ShS sh -> IIxS sh  ixsZero ZSS = ZIS  ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh -  ixsHead :: IxS (n : sh) i -> i  ixsHead (IxS list) = getConst (listsHead list) @@ -321,23 +313,6 @@ shsToList :: ShS sh -> [Int]  shsToList ZSS = []  shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = -  castWith (subst1 (lem Refl)) $ -    n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) -                                   idx) -  where -    lem :: forall sh1 sh' n. -           Just n : sh1 :~: MapJust sh' -        -> n : Tail sh' :~: sh' -    lem Refl = unsafeCoerceRefl -shsFromShX (SUnknown _ :$% _) = error "impossible" - -shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -  shsHead :: ShS (n : sh) -> SNat n  shsHead (ShS list) = listsHead list | 
