aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Convert.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r--src/Data/Array/Nested/Convert.hs32
1 files changed, 26 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 01abae3..92bc3b4 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
@@ -16,6 +17,9 @@ module Data.Array.Nested.Convert (
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
mcast, mcastToShaped, mtoRanked,
+
+ -- * Additional index/shape casting functions
+ ixrFromIxS, shrFromShS,
) where
import Control.Category
@@ -28,6 +32,7 @@ import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
@@ -37,7 +42,7 @@ mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
mcast ssh2 arr
| Refl <- lemAppNil @sh1
, Refl <- lemAppNil @sh2
- = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
+ = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
mtoRanked = castCastable (CastXR CastId)
@@ -74,10 +79,25 @@ stoRanked sarr@(Shaped arr)
rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
- | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
+ | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped targetsh arr
+ixrFromIxS :: IIxS sh -> IIxR (Rank sh)
+ixrFromIxS ZIS = ZIR
+ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+
+-- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh
+-- ixsFromIxR = \ix -> go ix _
+-- where
+-- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r
+-- go ZIR k = k ZIS
+-- go (i :.: ix) k = go ix (i :.$)
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
-- | The constructors that perform runtime shape checking are marked with a
-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
@@ -122,20 +142,20 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
= let x' = go c x
- ssx' = ssxAppend (ssxFromShape esh)
- (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh))))
+ ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))
in M_Ranked (M_Nest esh (mcast ssx' x'))
go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
+ = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x)
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'