aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs37
-rw-r--r--src/Data/Array/Nested/Convert.hs361
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs59
-rw-r--r--src/Data/Array/Nested/Lemmas.hs (renamed from src/Data/Array/Mixed/Lemmas.hs)100
-rw-r--r--src/Data/Array/Nested/Mixed.hs53
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs88
-rw-r--r--src/Data/Array/Nested/Permutation.hs22
-rw-r--r--src/Data/Array/Nested/Ranked.hs46
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs20
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs57
-rw-r--r--src/Data/Array/Nested/Shaped.hs37
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs17
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs61
-rw-r--r--src/Data/Array/Nested/Trace.hs8
-rw-r--r--src/Data/Array/Nested/Types.hs11
-rw-r--r--src/Data/Array/XArray.hs6
16 files changed, 588 insertions, 395 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9801529..c3635e9 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -10,8 +10,9 @@ module Data.Array.Nested (
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
rrerank,
- rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rfromListLinear, rfromListPrimLinear, rtoListLinear,
+ rreplicate, rreplicateScal,
+ rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear,
+ rtoList, rtoListOuter, rtoListLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest, rzip, runzip,
@@ -19,7 +20,7 @@ module Data.Array.Nested (
rlift, rlift2,
-- ** Conversions
rtoXArrayPrim, rfromXArrayPrim,
- rcastToShaped, rtoMixed, rcastToMixed,
+ rtoMixed, rcastToMixed, rcastToShaped,
rfromOrthotope, rtoOrthotope,
-- ** Additional arithmetic operations
--
@@ -36,8 +37,9 @@ module Data.Array.Nested (
-- TODO: sconcat? What should its type be?
semptyArray,
srerank,
- sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sfromListLinear, sfromListPrimLinear, stoListLinear,
+ sreplicate, sreplicateScal,
+ sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear,
+ stoList, stoListOuter, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest, szip, sunzip,
@@ -45,7 +47,7 @@ module Data.Array.Nested (
slift, slift2,
-- ** Conversions
stoXArrayPrim, sfromXArrayPrim,
- stoRanked, stoMixed, scastToMixed,
+ stoMixed, scastToMixed, stoRanked,
sfromOrthotope, stoOrthotope,
-- ** Additional arithmetic operations
--
@@ -63,8 +65,9 @@ module Data.Array.Nested (
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
mrerank,
- mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mfromListLinear, mfromListPrimLinear, mtoListLinear,
+ mreplicate, mreplicateScal,
+ mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear,
+ mtoList, mtoListOuter, mtoListLinear,
mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
@@ -73,8 +76,8 @@ module Data.Array.Nested (
-- ** Conversions
mtoXArrayPrim, mfromXArrayPrim,
mcast,
- mtoRanked, mcastToShaped,
- castCastable, Castable(..),
+ mcastToShaped, mtoRanked,
+ convert, Conversion(..),
-- ** Additional arithmetic operations
--
-- $integralRealFloat
@@ -102,23 +105,23 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed
-import Data.Array.Nested.Ranked
-import Data.Array.Nested.Shaped
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
import Foreign.Storable
import GHC.TypeLits
-- $integralRealFloat
--
--- These functions separate top-level functions, and not exposed in instances
--- for 'RealFloat' and 'Integral', because those classes include a variety of
--- other functions that make no sense for arrays.
+-- These functions are separate top-level functions, and not exposed in
+-- instances for 'RealFloat' and 'Integral', because those classes include a
+-- variety of other functions that make no sense for arrays.
-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but
-- having 'Num', 'Fractional' and 'Floating' available is just too useful.
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index d5e6008..2438f68 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -1,42 +1,291 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Data.Array.Nested.Convert (
- castCastable,
- Castable(..),
+ -- * Shape\/index\/list casting functions
+ -- ** To ranked
+ ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
+ listrCast, ixrCast, shrCast,
+ -- ** To shaped
+ ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
+ ixsCast,
+ -- ** To mixed
+ ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
+ ixxCast, shxCast, shxCast',
- -- * Special cases
+ -- * Array conversions
+ convert,
+ Conversion(..),
+
+ -- * Special cases of array conversions
--
- -- | These functions can all be implemented using 'castCastable' in some way,
+ -- | These functions can all be implemented using 'convert' in some way,
-- but some have fewer constraints.
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
mcast, mcastToShaped, mtoRanked,
-
- -- * Additional index/shape casting functions
- ixrFromIxS, shrFromShS,
) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
+import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Types
-import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+-- * Shape or index or list casting functions
+
+-- * To ranked
+
+ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
+ixrFromIxS ZIS = ZIR
+ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX ZIX = ZIR
+ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
+-- shrFromShX re-exported
+-- shrFromShX2 re-exported
+-- listrCast re-exported
+-- ixrCast re-exported
+-- shrCast re-exported
+
+-- * To shaped
+
+-- TODO: these take a ShS because there are KnownNats inside IxS.
+
+ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
+ixsFromIxR ZSS ZIR = ZIS
+ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+ixsFromIxR _ _ = error "unreachable"
+
+-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
+-- following, but more efficient:
+--
+-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx)
+ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i
+ixsFromIxR' ZSS ZIR = ZIS
+ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
+ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
+
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX ZSS ZIX = ZIS
+ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+
+-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
+-- the following, but more efficient:
+--
+-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx)
+ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i
+ixsFromIxX' ZSS ZIX = ZIS
+ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx
+ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank"
+
+-- | Produce an existential 'ShS' from an 'IShR'.
+withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r
+withShsFromShR ZSR k = k ZSS
+withShsFromShR (n :$: sh) k =
+ withShsFromShR sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
+
+-- shsFromShX re-exported
+
+-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
+-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
+withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r
+withShsFromShX ZSX k = k ZSS
+withShsFromShX (SKnown sn@SNat :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ k (sn :$$ sh')
+withShsFromShX (SUnknown n :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+
+shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
+shsFromSSX = shsFromShX Prelude.. shxFromSSX
+
+-- ixsCast re-exported
+
+-- * To mixed
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR ZIR = ZIX
+ixxFromIxR (n :.: (idx :: IxR m i)) =
+ castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixxFromIxR idx)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS ZIS = ZIX
+ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+
+shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
+shxFromShR ZSR = ZSX
+shxFromShR (n :$: (idx :: ShR m i)) =
+ castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shxFromShR idx)
+
+shxFromShS :: ShS sh -> IShX (MapJust sh)
+shxFromShS ZSS = ZSX
+shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+
+-- ixxCast re-exported
+-- shxCast re-exported
+-- shxCast' re-exported
+
+
+-- * Array conversions
+
+-- | The constructors that perform runtime shape checking are marked with a
+-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types
+-- ensure that the shapes are already compatible. To convert between 'Ranked'
+-- and 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Conversion' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'XArray's. This leads to the inclusion of some operations that do
+-- not look like simple conversions (casts) at first glance, like 'ConvZip'.
+--
+-- /Note/: Haddock gleefully renames type variables in constructors so that
+-- they match the data type head as much as possible. See the source for a more
+-- readable presentation of this data type.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Conversion (Mixed sh a) (Shaped sh' a)
+
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
+ where
+ -- The 'esh' is the extension shape: the conversion happens under a whole
+ -- bunch of additional dimensions that it does not touch. These dimensions
+ -- are 'esh'.
+ -- The strategy is to unwind step-by-step to a large Mixed array, and to
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
+ x))
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go ConvX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x)
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go ConvZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go ConvUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
+
+ lemRankAppRankEq :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
+ lemRankAppRankEq _ _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
+-- * Special cases of array conversions
+
mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
=> StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
mcast ssh2 arr
@@ -45,7 +294,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked = castCastable (CastXR CastId)
+mtoRanked = convert ConvXR
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -59,7 +308,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
-mcastToShaped targetsh = castCastable (CastXS' targetsh CastId)
+mcastToShaped targetsh = convert (ConvXS' targetsh)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
@@ -82,91 +331,3 @@ rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped targetsh arr
-
-ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
-ixrFromIxS ZIS = ZIR
-ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
-
--- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh
--- ixsFromIxR = \ix -> go ix _
--- where
--- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r
--- go ZIR k = k ZIS
--- go (i :.: ix) k = go ix (i :.$)
-
-shrFromShS :: ShS sh -> IShR (Rank sh)
-shrFromShS ZSS = ZSR
-shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
-
--- | The constructors that perform runtime shape checking are marked with a
--- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
--- that the shapes are already compatible. To convert between 'Ranked' and
--- 'Shaped', go via 'Mixed'.
-data Castable a b where
- CastId :: Castable a a
- CastCmp :: Castable b c -> Castable a b -> Castable a c
-
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
-
- CastXR :: Elt b
- => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
-
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
-
- CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
- -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
-
-instance Category Castable where
- id = CastId
- (.) = CastCmp
-
-castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
-castCastable = \c x -> munScalar (go c (mscalar x))
- where
- -- The 'esh' is the extension shape: the casting happens under a whole
- -- bunch of additional dimensions that it does not touch. These dimensions
- -- are 'esh'.
- -- The strategy is to unwind step-by-step to a large Mixed array, and to
- -- perform the required checks and castings when re-nesting back up.
- go :: Castable a b -> Mixed esh a -> Mixed esh b
- go CastId x = x
- go (CastCmp c1 c2) x = go c1 (go c2 x)
- go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
- = let x' = go c x
- ssx' = ssxAppend (ssxFromShX esh)
- (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))
- in M_Ranked (M_Nest esh (mcast ssx' x'))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
- (go c x)))
- go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
- go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
- go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)
-
- lemRankAppRankEq :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
- lemRankAppRankEq _ _ _ = unsafeCoerceRefl
-
- lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
- -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
- lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
-
- lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
deleted file mode 100644
index b1589e0..0000000
--- a/src/Data/Array/Nested/Internal/Lemmas.hs
+++ /dev/null
@@ -1,59 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Lemmas where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Shaped.Shape
-
-
-lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
-lemRankMapJust ZSS = Refl
-lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
-
-lemMapJustApp :: ShS sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustApp ZSS _ = Refl
-lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
-
-lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
-lemTakeLenMapJust PNil _ = Refl
-lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
-lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
-
-lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
-lemDropLenMapJust PNil _ = Refl
-lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
-lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
-
-lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
-lemIndexMapJust SZ (_ :$$ _) = Refl
-lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
- | Refl <- lemIndexMapJust i sh
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = Refl
-lemIndexMapJust _ ZSS = error "Index of empty"
-
-lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
-lemPermuteMapJust PNil _ = Refl
-lemPermuteMapJust (i `PCons` is) sh
- | Refl <- lemPermuteMapJust is sh
- , Refl <- lemIndexMapJust i sh
- = Refl
-
-lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
-lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKX
- go (n :$$ sh) = SKnown n :!% go sh
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index e6d970c..8cac298 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -6,7 +6,7 @@
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Lemmas where
+module Data.Array.Nested.Lemmas where
import Data.Proxy
import Data.Type.Equality
@@ -14,10 +14,11 @@ import GHC.TypeLits
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types
--- * Lemmas
+-- * Lemmas about numbers and lists
-- ** Nat
@@ -27,7 +28,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
-
-- ** Append
lemAppNil :: l ++ '[] :~: l
@@ -39,31 +39,7 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl
lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl
-
--- ** Rank
-
-lemRankApp :: forall sh1 sh2.
- StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
-lemRankApp ZKX _ = Refl
-lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
- = lem (Proxy @(Rank sh1T)) Proxy Proxy $
- sym (lemRankApp ssh1 ssh2)
- where
- lem :: proxy a -> proxy b -> proxy c
- -> (a + b :~: c)
- -> c + 1 :~: (a + 1 + b)
- lem _ _ _ Refl = Refl
-
-lemRankAppComm :: proxy sh1 -> proxy sh2
- -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
-lemRankAppComm _ _ = unsafeCoerceRefl
-
-lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = unsafeCoerceRefl
-
-
--- ** Various type families
+-- ** Simple type families
lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
@@ -107,6 +83,8 @@ lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+-- * Lemmas about shapes
+
-- ** Known shapes
lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
@@ -116,3 +94,69 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
+
+-- ** Rank
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem (Proxy @(Rank sh1T)) Proxy Proxy $
+ sym (lemRankApp ssh1 ssh2)
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem _ _ _ Refl = Refl
+
+lemRankAppComm :: proxy sh1 -> proxy sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl
+
+lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = unsafeCoerceRefl
+
+lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust ZSS = Refl
+lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
+
+-- ** Related to MapJust and/or Permutation
+
+lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemTakeLenMapJust PNil _ = Refl
+lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
+lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemDropLenMapJust PNil _ = Refl
+lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
+lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemIndexMapJust SZ (_ :$$ _) = Refl
+lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemIndexMapJust i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemIndexMapJust _ ZSS = error "Index of empty"
+
+lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemPermuteMapJust PNil _ = Refl
+lemPermuteMapJust (i `PCons` is) sh
+ | Refl <- lemPermuteMapJust is sh
+ , Refl <- lemIndexMapJust i sh
+ = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 54bd5f2..144230e 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -42,13 +42,13 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
import Data.Array.XArray (XArray(..))
import Data.Array.XArray qualified as X
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Strided.Orthotope
import Data.Bag
@@ -380,9 +380,7 @@ class Elt a where
-- of this class with those of 'Elt': some instances have an additional
-- "known-shape" constraint.
--
--- This class is (currently) only required for 'mgenerate',
--- 'Data.Array.Nested.Ranked.rgenerate' and
--- 'Data.Array.Nested.Shaped.sgenerate'.
+-- This class is (currently) only required for `memptyArray` and 'mgenerate'.
class Elt a => KnownElt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArrayUnsafe :: IShX sh -> Mixed sh a
@@ -398,7 +396,7 @@ class Elt a => KnownElt a where
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
@@ -440,7 +438,7 @@ instance Storable a => Elt (Primitive a) where
=> StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
mtranspose perm (M_Primitive sh arr) =
@@ -557,13 +555,13 @@ instance Elt a => Elt (Mixed sh' a) where
= fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
+ mindex (M_Nest _ arr) = mindexPartial arr
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
@@ -632,14 +630,14 @@ instance Elt a => Elt (Mixed sh' a) where
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
= let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
-> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ | let sh' = shxDropSh @sh @sh' sh (mshape arr)
, Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
, Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
, Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
@@ -784,14 +782,10 @@ mtoVector arr = mtoVectorP (toPrimitive arr)
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1 l)
mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromListPrim l =
@@ -804,10 +798,8 @@ mfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
mtoListLinear :: Elt a => Mixed sh a -> [a]
mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
@@ -821,8 +813,11 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
@@ -832,13 +827,13 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
-> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
-> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ let sh1 = shxDropSSX ssh sh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2)
(X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
(\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
arr)
--- | See the caveats at @X.rerank@.
+-- | See the caveats at 'Data.Array.XArray.rerank'.
mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
=> StaticShX sh -> IShX sh2
-> (Mixed sh1 a -> Mixed sh2 b)
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 2f35ff9..852dd5e 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -31,7 +31,6 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
import GHC.Exts (withDict)
import GHC.Generics (Generic)
@@ -146,9 +145,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f
+listxDrop ZX long = long
+listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'
listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
@@ -172,7 +171,7 @@ listxZipWith f (i ::% is) (j ::% js) =
-- * Mixed indices
--- | This is a newtype over 'ListX'.
+-- | An index into a mixed-typed array.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
@@ -191,6 +190,8 @@ infixr 3 :.%
{-# COMPLETE ZIX, (:.%) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxX sh = IxX sh Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -234,7 +235,7 @@ ixxTail (IxX list) = IxX (listxTail list)
ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
ixxAppend = coerce (listxAppend @_ @(Const i))
-ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i
ixxDrop = coerce (listxDrop @(Const i) @(Const i))
ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
@@ -243,6 +244,11 @@ ixxInit = coerce (listxInit @(Const i))
ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))
+ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
+ixxCast ZKX ZIX = ZIX
+ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx
+ixxCast _ _ = error "ixxCast: ranks don't match"
+
ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
ixxZip ZIX ZIX = ZIX
ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
@@ -390,10 +396,10 @@ shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
-shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
+shxFromList :: StaticShX sh -> [Int] -> IShX sh
shxFromList topssh topl = go topssh topl
where
- go :: StaticShX sh' -> [Int] -> ShX sh' Int
+ go :: StaticShX sh' -> [Int] -> IShX sh'
go ZKX [] = ZSX
go (SKnown sn :!% sh) (i : is)
| i == fromSNat' sn = SKnown sn :$% go sh is
@@ -408,11 +414,18 @@ shxToList :: IShX sh -> [Int]
shxToList ZSX = []
shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
+shxFromSSX ZKX = ZSX
+shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
+ | Refl <- lemMapJustCons @sh Refl
+ = SKnown n :$% shxFromSSX sh
+shxFromSSX (SUnknown _ :!% _) = error "unreachable"
+
-- | This may fail if @sh@ has @Nothing@s in it.
-shxFromSSX' :: StaticShX sh -> Maybe (IShX sh)
-shxFromSSX' ZKX = Just ZSX
-shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh
-shxFromSSX' (SUnknown _ :!% _) = Nothing
+shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i)
+shxFromSSX2 ZKX = Just ZSX
+shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
+shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
@@ -423,13 +436,13 @@ shxHead (ShX list) = listxHead list
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
-shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
-shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
@@ -438,12 +451,9 @@ shxInit = coerce (listxInit @(SMayNat i SNat))
shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
shxLast = coerce (listxLast @(SMayNat i SNat))
-shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shxTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
-> ShX sh i -> ShX sh j -> ShX sh k
@@ -456,7 +466,7 @@ shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
-shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
@@ -467,17 +477,17 @@ shxEnum = \sh -> go sh id []
go ZSX f = (f ZIX :)
go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
-shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
-shxCast ZSX ZKX = Just ZSX
-shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
-shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
+shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
+shxCast ZKX ZSX = Just ZSX
+shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
+shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh
shxCast _ _ = Nothing
-- | Partial version of 'shxCast'.
-shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
-shxCast' sh ssh = case shxCast sh ssh of
+shxCast' :: StaticShX sh' -> IShX sh -> IShX sh'
+shxCast' ssh sh = case shxCast ssh sh of
Just sh' -> sh'
Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
@@ -537,9 +547,15 @@ ssxHead (StaticShX list) = listxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+
+ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
ssxInit = coerce (listxInit @(SMayNat () SNat))
@@ -552,11 +568,11 @@ ssxReplicate (SS (n :: SNat n'))
| Refl <- lemReplicateSucc @(Nothing @Nat) @n'
= SUnknown () :!% ssxReplicate n
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom :: StaticShX sh -> Int -> [Int]
+ssxIotaFrom ZKX _ = []
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1)
-ssxFromShX :: IShX sh -> StaticShX sh
+ssxFromShX :: ShX sh i -> StaticShX sh
ssxFromShX ZSX = ZKX
ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh
@@ -574,7 +590,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
-withKnownShX k = withDict @(KnownShX sh) k
+withKnownShX = withDict @(KnownShX sh)
-- * Flattening
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 031755f..03d1640 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -4,7 +4,6 @@
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,6 +24,7 @@ import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
+import GHC.Exts (withDict)
import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
@@ -36,8 +36,8 @@ import Data.Array.Nested.Types
-- * Permutations
-- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
+-- dimension list is most similar to @backpermute@ in the @vector@ package; see
+-- 'Permute' for code that implements this.
data Perm list where
PNil :: Perm '[]
PCons :: SNat a -> Perm l -> Perm (a : l)
@@ -45,6 +45,13 @@ infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
+instance TestEquality Perm where
+ testEquality PNil PNil = Just Refl
+ testEquality (x `PCons` xs) (y `PCons` ys)
+ | Just Refl <- testEquality x y
+ , Just Refl <- testEquality xs ys = Just Refl
+ testEquality _ _ = Nothing
+
permRank :: Perm list -> SNat (Rank list)
permRank PNil = SNat
permRank (_ `PCons` l) | SNat <- permRank l = SNat
@@ -119,6 +126,9 @@ class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r
+withKnownPerm = withDict @(KnownPerm l)
+
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -224,7 +234,7 @@ permInverse = \perm k ->
++ " ; invperm = " ++ show invperm)
(permCheckPermutation invperm
(k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
+ (\ssh -> case permCheckInverse perm invperm ssh of
Just eq -> eq
Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
++ " ; invperm = " ++ show invperm)))
@@ -238,9 +248,9 @@ permInverse = \perm k ->
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ permCheckInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
+ permCheckInverse perm perminv ssh =
ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index e5c51ef..9778c54 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -29,17 +29,17 @@ import Foreign.Storable (Storable)
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray(..))
-import Data.Array.XArray qualified as X
import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
remptyArray :: KnownElt a => Ranked 1 a
@@ -137,38 +137,32 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
rfromList1 l = Ranked (mfromList1 l)
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
-
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = rreshape sh (rfromList1 l)
rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+rfromListPrim l = Ranked (mfromListPrim l)
rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
rfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoListLinear :: Elt a => Ranked n a -> [a]
rtoListLinear (Ranked arr) = mtoListLinear arr
@@ -197,7 +191,7 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
| Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked arr
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
@@ -230,7 +224,7 @@ rrerankP sn sh2 f (Ranked arr)
-- then:
--
-- @
--- rrerank _ _ _ f arr :: Ranked 5 Float
+-- rrerank _ _ _ f arr :: Ranked 6 Float
-- @
--
-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index f50f671..babc809 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,6 +26,7 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -34,13 +36,13 @@ import GHC.TypeLits
import Data.Foldable (toList)
#endif
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray(..))
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -252,3 +254,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
+shrFromShX ZSX = ZSR
+shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
+shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
+shrFromShX2 sh
+ | Refl <- lemRankReplicate (Proxy @n)
+ = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index c0c4f17..8b670e5 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -39,11 +39,12 @@ import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Types
+-- * Ranked lists
+
type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
@@ -130,6 +131,10 @@ listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
listrLast ZR = error "unreachable"
+-- | Performs a runtime check that the lengths are identical.
+listrCast :: SNat n' -> ListR n i -> ListR n' i
+listrCast = listrCastWithName "listrCast"
+
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
@@ -172,6 +177,8 @@ listrPermutePrefix = \perm sh ->
GTI -> error "listrPermutePrefix: Index in permutation out of range"
+-- * Ranked indices
+
-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
@@ -192,6 +199,8 @@ infixr 3 :.:
{-# COMPLETE ZIR, (:.:) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxR n = IxR n Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -213,16 +222,6 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
-ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX ZIX = ZIR
-ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
-
-ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
-ixxFromIxR ZIR = ZIX
-ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixxFromIxR idx)
-
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -235,6 +234,10 @@ ixrInit (IxR list) = IxR (listrInit list)
ixrLast :: IxR (n + 1) i -> i
ixrLast (IxR list) = listrLast list
+-- | Performs a runtime check that the lengths are identical.
+ixrCast :: SNat n' -> IxR n i -> IxR n' i
+ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
+
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
@@ -248,6 +251,8 @@ ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- * Ranked shapes
+
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
@@ -278,22 +283,6 @@ instance Show i => Show (ShR n i) where
instance NFData i => NFData (ShR sh i)
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh
-
-shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
-shxFromShR ZSR = ZSX
-shxFromShR (n :$: (idx :: ShR m i)) =
- castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shxFromShR idx)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
@@ -329,6 +318,10 @@ shrInit (ShR list) = ShR (listrInit list)
shrLast :: ShR (n + 1) i -> i
shrLast (ShR list) = listrLast list
+-- | Performs a runtime check that the lengths are identical.
+shrCast :: SNat n' -> ShR n i -> ShR n' i
+shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
@@ -366,3 +359,11 @@ instance KnownNat n => IsList (ShR n i) where
type Item (ShR n i) = i
fromList = ShR . IsList.fromList
toList = Foldable.toList
+
+
+-- * Internal helper functions
+
+listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
+listrCastWithName _ SZ ZR = ZR
+listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 7e38aee..198a068 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -29,18 +29,17 @@ import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
-import Data.Array.XArray qualified as X
-import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+import Data.Array.XArray qualified as X
semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
@@ -124,20 +123,14 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromListPrim sn l
@@ -151,8 +144,11 @@ sfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
stoListLinear :: Elt a => Shaped sh a -> [a]
stoListLinear (Shaped arr) = mtoListLinear arr
@@ -177,7 +173,7 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
| Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
= Shaped arr
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
@@ -190,11 +186,12 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
srerankP sh sh2 f sarr@(Shaped arr)
| Refl <- lemMapJustApp sh (Proxy @sh1)
, Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh))))
+ = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr))))
(shxFromShS sh2)
(\a -> let Shaped r = f (Shaped a) in r)
arr)
+-- | See the caveats at 'Data.Array.XArray.rerank'.
srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
=> ShS sh -> ShS sh2
-> (Shaped sh1 a -> Shaped sh2 b)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 529ac21..879e6b5 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,18 +26,19 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
-import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
-- | A shape-typed array: the full shape of the array (the sizes of its
@@ -242,3 +244,12 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0b7d1c9..5f9ba79 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,12 +1,9 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -48,6 +45,10 @@ import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
+-- * Shaped lists
+
+-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
+-- removed in a future release.
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
@@ -180,11 +181,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape"
listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+-- * Shaped indices
-- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
@@ -193,6 +192,8 @@ newtype IxS sh i = IxS (ListS sh (Const i))
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
+-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be
+-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (KnownNat n, n : sh ~ sh1)
@@ -203,6 +204,8 @@ infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxS sh = IxS sh Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -230,14 +233,6 @@ ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX ZSS ZIX = ZIS
-ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
-
-ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
-ixxFromIxS ZIS = ZIX
-ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
-
ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)
@@ -250,6 +245,12 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
+ixsCast ZSS ZIS = ZIS
+ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
+ixsCast _ _ = error "ixsCast: ranks don't match"
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -265,6 +266,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+-- * Shaped shapes
+
-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
-- Note that because the shape of a shape-typed array is known statically, you
@@ -321,23 +324,6 @@ shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shsFromShX (SUnknown _ :$% _) = error "impossible"
-
-shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list
@@ -381,7 +367,7 @@ instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
-withKnownShS k = withDict @(KnownShS sh) k
+withKnownShS = withDict @(KnownShS sh)
shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
@@ -391,6 +377,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
+-- This function may be removed in a future release.
+shsFromListS :: ListS sh f -> ShS sh
+shsFromListS ZS = ZSS
+shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
+-- function may be removed in a future release.
+shsFromIxS :: IxS sh i -> ShS sh
+shsFromIxS (IxS l) = shsFromListS l
+
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 838e2b0..8a29aa5 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -37,10 +37,12 @@ module Data.Array.Nested.Trace (
ShS(..), KnownShS(..),
Mixed,
+ ListX(ZX, (::%)),
IxX(..), IIxX,
- ShX(..), KnownShX(..),
+ ShX(..), KnownShX(..), IShX,
StaticShX(..),
SMayNat(..),
+ Conversion(..),
Elt,
PrimElt,
@@ -54,7 +56,7 @@ module Data.Array.Nested.Trace (
Perm(..),
IsPermutation,
KnownPerm(..),
- NumElt, FloatElt,
+ NumElt, IntElt, FloatElt,
Rank, Product,
Replicate,
MapJust,
@@ -67,4 +69,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index 4172fa0..4444acd 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -6,7 +6,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
@@ -30,6 +30,7 @@ module Data.Array.Nested.Types (
Replicate,
lemReplicateSucc,
MapJust,
+ lemMapJustEmpty, lemMapJustCons,
Head,
Tail,
Init,
@@ -111,10 +112,16 @@ type family Replicate n a where
lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = unsafeCoerceRefl
-type family MapJust l where
+type family MapJust l = r | r -> l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
+lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[]
+lemMapJustEmpty Refl = unsafeCoerceRefl
+
+lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
+lemMapJustCons Refl = unsafeCoerceRefl
+
type family Head l where
Head (x : _) = x
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index dde06e3..bf47622 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -31,10 +31,10 @@ import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
-import Data.Array.Nested.Mixed.Shape
import Data.Array.Strided.Orthotope
@@ -243,7 +243,7 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =