From 27f5fd474a85bd0c404215a1ce38ed378594e54b Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Thu, 9 Apr 2026 17:48:51 +0200 Subject: Get rid of ListH --- src/Data/Array/Nested.hs | 2 +- src/Data/Array/Nested/Mixed/Shape.hs | 327 +++++++++++++--------------------- src/Data/Array/Nested/Permutation.hs | 84 ++++----- src/Data/Array/Nested/Ranked/Shape.hs | 18 +- src/Data/Array/Nested/Shaped/Shape.hs | 28 +-- 5 files changed, 187 insertions(+), 272 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 14de7f9..f022fe0 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -62,7 +62,7 @@ module Data.Array.Nested ( Mixed, ListX(ZX, (::%)), IxX(.., ZIX, (:.%)), IIxX, - ShX(.., ZSX, (:$%)), KnownShX(..), IShX, + ShX(.., (:$%)), KnownShX(..), IShX, StaticShX(.., ZKX, (:!%)), SMayNat(..), mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim, diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index a5e3ced..671df2c 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -233,7 +233,7 @@ shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] fromLin _ _ _ = error "impossible" --- * Mixed shape-like lists to be used for ShX and StaticShX +-- * Mixed shapes data SMayNat i n where SUnknown :: i -> SMayNat i Nothing @@ -273,208 +273,158 @@ smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) -type role ListH nominal representational -type ListH :: [Maybe Nat] -> Type -> Type -data ListH sh i where - ZH :: ListH '[] i - ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +data ShX sh i where + ZSX :: ShX '[] i + ConsUnknown :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i -- TODO: bring this UNPACK back when GHC no longer crashes: --- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i - ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i -deriving instance Ord i => Ord (ListH sh i) +-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ShX sh i -> ShX (Just n : sh) i + ConsKnown :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i +deriving instance Ord i => Ord (ShX sh i) -- A manually defined instance and this INLINEABLE is needed to specialize -- mdot1Inner (otherwise GHC warns specialization breaks down here). -instance Eq i => Eq (ListH sh i) where +instance Eq i => Eq (ShX sh i) where {-# INLINEABLE (==) #-} - ZH == ZH = True + ZSX == ZSX = True ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2 ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2 #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance Show i => Show (ListH sh i) +deriving instance Show i => Show (ShX sh i) #else -instance Show i => Show (ListH sh i) where - showsPrec _ = listhShow shows +instance Show i => Show (ShX sh i) where + showsPrec _ l = shxShow (fromSMayNat shows (shows . fromSNat)) l #endif -instance NFData i => NFData (ListH sh i) where - rnf ZH = () +instance NFData i => NFData (ShX sh i) where + rnf ZSX = () rnf (x `ConsUnknown` l) = rnf x `seq` rnf l rnf (SNat `ConsKnown` l) = rnf l -data UnconsListHRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListHRes (SMayNat i n) (ListH sh i) -listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1) -listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes (SUnknown i) shl') -listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes (SKnown i) shl') -listhUncons ZH = Nothing +instance Functor (ShX sh) where + {-# INLINE fmap #-} + fmap f l = shxFmap (fromSMayNat (SUnknown . f) SKnown) l + +data UnconsShXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShXRes (SMayNat i n) (ShX sh i) +shxUncons :: ShX sh1 i -> Maybe (UnconsShXRes i sh1) +shxUncons (i `ConsUnknown` shl') = Just (UnconsShXRes (SUnknown i) shl') +shxUncons (i `ConsKnown` shl') = Just (UnconsShXRes (SKnown i) shl') +shxUncons ZSX = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to -- 'testEquality', except on the penultimate type parameter. -listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') -listhEqType ZH ZH = Just Refl -listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') - | Just Refl <- listhEqType sh sh' +shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqType ZSX ZSX = Just Refl +shxEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') + | Just Refl <- shxEqType sh sh' = Just Refl -listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh') +shxEqType (n `ConsKnown` sh) (m `ConsKnown` sh') | Just Refl <- testEquality n m - , Just Refl <- listhEqType sh sh' + , Just Refl <- shxEqType sh sh' = Just Refl -listhEqType _ _ = Nothing +shxEqType _ _ = Nothing -- | This checks whether the two lists actually contain equal values. This is -- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ -- in the @some@ package (except on the penultimate type parameter). -listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') -listhEqual ZH ZH = Just Refl -listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') | n == m - , Just Refl <- listhEqual sh sh' + , Just Refl <- shxEqual sh sh' = Just Refl -listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh') +shxEqual (n `ConsKnown` sh) (m `ConsKnown` sh') | Just Refl <- testEquality n m - , Just Refl <- listhEqual sh sh' + , Just Refl <- shxEqual sh sh' = Just Refl -listhEqual _ _ = Nothing - -{-# INLINE listhFmap #-} -listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j -listhFmap _ ZH = ZH -listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of - SUnknown y -> y `ConsUnknown` listhFmap f xs -listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of - SKnown y -> y `ConsKnown` listhFmap f xs - -{-# INLINE listhFoldMap #-} -listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m -listhFoldMap _ ZH = mempty -listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs -listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs - -listhLength :: ListH sh i -> Int -listhLength = getSum . listhFoldMap (\_ -> Sum 1) - -listhRank :: ListH sh i -> SNat (Rank sh) -listhRank ZH = SNat -listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat -listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat - -{-# INLINE listhShow #-} -listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS -listhShow f l = showString "[" . go "" l . showString "]" +shxEqual _ _ = Nothing + +{-# INLINE shxFmap #-} +shxFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ShX sh i -> ShX sh j +shxFmap _ ZSX = ZSX +shxFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of + SUnknown y -> y `ConsUnknown` shxFmap f xs +shxFmap f (x `ConsKnown` xs) = case f (SKnown x) of + SKnown y -> y `ConsKnown` shxFmap f xs + +{-# INLINE shxFoldMap #-} +shxFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ShX sh i -> m +shxFoldMap _ ZSX = mempty +shxFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> shxFoldMap f xs +shxFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> shxFoldMap f xs + +shxLength :: ShX sh i -> Int +shxLength = getSum . shxFoldMap (\_ -> Sum 1) + +shxRank :: ShX sh i -> SNat (Rank sh) +shxRank ZSX = SNat +shxRank (_ `ConsUnknown` l) | SNat <- shxRank l = SNat +shxRank (_ `ConsKnown` l) | SNat <- shxRank l = SNat + +{-# INLINE shxShow #-} +shxShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ShX sh i -> ShowS +shxShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListH sh' i -> ShowS - go _ ZH = id + go :: String -> ShX sh' i -> ShowS + go _ ZSX = id go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs -listhHead :: ListH (mn ': sh) i -> SMayNat i mn -listhHead (i `ConsUnknown` _) = SUnknown i -listhHead (i `ConsKnown` _) = SKnown i - -listhTail :: ListH (n : sh) i -> ListH sh i -listhTail (_ `ConsUnknown` sh) = sh -listhTail (_ `ConsKnown` sh) = sh - -listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i -listhAppend ZH idx' = idx' -listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx' -listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx' - -listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i -listhDrop ZH long = long -listhDrop (_ `ConsUnknown` short) long = case long of - _ `ConsUnknown` long' -> listhDrop short long' -listhDrop (_ `ConsKnown` short) long = case long of - _ `ConsKnown` long' -> listhDrop short long' - -listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i -listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh -listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh -listhInit (_ `ConsUnknown` ZH) = ZH -listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh -listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh -listhInit (_ `ConsKnown` ZH) = ZH - -listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh)) -listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh -listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh -listhLast (x `ConsUnknown` ZH) = SUnknown x -listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh -listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh -listhLast (x `ConsKnown` ZH) = SKnown x +shxHead :: ShX (mn ': sh) i -> SMayNat i mn +shxHead (i `ConsUnknown` _) = SUnknown i +shxHead (i `ConsKnown` _) = SKnown i --- * Mixed shapes +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ `ConsUnknown` sh) = sh +shxTail (_ `ConsKnown` sh) = sh --- | This is a newtype over 'ListH'. -type role ShX nominal representational -type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListH sh i) - deriving (Eq, Ord, NFData) +shxAppend :: ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend ZSX idx' = idx' +shxAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` shxAppend idx idx' +shxAppend (i `ConsKnown` idx) idx' = i `ConsKnown` shxAppend idx idx' -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZH +shxDropSh :: forall sh sh' i j. ShX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh ZSX long = long +shxDropSh (_ `ConsUnknown` short) long = case long of + _ `ConsUnknown` long' -> shxDropSh short long' +shxDropSh (_ `ConsKnown` short) long = case long of + _ `ConsKnown` long' -> shxDropSh short long' + +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (shxDropSh @_ @_ @i @()) + +shxInit :: forall i n sh. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` shxInit sh +shxInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` shxInit sh +shxInit (_ `ConsUnknown` ZSX) = ZSX +shxInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` shxInit sh +shxInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` shxInit sh +shxInit (_ `ConsKnown` ZSX) = ZSX + +shxLast :: forall i n sh. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = shxLast sh +shxLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = shxLast sh +shxLast (x `ConsUnknown` ZSX) = SUnknown x +shxLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = shxLast sh +shxLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = shxLast sh +shxLast (x `ConsKnown` ZSX) = SKnown x pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => SMayNat i n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes i (ShX -> shl))) - where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl) +pattern i :$% shl <- (shxUncons -> Just (UnconsShXRes i shl)) + where i :$% shl = case i of; SUnknown x -> x `ConsUnknown` shl; SKnown x -> x `ConsKnown` shl infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int -#ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance Show i => Show (ShX sh i) -#else -instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l -#endif - -instance Functor (ShX sh) where - {-# INLINE fmap #-} - fmap f (ShX l) = ShX (listhFmap (fromSMayNat (SUnknown . f) SKnown) l) - --- | This checks only whether the types are equal; unknown dimensions might --- still differ. This corresponds to 'testEquality', except on the penultimate --- type parameter. -shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') -shxEqType ZSX ZSX = Just Refl -shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m - , Just Refl <- shxEqType sh sh' - = Just Refl -shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') - | Just Refl <- shxEqType sh sh' - = Just Refl -shxEqType _ _ = Nothing - --- | This checks whether all dimensions have the same value. This is more than --- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the --- @some@ package (except on the penultimate type parameter). -shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') -shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m - , Just Refl <- shxEqual sh sh' - = Just Refl -shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') - | i == j - , Just Refl <- shxEqual sh sh' - = Just Refl -shxEqual _ _ = Nothing - -shxLength :: ShX sh i -> Int -shxLength (ShX l) = listhLength l - -shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listhRank l - -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 @@ -483,22 +433,22 @@ shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -- We don't report the size of the list in case of errors in order not to retain the list. {-# INLINEABLE shxFromList #-} shxFromList :: StaticShX sh -> [Int] -> IShX sh -shxFromList (StaticShX topssh) topl = ShX $ go topssh topl +shxFromList (StaticShX topssh) topl = go topssh topl where - go :: ListH sh' () -> [Int] -> ListH sh' Int - go ZH [] = ZH - go ZH _ = error $ "shxFromList: List too long (type says " ++ show (listhLength topssh) ++ ")" + go :: ShX sh' () -> [Int] -> ShX sh' Int + go ZSX [] = ZSX + go ZSX _ = error $ "shxFromList: List too long (type says " ++ show (shxLength topssh) ++ ")" go (ConsKnown sn sh) (i : is) | i == fromSNat' sn = ConsKnown sn (go sh is) | otherwise = error "shxFromList: Value does not match typing" go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is) - go _ _ = error $ "shxFromList: List too short (type says " ++ show (listhLength topssh) ++ ")" + go _ _ = error $ "shxFromList: List too short (type says " ++ show (shxLength topssh) ++ ")" {-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] -shxToList (ShX l) = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListH sh Int -> is - go ZH = nil +shxToList l = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil go (ConsUnknown i rest) = i `cons` go rest go (ConsKnown sn rest) = fromSNat' sn `cons` go rest in go l) @@ -517,15 +467,6 @@ shxFromSSX2 ZKX = Just ZSX shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing -shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listhAppend @_ @i) - -shxHead :: ShX (n : sh) i -> SMayNat i n -shxHead (ShX list) = listhHead list - -shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listhTail list) - shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh @@ -534,12 +475,6 @@ shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i - shxTakeSh _ ZSX _ = ZSX shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh -shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listhDrop @i @()) - -shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listhDrop @i @i) - {-# INLINEABLE shxTakeIx #-} shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i shxTakeIx _ (IxX ZX) _ = ZSX @@ -550,12 +485,6 @@ shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx ZIX long = long shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' -shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listhInit @i) - -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) -shxLast = coerce (listhLast @i) - {-# INLINE shxZipWith #-} shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k @@ -589,22 +518,22 @@ shxCast' ssh sh = case shxCast ssh sh of -- * Static mixed shapes --- | The part of a shape that is statically known. (A newtype over 'ListH'.) +-- | The part of a shape that is statically known. (A newtype over 'ShX'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListH sh ()) +newtype StaticShX sh = StaticShX (ShX sh ()) deriving (NFData) instance Eq (StaticShX sh) where _ == _ = True instance Ord (StaticShX sh) where compare _ _ = EQ pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZH +pattern ZKX = StaticShX ZSX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SMayNat () n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes i (StaticShX -> shl))) +pattern i :!% shl <- StaticShX (shxUncons -> Just (UnconsShXRes i (StaticShX -> shl))) where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) infixr 3 :!% @@ -615,30 +544,30 @@ infixr 3 :!% deriving instance Show (StaticShX sh) #else instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (StaticShX l) = shxShow (fromSMayNat shows (shows . fromSNat)) l #endif instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2 + testEquality (StaticShX l1) (StaticShX l2) = shxEqType l1 l2 ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listhLength l +ssxLength (StaticShX l) = shxLength l ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listhRank l +ssxRank (StaticShX l) = shxRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend = coerce (listhAppend @_ @()) +ssxAppend = coerce (shxAppend @_ @()) ssxHead :: StaticShX (n : sh) -> SMayNat () n -ssxHead (StaticShX list) = listhHead list +ssxHead (StaticShX list) = shxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (StaticShX list) = StaticShX (listhTail list) +ssxTail (StaticShX list) = StaticShX (shxTail list) ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh ssxTakeIx _ (IxX ZX) _ = ZKX @@ -649,16 +578,16 @@ ssxDropIx (IxX ZX) long = long ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listhDrop @() @i) +ssxDropSh = coerce (shxDropSh @_ @_ @() @i) ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listhDrop @() @()) +ssxDropSSX = coerce (shxDropSh @_ @_ @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listhInit @()) +ssxInit = coerce (shxInit @()) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) -ssxLast = coerce (listhLast @()) +ssxLast = coerce (shxLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index b6e5f47..ee79ecf 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -172,68 +172,54 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -listhTakeLenPerm :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i -listhTakeLenPerm PNil _ = ZH -listhTakeLenPerm (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLenPerm is sh -listhTakeLenPerm (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLenPerm is sh -listhTakeLenPerm (_ `PCons` _) ZH = error "Permutation longer than shape" - -listhDropLenPerm :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i -listhDropLenPerm PNil sh = sh -listhDropLenPerm (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLenPerm is sh -listhDropLenPerm (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLenPerm is sh -listhDropLenPerm (_ `PCons` _) ZH = error "Permutation longer than shape" - -listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i -listhPermute PNil _ = ZH -listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) = - case listhIndex i sh of - SUnknown x -> x `ConsUnknown` listhPermute is sh - SKnown x -> x `ConsKnown` listhPermute is sh - -listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh) -listhIndex SZ (n `ConsUnknown` _) = SUnknown n -listhIndex SZ (n `ConsKnown` _) = SKnown n -listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i)) +shxTakeLenPerm :: forall i is sh. Perm is -> ShX sh i -> ShX (TakeLen is sh) i +shxTakeLenPerm PNil _ = ZSX +shxTakeLenPerm (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` shxTakeLenPerm is sh +shxTakeLenPerm (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` shxTakeLenPerm is sh +shxTakeLenPerm (_ `PCons` _) ZSX = error "Permutation longer than shape" + +shxDropLenPerm :: forall i is sh. Perm is -> ShX sh i -> ShX (DropLen is sh) i +shxDropLenPerm PNil sh = sh +shxDropLenPerm (_ `PCons` is) (_ `ConsUnknown` sh) = shxDropLenPerm is sh +shxDropLenPerm (_ `PCons` is) (_ `ConsKnown` sh) = shxDropLenPerm is sh +shxDropLenPerm (_ `PCons` _) ZSX = error "Permutation longer than shape" + +shxPermute :: forall i is sh. Perm is -> ShX sh i -> ShX (Permute is sh) i +shxPermute PNil _ = ZSX +shxPermute (i `PCons` (is :: Perm is')) (sh :: ShX sh i) = + case shxIndex i sh of + SUnknown x -> x `ConsUnknown` shxPermute is sh + SKnown x -> x `ConsKnown` shxPermute is sh + +shxIndex :: forall i k sh. SNat k -> ShX sh i -> SMayNat i (Index k sh) +shxIndex SZ (n `ConsUnknown` _) = SUnknown n +shxIndex SZ (n `ConsKnown` _) = SKnown n +shxIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ShX sh' i)) | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh') - = listhIndex i sh -listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i)) + = shxIndex i sh +shxIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ShX sh' i)) | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh') - = listhIndex i sh -listhIndex _ ZH = error "Index into empty shape" + = shxIndex i sh +shxIndex _ ZSX = error "Index into empty shape" + +shxPermutePrefix :: forall i is sh. Perm is -> ShX sh i -> ShX (PermutePrefix is sh) i +shxPermutePrefix perm sh = shxAppend (shxPermute perm (shxTakeLenPerm perm sh)) (shxDropLenPerm perm sh) -listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i -listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLenPerm perm sh)) (listhDropLenPerm perm sh) ssxTakeLenPerm :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLenPerm = coerce (listhTakeLenPerm @()) +ssxTakeLenPerm = coerce (shxTakeLenPerm @()) ssxDropLenPerm :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLenPerm = coerce (listhDropLenPerm @()) +ssxDropLenPerm = coerce (shxDropLenPerm @()) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listhPermute @()) +ssxPermute = coerce (shxPermute @()) ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh) -ssxIndex k = coerce (listhIndex @() k) +ssxIndex k = coerce (shxIndex @() k) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listhPermutePrefix @()) - -shxTakeLenPerm :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) -shxTakeLenPerm = coerce (listhTakeLenPerm @Int) - -shxDropLenPerm :: Perm is -> IShX sh -> IShX (DropLen is sh) -shxDropLenPerm = coerce (listhDropLenPerm @Int) - -shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) -shxPermute = coerce (listhPermute @Int) - -shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh) -shxIndex k = coerce (listhIndex @i k) - -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listhPermutePrefix @Int) +ssxPermutePrefix = coerce (shxPermutePrefix @()) listxTakeLenPerm :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 690b7da..a352eb3 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -366,10 +366,10 @@ shrSize (ShR sh) = shxSize sh -- We don't report the size of the list in case of errors in order not to retain the list. {-# INLINEABLE shrFromList #-} shrFromList :: SNat n -> [Int] -> IShR n -shrFromList snat topl = ShR $ ShX $ go snat topl +shrFromList snat topl = ShR $ go snat topl where - go :: SNat n -> [Int] -> ListH (Replicate n Nothing) Int - go SZ [] = ZH + go :: SNat n -> [Int] -> ShX (Replicate n Nothing) Int + go SZ [] = ZSX go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")" go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is) go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")" @@ -377,9 +377,9 @@ shrFromList snat topl = ShR $ ShX $ go snat topl -- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shrToList #-} shrToList :: IShR n -> [Int] -shrToList (ShR (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListH sh Int -> is - go ZH = nil +shrToList (ShR l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil go (ConsUnknown i rest) = i `cons` go rest go ConsKnown{} = error "shrToList: impossible case" in go l) @@ -411,7 +411,7 @@ shrInit = -- TODO: change this and all other unsafeCoerceRefl to lemmas: gcastWith (unsafeCoerceRefl :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ - coerce (shxInit @_ @_ @i) + coerce (shxInit @i) shrLast :: forall n i. ShR (n + 1) i -> i shrLast (ShR sh) @@ -431,7 +431,7 @@ shrAppend = -- lemReplicatePlusApp requires an SNat gcastWith (unsafeCoerceRefl :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ - coerce (shxAppend @_ @_ @i) + coerce (shxAppend @_ @i) {-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k @@ -447,7 +447,7 @@ shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m shrSplitAt SS{} ZSR = error "m' + 1 <= 0" shrIndex :: forall k sh i. SNat k -> ShR sh i -> i -shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of +shrIndex k (ShR sh) = case shxIndex @i k sh of SUnknown i -> i SKnown{} -> error "shrIndex: impossible SKnown" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 622ab97..3d4bf31 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -326,23 +326,23 @@ shsSize (ShS sh) = shxSize sh -- in case of errors in order not to retain the list. {-# INLINEABLE shsFromList #-} shsFromList :: ShS sh -> [Int] -> ShS sh -shsFromList sh0@(ShS (ShX topsh)) topl = go topsh topl `seq` sh0 +shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0 where - go :: ListH sh' Int -> [Int] -> () - go ZH [] = () - go ZH _ = error $ "shsFromList: List too long (type says " ++ show (listhLength topsh) ++ ")" + go :: ShX sh' Int -> [Int] -> () + go ZSX [] = () + go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")" go (ConsKnown sn sh) (i : is) | i == fromSNat' sn = go sh is | otherwise = error "shsFromList: Value does not match typing" go ConsUnknown{} _ = error "shsFromList: impossible case" - go _ _ = error $ "shsFromList: List too short (type says " ++ show (listhLength topsh) ++ ")" + go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")" -- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] -shsToList (ShS (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListH sh Int -> is - go ZH = nil +shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil go ConsUnknown{} = error "shsToList: impossible case" go (ConsKnown snat rest) = fromSNat' snat `cons` go rest in go l) @@ -368,7 +368,7 @@ shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) shsInit = gcastWith (unsafeCoerceRefl :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ - coerce (shxInit @_ @_ @Int) + coerce (shxInit @Int) shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) shsLast (ShS shx) = @@ -381,31 +381,31 @@ shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') shsAppend = gcastWith (unsafeCoerceRefl :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ - coerce (shxAppend @_ @_ @Int) + coerce (shxAppend @_ @Int) shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLenPerm = gcastWith (unsafeCoerceRefl :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ - coerce shxTakeLenPerm + coerce (shxTakeLenPerm @Int) shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) shsDropLenPerm = gcastWith (unsafeCoerceRefl :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ - coerce shxDropLenPerm + coerce (shxDropLenPerm @Int) shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = gcastWith (unsafeCoerceRefl :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ - coerce shxPermute + coerce (shxPermute @Int) shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) shsIndex i (ShS sh) = gcastWith (unsafeCoerceRefl :: Index i (MapJust sh) :~: Just (Index i sh)) $ - case shxIndex @_ @_ @Int i sh of + case shxIndex @Int i sh of SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -- cgit v1.3