aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Convert.hs28
1 files changed, 14 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index cdd2b6d..17ccc4d 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -5,12 +5,22 @@
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Convert where
+module Data.Array.Nested.Convert (
+ castCastable,
+ Castable(..),
+
+ -- * Special cases
+ --
+ -- | These functions can all be implemented using 'castCastable' in some way,
+ -- but some have fewer constraints.
+ rtoMixed, rcastToMixed, rcastToShaped,
+ stoMixed, scastToMixed, stoRanked,
+ mcast, mcastToShaped, mtoRanked,
+) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
-import GHC.TypeLits (Nat)
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Types
@@ -30,15 +40,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked arr
- | Refl <- lemRankReplicate (shxRank (mshape arr))
- = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
- where
- convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
- convSh ZSX = ZSX
- convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
- = SUnknown (fromSMayNat' smn) :$% convSh sh
+mtoRanked = castCastable (CastXR CastId)
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -52,9 +54,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> Mixed sh a -> ShS sh' -> Shaped sh' a
-mcastToShaped arr targetsh
- | Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
+mcastToShaped arr targetsh = castCastable (CastXS' targetsh CastId) arr
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr