diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 10:40:28 +0200 | 
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 10:40:28 +0200 | 
| commit | ccbfa6fd2cd1225dfe9f0dc5a281437f3e302b15 (patch) | |
| tree | 6ee1e6e2f2bf46847c626d74e7ebc23779eca02a /src/Data/Array | |
| parent | 3361aa23c6a415adf50194d69680d7d2f519b512 (diff) | |
Move shape conversion ops to Data.Array.Nested.Convert
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 51 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 18 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 27 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 20 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 25 | 
5 files changed, 81 insertions, 60 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 73055db..b3a2c63 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -8,7 +8,9 @@  {-# LANGUAGE TypeOperators #-}  module Data.Array.Nested.Convert (    -- * Shape/index/list casting functions -  ixrFromIxS, shrFromShS, +  ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, +  ixsFromIxX, shsFromShX, +  ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,    -- * Array conversions    castCastable, @@ -27,6 +29,7 @@ module Data.Array.Nested.Convert (  import Control.Category  import Data.Proxy  import Data.Type.Equality +import GHC.TypeLits  import Data.Array.Nested.Lemmas  import Data.Array.Nested.Mixed @@ -39,10 +42,26 @@ import Data.Array.Nested.Types  -- * Shape/index/list casting functions +-- * To ranked +  ixrFromIxS :: IxS sh i -> IxR (Rank sh) i  ixrFromIxS ZIS = ZIR  ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported + +-- shrFromShX2 re-exported + +-- * To shaped +  -- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh  -- ixsFromIxR = \ix -> go ix _  --   where @@ -50,9 +69,33 @@ ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix  --     go ZIR k = k ZIS  --     go (i :.: ix) k = go ix (i :.$) -shrFromShS :: ShS sh -> IShR (Rank sh) -shrFromShS ZSS = ZSR -shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- shsFromShX re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = +  castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) +    (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = +  castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) +    (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh  -- * Array conversions diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index beb5b0e..babc809 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -5,6 +5,7 @@  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +26,7 @@ import Data.Coerce (coerce)  import Data.Kind (Type)  import Data.List.NonEmpty (NonEmpty)  import Data.Proxy +import Data.Type.Equality  import Foreign.Storable (Storable)  import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)  import GHC.Generics (Generic) @@ -35,12 +37,12 @@ import Data.Foldable (toList)  #endif  import Data.Array.Nested.Lemmas -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..))  import Data.Array.Nested.Mixed  import Data.Array.Nested.Mixed.Shape  import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types  import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..))  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -252,3 +254,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr)  rrank :: Elt a => Ranked n a -> SNat n  rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh +  | Refl <- lemRankReplicate (Proxy @n) +  = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 75a1e5b..326bf61 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -40,7 +40,6 @@ import GHC.TypeLits  import GHC.TypeNats qualified as TN  import Data.Array.Nested.Lemmas -import Data.Array.Nested.Mixed.Shape  import Data.Array.Nested.Types @@ -213,16 +212,6 @@ ixrZero :: SNat n -> IIxR n  ixrZero SZ = ZIR  ixrZero (SS n) = 0 :.: ixrZero n -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx - -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = -  castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) -    (n :.% ixxFromIxR idx) -  ixrHead :: IxR (n + 1) i -> i  ixrHead (IxR list) = listrHead list @@ -278,22 +267,6 @@ instance Show i => Show (ShR n i) where  instance NFData i => NFData (ShR sh i) -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh -  | Refl <- lemRankReplicate (Proxy @n) -  = shrFromShX sh - -shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = -  castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) -    (SUnknown n :$% shxFromShR idx) -  -- | This checks only whether the ranks are equal, not whether the actual  -- values are.  shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index fa84efe..ddd44bf 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -5,6 +5,7 @@  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-} @@ -25,18 +26,19 @@ import Data.Coerce (coerce)  import Data.Kind (Type)  import Data.List.NonEmpty (NonEmpty)  import Data.Proxy +import Data.Type.Equality  import Foreign.Storable (Storable)  import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)  import GHC.Generics (Generic)  import GHC.TypeLits -import Data.Array.Nested.Types -import Data.Array.XArray (XArray)  import Data.Array.Nested.Lemmas  import Data.Array.Nested.Mixed  import Data.Array.Nested.Mixed.Shape  import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types  import Data.Array.Strided.Arith +import Data.Array.XArray (XArray)  -- | A shape-typed array: the full shape of the array (the sizes of its @@ -242,3 +244,17 @@ satan2Array = liftShaped2 matan2Array  sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh  sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = +  castWith (subst1 (lem Refl)) $ +    n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) +                                   idx) +  where +    lem :: forall sh1 sh' n. +           Just n : sh1 :~: MapJust sh' +        -> n : Tail sh' :~: sh' +    lem Refl = unsafeCoerceRefl +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0b7d1c9..fbfc7f5 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -230,14 +230,6 @@ ixsZero :: ShS sh -> IIxS sh  ixsZero ZSS = ZIS  ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh -  ixsHead :: IxS (n : sh) i -> i  ixsHead (IxS list) = getConst (listsHead list) @@ -321,23 +313,6 @@ shsToList :: ShS sh -> [Int]  shsToList ZSS = []  shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = -  castWith (subst1 (lem Refl)) $ -    n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) -                                   idx) -  where -    lem :: forall sh1 sh' n. -           Just n : sh1 :~: MapJust sh' -        -> n : Tail sh' :~: sh' -    lem Refl = unsafeCoerceRefl -shsFromShX (SUnknown _ :$% _) = error "impossible" - -shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -  shsHead :: ShS (n : sh) -> SNat n  shsHead (ShS list) = listsHead list | 
