aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested.hs15
-rw-r--r--src/Data/Array/Nested/Convert.hs108
-rw-r--r--src/Data/Array/Nested/Mixed.hs25
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs6
-rw-r--r--src/Data/Array/Nested/Permutation.hs1
-rw-r--r--src/Data/Array/Nested/Ranked.hs34
-rw-r--r--src/Data/Array/Nested/Shaped.hs23
-rw-r--r--src/Data/Array/Nested/Trace.hs8
-rw-r--r--src/Data/Array/Nested/Types.hs3
9 files changed, 135 insertions, 88 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 1ad2559..bb22d29 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -10,8 +10,9 @@ module Data.Array.Nested (
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
rrerank,
- rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rfromListLinear, rfromListPrimLinear, rtoListLinear,
+ rreplicate, rreplicateScal,
+ rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear,
+ rtoList, rtoListOuter, rtoListLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest, rzip, runzip,
@@ -36,8 +37,9 @@ module Data.Array.Nested (
-- TODO: sconcat? What should its type be?
semptyArray,
srerank,
- sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sfromListLinear, sfromListPrimLinear, stoListLinear,
+ sreplicate, sreplicateScal,
+ sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear,
+ stoList, stoListOuter, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest, szip, sunzip,
@@ -63,8 +65,9 @@ module Data.Array.Nested (
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
mrerank,
- mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mfromListLinear, mfromListPrimLinear, mtoListLinear,
+ mreplicate, mreplicateScal,
+ mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear,
+ mtoList, mtoListOuter, mtoListLinear,
mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index b3a2c63..d07bab9 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
@@ -104,25 +105,53 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
-- 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Castable' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'XArray's. This leads to the inclusion of some operations that do
+-- not look like a cast at first glance, like 'CastZip'; with the underlying
+-- representation in mind, however, they are very much like a cast.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
-
- CastXR :: Elt b
- => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
-
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
-
- CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
- -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
+ CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a)
+ CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a)
+
+ CastXR :: Elt a
+ => Castable (Mixed sh a) (Ranked (Rank sh) a)
+ CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a)
+ CastXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Castable (Mixed sh a) (Shaped sh' a)
+
+ CastXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Castable (Mixed sh a) (Mixed sh' a)
+
+ CastRR :: Castable a b
+ -> Castable (Ranked n a) (Ranked n b)
+ CastSS :: Castable a b
+ -> Castable (Shaped sh a) (Shaped sh b)
+ CastXX :: Castable a b
+ -> Castable (Mixed sh a) (Mixed sh b)
+ CastT2 :: Castable a a'
+ -> Castable b b'
+ -> Castable (a, b) (a', b')
+
+ Cast0X :: Elt a
+ => Castable a (Mixed '[] a)
+ CastX0 :: Castable (Mixed '[] a) a
+
+ CastNest :: Elt a => StaticShX sh
+ -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ CastZip :: (Elt a, Elt b)
+ => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ CastUnzip :: (Elt a, Elt b)
+ => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Castable a b)
instance Category Castable where
id = CastId
@@ -139,25 +168,44 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
- go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
+ go CastRX (M_Ranked x) = x
+ go CastSX (M_Shaped x) = x
+ go (CastXR @_ @sh) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
- = let x' = go c x
- ssx' = ssxAppend (ssxFromShX esh)
- (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))
- in M_Ranked (M_Nest esh (mcast ssx' x'))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (CastXS' @sh @sh' sh') (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
- (go c x)))
+ x))
+ go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)
+ go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Cast0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go CastX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x)
+ go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go CastZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go CastUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
@@ -184,7 +232,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked = castCastable (CastXR CastId)
+mtoRanked = castCastable CastXR
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -198,7 +246,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
-mcastToShaped targetsh = castCastable (CastXS' targetsh CastId)
+mcastToShaped targetsh = castCastable (CastXS' targetsh)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 221393f..54f8fe6 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -784,14 +784,10 @@ mtoVector arr = mtoVectorP (toPrimitive arr)
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1 l)
mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromListPrim l =
@@ -804,10 +800,8 @@ mfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
mtoListLinear :: Elt a => Mixed sh a -> [a]
mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
@@ -821,8 +815,11 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 2f35ff9..bf14bf5 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -537,9 +537,15 @@ ssxHead (StaticShX list) = listxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
+ssxDropSSX :: forall sh sh'. StaticShX (sh ++ sh') -> StaticShX sh -> StaticShX sh'
+ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+
ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropSh :: forall sh sh' i. StaticShX (sh ++ sh') -> ShX sh i -> StaticShX sh'
+ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
ssxInit = coerce (listxInit @(SMayNat () SNat))
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 031755f..bed2877 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -4,7 +4,6 @@
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index e5b8970..97b4c7c 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -137,38 +137,32 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
rfromList1 l = Ranked (mfromList1 l)
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
-
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = rreshape sh (rfromList1 l)
rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+rfromListPrim l = Ranked (mfromListPrim l)
rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
rfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoListLinear :: Elt a => Ranked n a -> [a]
rtoListLinear (Ranked arr) = mtoListLinear arr
@@ -197,7 +191,7 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
| Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked arr
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 01982a8..0275aad 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -123,20 +123,14 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromListPrim sn l
@@ -150,8 +144,11 @@ sfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
stoListLinear :: Elt a => Shaped sh a -> [a]
stoListLinear (Shaped arr) = mtoListLinear arr
@@ -176,7 +173,7 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
| Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
= Shaped arr
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 838e2b0..3581f10 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -37,10 +37,12 @@ module Data.Array.Nested.Trace (
ShS(..), KnownShS(..),
Mixed,
+ ListX(ZX, (::%)),
IxX(..), IIxX,
- ShX(..), KnownShX(..),
+ ShX(..), KnownShX(..), IShX,
StaticShX(..),
SMayNat(..),
+ Castable(..),
Elt,
PrimElt,
@@ -54,7 +56,7 @@ module Data.Array.Nested.Trace (
Perm(..),
IsPermutation,
KnownPerm(..),
- NumElt, FloatElt,
+ NumElt, IntElt, FloatElt,
Rank, Product,
Replicate,
MapJust,
@@ -67,4 +69,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'castCastable, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index 4172fa0..b8a9aea 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
@@ -111,7 +112,7 @@ type family Replicate n a where
lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = unsafeCoerceRefl
-type family MapJust l where
+type family MapJust l = r | r -> l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs