aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Convert.hs36
-rw-r--r--src/Data/Array/Nested/Mixed.hs2
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs278
-rw-r--r--src/Data/Array/Nested/Permutation.hs95
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs13
-rw-r--r--src/Data/Array/Nested/Shaped.hs4
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs14
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs154
-rw-r--r--src/Data/Array/XArray.hs2
9 files changed, 403 insertions, 195 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 8c88d23..3706105 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -38,9 +38,11 @@ module Data.Array.Nested.Convert (
) where
import Control.Category
+import Data.Coerce (coerce)
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
@@ -56,12 +58,10 @@ import Data.Array.Nested.Types
-- * To ranked
ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
-ixrFromIxS ZIS = ZIR
-ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+ixrFromIxS = unsafeCoerce
ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX ZIX = ZIR
-ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+ixrFromIxX = unsafeCoerce
shrFromShS :: ShS sh -> IShR (Rank sh)
shrFromShS ZSS = ZSR
@@ -75,12 +75,11 @@ shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
-- * To shaped
--- TODO: these take a ShS because there are KnownNats inside IxS.
-
+-- TODO: remove the ShS now that no KnownNats is inside IxS.
ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
-ixsFromIxR ZSS ZIR = ZIS
-ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+ixsFromIxR _ = unsafeCoerce
+-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS.
-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
-- following, but more efficient:
--
@@ -90,11 +89,11 @@ ixsFromIxR' ZSS ZIR = ZIS
ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
--- TODO: this takes a ShS because there are KnownNats inside IxS.
+-- TODO: remove the ShS now that no KnownNats is inside IxS.
ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX ZSS ZIX = ZIS
-ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+ixsFromIxX _ = unsafeCoerce
+-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS.
-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
-- the following, but more efficient:
--
@@ -113,7 +112,8 @@ withShsFromShR (n :$: sh) k =
Just sn@SNat -> k (sn :$$ sh')
Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
--- shsFromShX re-exported
+shsFromShX :: IShX (MapJust sh) -> ShS sh
+shsFromShX = coerce
-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
@@ -128,6 +128,7 @@ withShsFromShX (SUnknown n :$% sh) k =
Just sn@SNat -> k (sn :$$ sh')
Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+-- If it ever matters for performance, this is unsafeCoercible.
shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
shsFromSSX = shsFromShX Prelude.. shxFromSSX
@@ -136,14 +137,10 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX
-- * To mixed
ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
-ixxFromIxR ZIR = ZIX
-ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
- (n :.% ixxFromIxR idx)
+ixxFromIxR = unsafeCoerce
ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
-ixxFromIxS ZIS = ZIX
-ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+ixxFromIxS = unsafeCoerce
shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
shxFromShR ZSR = ZSX
@@ -152,8 +149,7 @@ shxFromShR (n :$: (idx :: ShR m i)) =
(SUnknown n :$% shxFromShR idx)
shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+shxFromShS = coerce
-- ixxCast re-exported
-- shxCast re-exported
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 54b2a9f..eb05eaa 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -816,7 +816,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
ssh = ssxFromShX sh
- snm :: SMayNat () SNat (AddMaybe n m)
+ snm :: SMayNat () (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index b1b4f81..7c79f8b 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -1,9 +1,10 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
@@ -35,9 +36,9 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
+import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
-import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
@@ -56,7 +57,7 @@ type family Rank sh where
Rank (_ : sh) = Rank sh + 1
--- * Mixed lists
+-- * Mixed lists to be used IxX and shaped and ranked lists and indexes
type role ListX nominal representational
type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
@@ -189,7 +190,7 @@ listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
- deriving (Eq, Ord, Generic)
+ deriving (Eq, Ord, NFData)
pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
pattern ZIX = IxX ZX
@@ -229,8 +230,6 @@ instance Foldable (IxX sh) where
null ZIX = False
null _ = True
-instance NFData i => NFData (IxX sh i)
-
ixxLength :: IxX sh i -> Int
ixxLength (IxX l) = listxLength l
@@ -295,32 +294,32 @@ ixxToLinear = \sh i -> go sh i 0
go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i)
--- * Mixed shapes
+-- * Mixed shape-like lists to be used for ShX and StaticShX
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
rnf (SUnknown i) = rnf i
rnf (SKnown x) = rnf x
-instance TestEquality f => TestEquality (SMayNat i f) where
+instance TestEquality (SMayNat i) where
testEquality SUnknown{} SUnknown{} = Just Refl
testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
testEquality _ _ = Nothing
{-# INLINE fromSMayNat #-}
fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => f m -> r)
- -> SMayNat i f n -> r
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s
-fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' :: SMayNat Int n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
type family AddMaybe n m where
@@ -328,27 +327,155 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
--- | This is a newtype over 'ListX'.
+type role ListH nominal representational
+type ListH :: [Maybe Nat] -> Type -> Type
+data ListH sh i where
+ ZH :: ListH '[] i
+ ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i
+-- TODO: bring this UNPACK back when GHC no longer crashes:
+-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i
+ ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i
+deriving instance Eq i => Eq (ListH sh i)
+deriving instance Ord i => Ord (ListH sh i)
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ListH sh i)
+#else
+instance Show i => Show (ListH sh i) where
+ showsPrec _ = listhShow shows
+#endif
+
+instance NFData i => NFData (ListH sh i) where
+ rnf ZH = ()
+ rnf (x `ConsUnknown` l) = rnf x `seq` rnf l
+ rnf (SNat `ConsKnown` l) = rnf l
+
+data UnconsListHRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n)
+listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1)
+listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes shl' (SUnknown i))
+listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes shl' (SKnown i))
+listhUncons ZH = Nothing
+
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
+listhEqType ZH ZH = Just Refl
+listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh')
+ | Just Refl <- listhEqType sh sh'
+ = Just Refl
+listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listhEqType sh sh'
+ = Just Refl
+listhEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
+listhEqual ZH ZH = Just Refl
+listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh')
+ | n == m
+ , Just Refl <- listhEqual sh sh'
+ = Just Refl
+listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listhEqual sh sh'
+ = Just Refl
+listhEqual _ _ = Nothing
+
+{-# INLINE listhFmap #-}
+listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j
+listhFmap _ ZH = ZH
+listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of
+ SUnknown y -> y `ConsUnknown` listhFmap f xs
+listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of
+ SKnown y -> y `ConsKnown` listhFmap f xs
+
+{-# INLINE listhFoldMap #-}
+listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m
+listhFoldMap _ ZH = mempty
+listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs
+listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs
+
+listhLength :: ListH sh i -> Int
+listhLength = getSum . listhFoldMap (\_ -> Sum 1)
+
+listhRank :: ListH sh i -> SNat (Rank sh)
+listhRank ZH = SNat
+listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat
+listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat
+
+{-# INLINE listhShow #-}
+listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS
+listhShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListH sh' i -> ShowS
+ go _ ZH = id
+ go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs
+ go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs
+
+listhHead :: ListH (mn ': sh) i -> SMayNat i mn
+listhHead (i `ConsUnknown` _) = SUnknown i
+listhHead (i `ConsKnown` _) = SKnown i
+
+listhTail :: ListH (n : sh) i -> ListH sh i
+listhTail (_ `ConsUnknown` sh) = sh
+listhTail (_ `ConsKnown` sh) = sh
+
+listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i
+listhAppend ZH idx' = idx'
+listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx'
+listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx'
+
+listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i
+listhDrop ZH long = long
+listhDrop (_ `ConsUnknown` short) long = case long of
+ _ `ConsUnknown` long' -> listhDrop short long'
+listhDrop (_ `ConsKnown` short) long = case long of
+ _ `ConsKnown` long' -> listhDrop short long'
+
+listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i
+listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh
+listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh
+listhInit (_ `ConsUnknown` ZH) = ZH
+listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh
+listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh
+listhInit (_ `ConsKnown` ZH) = ZH
+
+listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh))
+listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh
+listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh
+listhLast (x `ConsUnknown` ZH) = SUnknown x
+listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh
+listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh
+listhLast (x `ConsKnown` ZH) = SKnown x
+
+-- * Mixed shapes
+
+-- | This is a newtype over 'ListH'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
- deriving (Eq, Ord, Generic)
+newtype ShX sh i = ShX (ListH sh i)
+ deriving (Eq, Ord, NFData)
pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
-pattern ZSX = ShX ZX
+pattern ZSX = ShX ZH
pattern (:$%)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
-pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
- where i :$% ShX shl = ShX (i ::% shl)
+ => SMayNat i n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i))
+ where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl)
infixr 3 :$%
{-# COMPLETE ZSX, (:$%) #-}
@@ -359,17 +486,12 @@ type IShX sh = ShX sh Int
deriving instance Show i => Show (ShX sh i)
#else
instance Show i => Show (ShX sh i) where
- showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+ showsPrec _ (ShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l
#endif
instance Functor (ShX sh) where
{-# INLINE fmap #-}
- fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
-
-instance NFData i => NFData (ShX sh i) where
- rnf (ShX ZX) = ()
- rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+ fmap f (ShX l) = ShX (listhFmap (fromSMayNat (SUnknown . f) SKnown) l)
-- | This checks only whether the types are equal; unknown dimensions might
-- still differ. This corresponds to 'testEquality', except on the penultimate
@@ -401,10 +523,10 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
shxEqual _ _ = Nothing
shxLength :: ShX sh i -> Int
-shxLength (ShX l) = listxLength l
+shxLength (ShX l) = listhLength l
shxRank :: ShX sh i -> SNat (Rank sh)
-shxRank (ShX l) = listxRank l
+shxRank (ShX l) = listhRank l
-- | The number of elements in an array described by this shape.
shxSize :: IShX sh -> Int
@@ -433,6 +555,7 @@ shxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
go (smn :$% sh) = fromSMayNat' smn `cons` go sh
in go list)
+-- If it ever matters for performance, this is unsafeCoercible.
shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
shxFromSSX ZKX = ZSX
shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
@@ -447,35 +570,36 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+shxAppend = coerce (listhAppend @_ @i)
-shxHead :: ShX (n : sh) i -> SMayNat i SNat n
-shxHead (ShX list) = listxHead list
+shxHead :: ShX (n : sh) i -> SMayNat i n
+shxHead (ShX list) = listhHead list
shxTail :: ShX (n : sh) i -> ShX sh i
-shxTail (ShX list) = ShX (listxTail list)
+shxTail (ShX list) = ShX (listhTail list)
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+shxDropSSX = coerce (listhDrop @i @())
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+shxDropIx (IxX ZX) long = long
+shxDropIx (IxX (_ ::% short)) long = case long of _ :$% long' -> shxDropIx (IxX short) long'
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+shxDropSh = coerce (listhDrop @i @i)
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listxInit @(SMayNat i SNat))
+shxInit = coerce (listhInit @i)
-shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
-shxLast = coerce (listxLast @(SMayNat i SNat))
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
+shxLast = coerce (listhLast @i)
shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
shxTakeSSX _ ZKX _ = ZSX
shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
{-# INLINE shxZipWith #-}
-shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n)
-> ShX sh i -> ShX sh j -> ShX sh k
shxZipWith _ ZSX ZSX = ZSX
shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
@@ -523,20 +647,24 @@ shxCast' ssh sh = case shxCast ssh sh of
-- * Static mixed shapes
--- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+-- | The part of a shape that is statically known. (A newtype over 'ListH'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
- deriving (Eq, Ord)
+newtype StaticShX sh = StaticShX (ListH sh ())
+ deriving (NFData)
+
+instance Eq (StaticShX sh) where _ == _ = True
+instance Ord (StaticShX sh) where compare _ _ = EQ
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
-pattern ZKX = StaticShX ZX
+pattern ZKX = StaticShX ZH
pattern (:!%)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
-pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
- where i :!% StaticShX shl = StaticShX (i ::% shl)
+ => SMayNat () n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl)
+
infixr 3 :!%
{-# COMPLETE ZKX, (:!%) #-}
@@ -545,22 +673,17 @@ infixr 3 :!%
deriving instance Show (StaticShX sh)
#else
instance Show (StaticShX sh) where
- showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+ showsPrec _ (StaticShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l
#endif
-instance NFData (StaticShX sh) where
- rnf (StaticShX ZX) = ()
- rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
- rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
-
instance TestEquality StaticShX where
- testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
+ testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2
ssxLength :: StaticShX sh -> Int
-ssxLength (StaticShX l) = listxLength l
+ssxLength (StaticShX l) = listhLength l
ssxRank :: StaticShX sh -> SNat (Rank sh)
-ssxRank (StaticShX l) = listxRank l
+ssxRank (StaticShX l) = listhRank l
-- | @ssxEqType = 'testEquality'@. Provided for consistency.
ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
@@ -570,26 +693,31 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
-ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
-ssxHead (StaticShX list) = listxHead list
+ssxHead :: StaticShX (n : sh) -> SMayNat () n
+ssxHead (StaticShX list) = listhHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
+ssxTakeIx _ (IxX ZX) _ = ZKX
+ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short'
ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropIx (IxX ZX) long = long
+ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long'
ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+ssxDropSh = coerce (listhDrop @() @i)
+
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listhDrop @() @())
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listxInit @(SMayNat () SNat))
+ssxInit = coerce (listhInit @())
-ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
-ssxLast = coerce (listxLast @(SMayNat () SNat))
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
+ssxLast = coerce (listhLast @())
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
@@ -632,18 +760,18 @@ type family Flatten' acc sh where
Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-- This function is currently unused
-ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh)
ssxFlatten = go (SNat @1)
where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh)
go acc ZKX = SKnown acc
go _ (SUnknown () :!% _) = SUnknown ()
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
-shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten :: IShX sh -> SMayNat Int (Flatten sh)
shxFlatten = go (SNat @1)
where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh)
go acc ZSX = SKnown acc
go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 065c9fd..2e0c1ca 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -172,6 +172,70 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
+listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i
+listhTakeLen PNil _ = ZH
+listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh
+listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh
+listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape"
+
+listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i
+listhDropLen PNil sh = sh
+listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape"
+
+listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i
+listhPermute PNil _ = ZH
+listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) =
+ case listhIndex i sh of
+ SUnknown x -> x `ConsUnknown` listhPermute is sh
+ SKnown x -> x `ConsKnown` listhPermute is sh
+
+listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh)
+listhIndex SZ (n `ConsUnknown` _) = SUnknown n
+listhIndex SZ (n `ConsKnown` _) = SKnown n
+listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh')
+ = listhIndex i sh
+listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh')
+ = listhIndex i sh
+listhIndex _ ZH = error "Index into empty shape"
+
+listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i
+listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh)
+
+ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listhTakeLen @())
+
+ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen = coerce (listhDropLen @())
+
+ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute = coerce (listhPermute @())
+
+ssxIndex :: SNat i -> StaticShX sh -> SMayNat () (Index i sh)
+ssxIndex i = coerce (listhIndex @() i)
+
+ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listhPermutePrefix @())
+
+shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh)
+shxTakeLen = coerce (listhTakeLen @Int)
+
+shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh)
+shxDropLen = coerce (listhDropLen @Int)
+
+shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh)
+shxPermute = coerce (listhPermute @Int)
+
+shxIndex :: SNat i -> IShX sh -> SMayNat Int (Index i sh)
+shxIndex i = coerce (listhIndex @Int i)
+
+shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
+shxPermutePrefix = coerce (listhPermutePrefix @Int)
+
+
listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
listxTakeLen PNil _ = ZX
listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
@@ -185,14 +249,14 @@ listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
listxPermute PNil _ = ZX
listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
+ listxIndex i sh ::% listxPermute is sh
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
-listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
+listxIndex :: forall f i sh. SNat i -> ListX sh f -> f (Index i sh)
+listxIndex SZ (n ::% _) = n
+listxIndex (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
-listxIndex _ _ _ ZX = error "Index into empty shape"
+ = listxIndex i sh
+listxIndex _ ZX = error "Index into empty shape"
listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
@@ -200,25 +264,6 @@ listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm s
ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
-ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
-
-ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
-
-ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
-
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
-
-ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
-
-shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-
-
-- * Operations on permutations
permInverse :: Perm is
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index b6bee2e..36f49dc 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,8 +1,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -37,7 +35,6 @@ import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, build, quotRemInt#)
-import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
@@ -216,8 +213,7 @@ listrPermutePrefix = \perm sh ->
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
@@ -243,8 +239,6 @@ instance Show i => Show (IxR n i) where
showsPrec _ (IxR l) = listrShow shows l
#endif
-instance NFData i => NFData (IxR sh i)
-
ixrLength :: IxR sh i -> Int
ixrLength (IxR l) = listrLength l
@@ -310,8 +304,7 @@ ixrToLinear = \sh i -> go sh i 0
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR
@@ -335,8 +328,6 @@ instance Show i => Show (ShR n i) where
showsPrec _ (ShR l) = listrShow shows l
#endif
-instance NFData i => NFData (ShR sh i)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index d23a025..85042f2 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -246,9 +246,7 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape
sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
-sflatten arr =
- case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
- n@SNat -> sreshape (n :$$ ZSS) arr
+sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index e2ec416..b86bfe5 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -26,7 +26,6 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -132,7 +131,7 @@ instance Elt a => Elt (Shaped sh a) where
type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
- mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+ mshapeTree (Shaped arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -256,13 +255,4 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shsFromShX (mshape arr)
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
- castWith (subst1 (sym (lemMapJustCons Refl))) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
-shsFromShX (SUnknown _ :$% _) = error "impossible"
+sshape (Shaped arr) = coerce (mshape arr)
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index bfc6ad2..9d463a9 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,10 +1,9 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
@@ -39,7 +38,6 @@ import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
-import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
@@ -136,6 +134,17 @@ listsFromList topsh topl = go topsh topl
++ show (shsLength topsh) ++ ", list has length "
++ show (length topl) ++ ")"
+{-# INLINEABLE listsFromListS #-}
+listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i)
+listsFromListS topl0 topl = go topl0 topl
+ where
+ go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i)
+ go ZS [] = ZS
+ go (_ ::$ l0) (i : is) = Const i ::$ go l0 is
+ go _ _ = error $ "listsFromListS: Mismatched list length (the model says "
+ ++ show (listsLength topl0) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
{-# INLINEABLE listsToList #-}
listsToList :: ListS sh (Const i) -> [i]
listsToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
@@ -185,16 +194,16 @@ listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex (Proxy @is') (Proxy @sh) i sh of
+ case listsIndex i sh of
item -> item ::$ listsPermute is sh
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh)
-listsIndex _ _ SZ (n ::$ _) = n
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
+-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed
+listsIndex :: forall f i sh. SNat i -> ListS sh f -> f (Index i sh)
+listsIndex SZ (n ::$ _) = n
+listsIndex (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh
-listsIndex _ _ _ ZS = error "Index into empty shape"
+ = listsIndex i sh
+listsIndex _ ZS = error "Index into empty shape"
listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
@@ -205,7 +214,7 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord, Generic)
+ deriving (Eq, Ord, NFData)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
@@ -247,8 +256,6 @@ instance Foldable (IxS sh) where
null ZIS = False
null _ = True
-instance NFData i => NFData (IxS sh i)
-
ixsLength :: IxS sh i -> Int
ixsLength (IxS l) = listsLength l
@@ -258,6 +265,10 @@ ixsRank (IxS l) = listsRank l
ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
ixsFromList = coerce (listsFromList @_ @i)
+{-# INLINEABLE ixsFromIxS #-}
+ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS = coerce (listsFromListS @_ @i0 @i)
+
{-# INLINEABLE ixsToList #-}
ixsToList :: forall sh i. IxS sh i -> [i]
ixsToList = coerce (listsToList @_ @i)
@@ -278,11 +289,9 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
--- TODO: this takes a ShS because there are KnownNats inside IxS.
-ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
-ixsCast ZSS ZIS = ZIS
-ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
-ixsCast _ _ = error "ixsCast: ranks don't match"
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -318,21 +327,34 @@ ixsToLinear = \sh i -> go sh i 0
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Generic)
+newtype ShS sh = ShS (ShX (MapJust sh) Int)
+ deriving (NFData)
instance Eq (ShS sh) where _ == _ = True
instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
+pattern ZSS <- ShS (matchZS -> Just Refl)
+ where ZSS = ShS ZSX
+
+matchZS :: forall sh f. ShX (MapJust sh) f -> Maybe (sh :~: '[])
+matchZS ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZS _ = Nothing
pattern (:$$)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
+pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl))
+ where i :$$ ShS shl = ShS (SKnown i :$% shl)
+
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
+shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1)
+shsUncons (ShS (SKnown x :$% sh'))
+ | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsShSRes x (ShS sh'))
+shsUncons (ShS _) = Nothing
infixr 3 :$$
@@ -342,15 +364,13 @@ infixr 3 :$$
deriving instance Show (ShS sh)
#else
instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+ showsPrec d (ShS shx) = showsPrec d shx
#endif
-instance NFData (ShS sh) where
- rnf (ShS ZS) = ()
- rnf (ShS (SNat ::$ l)) = rnf (ShS l)
-
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+ testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of
+ Nothing -> Nothing
+ Just Refl -> Just unsafeCoerceRefl
-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
@@ -358,10 +378,13 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality
shsLength :: ShS sh -> Int
-shsLength (ShS l) = listsLength l
+shsLength (ShS shx) = shxLength shx
-shsRank :: ShS sh -> SNat (Rank sh)
-shsRank (ShS l) = listsRank l
+shsRank :: forall sh. ShS sh -> SNat (Rank sh)
+shsRank (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Rank (MapJust sh) :~: Rank sh) $
+ shxRank shx
shsSize :: ShS sh -> Int
shsSize ZSS = 1
@@ -391,31 +414,68 @@ shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) ->
in go topsh)
shsHead :: ShS (n : sh) -> SNat n
-shsHead (ShS list) = listsHead list
+shsHead (ShS shx) = case shxHead shx of
+ SKnown SNat -> SNat
-shsTail :: ShS (n : sh) -> ShS sh
-shsTail (ShS list) = ShS (listsTail list)
+shsTail :: forall n sh. ShS (n : sh) -> ShS sh
+shsTail = coerce (shxTail @_ @_ @Int)
-shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
-shsInit (ShS list) = ShS (listsInit list)
+shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
+shsInit =
+ gcastWith (unsafeCoerceRefl
+ :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $
+ coerce (shxInit @_ @_ @Int)
-shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
-shsLast (ShS list) = listsLast list
+shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $
+ case shxLast shx of
+ SKnown SNat -> SNat
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
-shsAppend = coerce (listsAppend @_ @SNat)
+shsAppend =
+ gcastWith (unsafeCoerceRefl
+ :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $
+ coerce (shxAppend @_ @_ @Int)
+
+shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen =
+ gcastWith (unsafeCoerceRefl
+ :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $
+ coerce shxTakeLen
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
+shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh)
+shsDropLen =
+ gcastWith (unsafeCoerceRefl
+ :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $
+ coerce shxDropLen
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
+shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute =
+ gcastWith (unsafeCoerceRefl
+ :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $
+ coerce shxPermute
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh))
+shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex i sh =
+ gcastWith (unsafeCoerceRefl
+ :: Index i (MapJust sh) :~: Just (Index i sh)) $
+ case shxIndex i (coerce sh) of
+ SKnown SNat -> SNat
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+shsPermutePrefix perm (ShS shx)
+ {- TODO: here and elsewhere, solve the module dependency cycle and add this:
+ | Refl <- lemTakeLenMapJust perm sh
+ , Refl <- lemDropLenMapJust perm sh
+ , Refl <- lemPermuteMapJust perm sh
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -}
+ = gcastWith (unsafeCoerceRefl
+ :: Permute is (TakeLen is (MapJust sh))
+ ++ DropLen is (MapJust sh)
+ :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $
+ ShS (shxPermutePrefix perm shx)
type family Product sh where
Product '[] = 1
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 6389e67..a9fc14c 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -285,7 +285,7 @@ sumInner ssh ssh' arr
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
- , let sn = listxRank (let StaticShX l = ssh in l)
+ , let sn = ssxRank ssh
= XArray (liftO1 (numEltSum1Inner sn) arr')
in go $