aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Convert.hs51
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs18
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs27
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs20
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs25
5 files changed, 81 insertions, 60 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 73055db..b3a2c63 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -8,7 +8,9 @@
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Convert (
-- * Shape/index/list casting functions
- ixrFromIxS, shrFromShS,
+ ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
+ ixsFromIxX, shsFromShX,
+ ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
-- * Array conversions
castCastable,
@@ -27,6 +29,7 @@ module Data.Array.Nested.Convert (
import Control.Category
import Data.Proxy
import Data.Type.Equality
+import GHC.TypeLits
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
@@ -39,10 +42,26 @@ import Data.Array.Nested.Types
-- * Shape/index/list casting functions
+-- * To ranked
+
ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
ixrFromIxS ZIS = ZIR
ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX ZIX = ZIR
+ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
+-- shrFromShX re-exported
+
+-- shrFromShX2 re-exported
+
+-- * To shaped
+
-- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh
-- ixsFromIxR = \ix -> go ix _
-- where
@@ -50,9 +69,33 @@ ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
-- go ZIR k = k ZIS
-- go (i :.: ix) k = go ix (i :.$)
-shrFromShS :: ShS sh -> IShR (Rank sh)
-shrFromShS ZSS = ZSR
-shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX ZSS ZIX = ZIS
+ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+
+-- shsFromShX re-exported
+
+-- * To mixed
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR ZIR = ZIX
+ixxFromIxR (n :.: (idx :: IxR m i)) =
+ castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixxFromIxR idx)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS ZIS = ZIX
+ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+
+shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
+shxFromShR ZSR = ZSX
+shxFromShR (n :$: (idx :: ShR m i)) =
+ castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shxFromShR idx)
+
+shxFromShS :: ShS sh -> IShX (MapJust sh)
+shxFromShS ZSS = ZSX
+shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-- * Array conversions
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index beb5b0e..babc809 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,6 +26,7 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -35,12 +37,12 @@ import Data.Foldable (toList)
#endif
import Data.Array.Nested.Lemmas
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray(..))
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -252,3 +254,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
+shrFromShX ZSX = ZSR
+shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
+shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
+shrFromShX2 sh
+ | Refl <- lemRankReplicate (Proxy @n)
+ = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 75a1e5b..326bf61 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -40,7 +40,6 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Nested.Lemmas
-import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Types
@@ -213,16 +212,6 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
-ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX ZIX = ZIR
-ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
-
-ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
-ixxFromIxR ZIR = ZIX
-ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixxFromIxR idx)
-
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -278,22 +267,6 @@ instance Show i => Show (ShR n i) where
instance NFData i => NFData (ShR sh i)
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh
-
-shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
-shxFromShR ZSR = ZSX
-shxFromShR (n :$: (idx :: ShR m i)) =
- castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shxFromShR idx)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index fa84efe..ddd44bf 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,18 +26,19 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
-- | A shape-typed array: the full shape of the array (the sizes of its
@@ -242,3 +244,17 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerceRefl
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0b7d1c9..fbfc7f5 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -230,14 +230,6 @@ ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX ZSS ZIX = ZIS
-ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
-
-ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
-ixxFromIxS ZIS = ZIX
-ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
-
ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)
@@ -321,23 +313,6 @@ shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shsFromShX (SUnknown _ :$% _) = error "impossible"
-
-shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list