diff options
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 51 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 18 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 27 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 20 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 25 |
5 files changed, 81 insertions, 60 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 73055db..b3a2c63 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -8,7 +8,9 @@ {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( -- * Shape/index/list casting functions - ixrFromIxS, shrFromShS, + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + ixsFromIxX, shsFromShX, + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, -- * Array conversions castCastable, @@ -27,6 +29,7 @@ module Data.Array.Nested.Convert ( import Control.Category import Data.Proxy import Data.Type.Equality +import GHC.TypeLits import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed @@ -39,10 +42,26 @@ import Data.Array.Nested.Types -- * Shape/index/list casting functions +-- * To ranked + ixrFromIxS :: IxS sh i -> IxR (Rank sh) i ixrFromIxS ZIS = ZIR ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported + +-- shrFromShX2 re-exported + +-- * To shaped + -- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh -- ixsFromIxR = \ix -> go ix _ -- where @@ -50,9 +69,33 @@ ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix -- go ZIR k = k ZIS -- go (i :.: ix) k = go ix (i :.$) -shrFromShS :: ShS sh -> IShR (Rank sh) -shrFromShS ZSS = ZSR -shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- shsFromShX re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -- * Array conversions diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index beb5b0e..babc809 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +26,7 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy +import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -35,12 +37,12 @@ import Data.Foldable (toList) #endif import Data.Array.Nested.Lemmas -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -252,3 +254,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 75a1e5b..326bf61 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -40,7 +40,6 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Nested.Lemmas -import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Types @@ -213,16 +212,6 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx - -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixxFromIxR idx) - ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -278,22 +267,6 @@ instance Show i => Show (ShR n i) where instance NFData i => NFData (ShR sh i) -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh - | Refl <- lemRankReplicate (Proxy @n) - = shrFromShX sh - -shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shxFromShR idx) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index fa84efe..ddd44bf 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,18 +26,19 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy +import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) -- | A shape-typed array: the full shape of the array (the sizes of its @@ -242,3 +244,17 @@ satan2Array = liftShaped2 matan2Array sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = + castWith (subst1 (lem Refl)) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) + where + lem :: forall sh1 sh' n. + Just n : sh1 :~: MapJust sh' + -> n : Tail sh' :~: sh' + lem Refl = unsafeCoerceRefl +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0b7d1c9..fbfc7f5 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -230,14 +230,6 @@ ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh - ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) @@ -321,23 +313,6 @@ shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shsFromShX (SUnknown _ :$% _) = error "impossible" - -shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh - shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list |