diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 36 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 278 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 95 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 13 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 154 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 2 |
9 files changed, 403 insertions, 195 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 8c88d23..3706105 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -38,9 +38,11 @@ module Data.Array.Nested.Convert ( ) where import Control.Category +import Data.Coerce (coerce) import Data.Proxy import Data.Type.Equality import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed @@ -56,12 +58,10 @@ import Data.Array.Nested.Types -- * To ranked ixrFromIxS :: IxS sh i -> IxR (Rank sh) i -ixrFromIxS ZIS = ZIR -ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix +ixrFromIxS = unsafeCoerce ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx +ixrFromIxX = unsafeCoerce shrFromShS :: ShS sh -> IShR (Rank sh) shrFromShS ZSS = ZSR @@ -75,12 +75,11 @@ shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh -- * To shaped --- TODO: these take a ShS because there are KnownNats inside IxS. - +-- TODO: remove the ShS now that no KnownNats is inside IxS. ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i -ixsFromIxR ZSS ZIR = ZIS -ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ = unsafeCoerce +-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS. -- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the -- following, but more efficient: -- @@ -90,11 +89,11 @@ ixsFromIxR' ZSS ZIR = ZIS ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" --- TODO: this takes a ShS because there are KnownNats inside IxS. +-- TODO: remove the ShS now that no KnownNats is inside IxS. ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx +ixsFromIxX _ = unsafeCoerce +-- TODO: if possible, remove the ShS now that no KnownNats is inside IxS. -- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to -- the following, but more efficient: -- @@ -113,7 +112,8 @@ withShsFromShR (n :$: sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" --- shsFromShX re-exported +shsFromShX :: IShX (MapJust sh) -> ShS sh +shsFromShX = coerce -- | Produce an existential 'ShS' from an 'IShX'. If you already know that -- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. @@ -128,6 +128,7 @@ withShsFromShX (SUnknown n :$% sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" +-- If it ever matters for performance, this is unsafeCoercible. shsFromSSX :: StaticShX (MapJust sh) -> ShS sh shsFromSSX = shsFromShX Prelude.. shxFromSSX @@ -136,14 +137,10 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX -- * To mixed ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) - (n :.% ixxFromIxR idx) +ixxFromIxR = unsafeCoerce ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh +ixxFromIxS = unsafeCoerce shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i shxFromShR ZSR = ZSX @@ -152,8 +149,7 @@ shxFromShR (n :$: (idx :: ShR m i)) = (SUnknown n :$% shxFromShR idx) shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +shxFromShS = coerce -- ixxCast re-exported -- shxCast re-exported diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54b2a9f..eb05eaa 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -816,7 +816,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 ssh = ssxFromShX sh - snm :: SMayNat () SNat (AddMaybe n m) + snm :: SMayNat () (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index b1b4f81..7c79f8b 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -1,9 +1,10 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} @@ -35,9 +36,9 @@ import Data.Functor.Const import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) +import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) -import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits @@ -56,7 +57,7 @@ type family Rank sh where Rank (_ : sh) = Rank sh + 1 --- * Mixed lists +-- * Mixed lists to be used IxX and shaped and ranked lists and indexes type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type @@ -189,7 +190,7 @@ listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) + deriving (Eq, Ord, NFData) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX @@ -229,8 +230,6 @@ instance Foldable (IxX sh) where null ZIX = False null _ = True -instance NFData i => NFData (IxX sh i) - ixxLength :: IxX sh i -> Int ixxLength (IxX l) = listxLength l @@ -295,32 +294,32 @@ ixxToLinear = \sh i -> go sh i 0 go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) --- * Mixed shapes +-- * Mixed shape-like lists to be used for ShX and StaticShX -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing {-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -328,27 +327,155 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) --- | This is a newtype over 'ListX'. +type role ListH nominal representational +type ListH :: [Maybe Nat] -> Type -> Type +data ListH sh i where + ZH :: ListH '[] i + ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i +-- TODO: bring this UNPACK back when GHC no longer crashes: +-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i + ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i +deriving instance Eq i => Eq (ListH sh i) +deriving instance Ord i => Ord (ListH sh i) + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListH sh i) +#else +instance Show i => Show (ListH sh i) where + showsPrec _ = listhShow shows +#endif + +instance NFData i => NFData (ListH sh i) where + rnf ZH = () + rnf (x `ConsUnknown` l) = rnf x `seq` rnf l + rnf (SNat `ConsKnown` l) = rnf l + +data UnconsListHRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n) +listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1) +listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes shl' (SUnknown i)) +listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes shl' (SKnown i)) +listhUncons ZH = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') +listhEqType ZH ZH = Just Refl +listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') + | Just Refl <- listhEqType sh sh' + = Just Refl +listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m + , Just Refl <- listhEqType sh sh' + = Just Refl +listhEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') +listhEqual ZH ZH = Just Refl +listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') + | n == m + , Just Refl <- listhEqual sh sh' + = Just Refl +listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m + , Just Refl <- listhEqual sh sh' + = Just Refl +listhEqual _ _ = Nothing + +{-# INLINE listhFmap #-} +listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j +listhFmap _ ZH = ZH +listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of + SUnknown y -> y `ConsUnknown` listhFmap f xs +listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of + SKnown y -> y `ConsKnown` listhFmap f xs + +{-# INLINE listhFoldMap #-} +listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m +listhFoldMap _ ZH = mempty +listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs +listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs + +listhLength :: ListH sh i -> Int +listhLength = getSum . listhFoldMap (\_ -> Sum 1) + +listhRank :: ListH sh i -> SNat (Rank sh) +listhRank ZH = SNat +listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat +listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat + +{-# INLINE listhShow #-} +listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS +listhShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListH sh' i -> ShowS + go _ ZH = id + go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs + go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs + +listhHead :: ListH (mn ': sh) i -> SMayNat i mn +listhHead (i `ConsUnknown` _) = SUnknown i +listhHead (i `ConsKnown` _) = SKnown i + +listhTail :: ListH (n : sh) i -> ListH sh i +listhTail (_ `ConsUnknown` sh) = sh +listhTail (_ `ConsKnown` sh) = sh + +listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i +listhAppend ZH idx' = idx' +listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx' +listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx' + +listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i +listhDrop ZH long = long +listhDrop (_ `ConsUnknown` short) long = case long of + _ `ConsUnknown` long' -> listhDrop short long' +listhDrop (_ `ConsKnown` short) long = case long of + _ `ConsKnown` long' -> listhDrop short long' + +listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i +listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh +listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh +listhInit (_ `ConsUnknown` ZH) = ZH +listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh +listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh +listhInit (_ `ConsKnown` ZH) = ZH + +listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh)) +listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh +listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh +listhLast (x `ConsUnknown` ZH) = SUnknown x +listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh +listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh +listhLast (x `ConsKnown` ZH) = SKnown x + +-- * Mixed shapes + +-- | This is a newtype over 'ListH'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) +newtype ShX sh i = ShX (ListH sh i) + deriving (Eq, Ord, NFData) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX +pattern ZSX = ShX ZH pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) + => SMayNat i n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i)) + where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} @@ -359,17 +486,12 @@ type IShX sh = ShX sh Int deriving instance Show i => Show (ShX sh i) #else instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (ShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l #endif instance Functor (ShX sh) where {-# INLINE fmap #-} - fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + fmap f (ShX l) = ShX (listhFmap (fromSMayNat (SUnknown . f) SKnown) l) -- | This checks only whether the types are equal; unknown dimensions might -- still differ. This corresponds to 'testEquality', except on the penultimate @@ -401,10 +523,10 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') shxEqual _ _ = Nothing shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l +shxLength (ShX l) = listhLength l shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l +shxRank (ShX l) = listhRank l -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int @@ -433,6 +555,7 @@ shxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> go (smn :$% sh) = fromSMayNat' smn `cons` go sh in go list) +-- If it ever matters for performance, this is unsafeCoercible. shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i shxFromSSX ZKX = ZSX shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) @@ -447,35 +570,36 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxAppend = coerce (listhAppend @_ @i) -shxHead :: ShX (n : sh) i -> SMayNat i SNat n -shxHead (ShX list) = listxHead list +shxHead :: ShX (n : sh) i -> SMayNat i n +shxHead (ShX list) = listhHead list shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) +shxTail (ShX list) = ShX (listhTail list) shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) +shxDropSSX = coerce (listhDrop @i @()) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +shxDropIx (IxX ZX) long = long +shxDropIx (IxX (_ ::% short)) long = case long of _ :$% long' -> shxDropIx (IxX short) long' shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxDropSh = coerce (listhDrop @i @i) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxInit = coerce (listhInit @i) -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast = coerce (listhLast @i) shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh {-# INLINE shxZipWith #-} -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -523,20 +647,24 @@ shxCast' ssh sh = case shxCast ssh sh of -- * Static mixed shapes --- | The part of a shape that is statically known. (A newtype over 'ListX'.) +-- | The part of a shape that is statically known. (A newtype over 'ListH'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) +newtype StaticShX sh = StaticShX (ListH sh ()) + deriving (NFData) + +instance Eq (StaticShX sh) where _ == _ = True +instance Ord (StaticShX sh) where compare _ _ = EQ pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX +pattern ZKX = StaticShX ZH pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) + => SMayNat () n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) + infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} @@ -545,22 +673,17 @@ infixr 3 :!% deriving instance Show (StaticShX sh) #else instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (StaticShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l #endif -instance NFData (StaticShX sh) where - rnf (StaticShX ZX) = () - rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) - instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2 ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l +ssxLength (StaticShX l) = listhLength l ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l +ssxRank (StaticShX l) = listhRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') @@ -570,26 +693,31 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n -ssxHead (StaticShX list) = listxHead list +ssxHead :: StaticShX (n : sh) -> SMayNat () n +ssxHead (StaticShX list) = listhHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) +ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh +ssxTakeIx _ (IxX ZX) _ = ZKX +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short' ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropIx (IxX ZX) long = long +ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) +ssxDropSh = coerce (listhDrop @() @i) + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listhDrop @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxInit = coerce (listhInit @()) -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (listhLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX @@ -632,18 +760,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 065c9fd..2e0c1ca 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -172,6 +172,70 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs +listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i +listhTakeLen PNil _ = ZH +listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh +listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh +listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape" + +listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i +listhDropLen PNil sh = sh +listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape" + +listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i +listhPermute PNil _ = ZH +listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) = + case listhIndex i sh of + SUnknown x -> x `ConsUnknown` listhPermute is sh + SKnown x -> x `ConsKnown` listhPermute is sh + +listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh) +listhIndex SZ (n `ConsUnknown` _) = SUnknown n +listhIndex SZ (n `ConsKnown` _) = SKnown n +listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh') + = listhIndex i sh +listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh') + = listhIndex i sh +listhIndex _ ZH = error "Index into empty shape" + +listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i +listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listhTakeLen @()) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listhDropLen @()) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listhPermute @()) + +ssxIndex :: SNat i -> StaticShX sh -> SMayNat () (Index i sh) +ssxIndex i = coerce (listhIndex @() i) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listhPermutePrefix @()) + +shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) +shxTakeLen = coerce (listhTakeLen @Int) + +shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) +shxDropLen = coerce (listhDropLen @Int) + +shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) +shxPermute = coerce (listhPermute @Int) + +shxIndex :: SNat i -> IShX sh -> SMayNat Int (Index i sh) +shxIndex i = coerce (listhIndex @Int i) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listhPermutePrefix @Int) + + listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f listxTakeLen PNil _ = ZX listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh @@ -185,14 +249,14 @@ listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f listxPermute PNil _ = ZX listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh + listxIndex i sh ::% listxPermute is sh -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) +listxIndex :: forall f i sh. SNat i -> ListX sh f -> f (Index i sh) +listxIndex SZ (n ::% _) = n +listxIndex (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" + = listxIndex i sh +listxIndex _ ZX = error "Index into empty shape" listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) @@ -200,25 +264,6 @@ listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm s ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) -ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) - -ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - - -- * Operations on permutations permInverse :: Perm is diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index b6bee2e..36f49dc 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,8 +1,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -37,7 +35,6 @@ import Data.Kind (Type) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#) -import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits @@ -216,8 +213,7 @@ listrPermutePrefix = \perm sh -> type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR @@ -243,8 +239,6 @@ instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l #endif -instance NFData i => NFData (IxR sh i) - ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l @@ -310,8 +304,7 @@ ixrToLinear = \sh i -> go sh i 0 type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR = ShR ZR @@ -335,8 +328,6 @@ instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = listrShow shows l #endif -instance NFData i => NFData (ShR sh i) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index d23a025..85042f2 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -246,9 +246,7 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr +sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index e2ec416..b86bfe5 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -26,7 +26,6 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy -import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -132,7 +131,7 @@ instance Elt a => Elt (Shaped sh a) where type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + mshapeTree (Shaped arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -256,13 +255,4 @@ satan2Array = liftShaped2 matan2Array sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shsFromShX (mshape arr) - --- Needed already here, but re-exported in Data.Array.Nested.Convert. -shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = - castWith (subst1 (sym (lemMapJustCons Refl))) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) -shsFromShX (SUnknown _ :$% _) = error "impossible" +sshape (Shaped arr) = coerce (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index bfc6ad2..9d463a9 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,10 +1,9 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} @@ -39,7 +38,6 @@ import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) -import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits @@ -136,6 +134,17 @@ listsFromList topsh topl = go topsh topl ++ show (shsLength topsh) ++ ", list has length " ++ show (length topl) ++ ")" +{-# INLINEABLE listsFromListS #-} +listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) +listsFromListS topl0 topl = go topl0 topl + where + go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) + go ZS [] = ZS + go (_ ::$ l0) (i : is) = Const i ::$ go l0 is + go _ _ = error $ "listsFromListS: Mismatched list length (the model says " + ++ show (listsLength topl0) ++ ", list has length " + ++ show (length topl) ++ ")" + {-# INLINEABLE listsToList #-} listsToList :: ListS sh (Const i) -> [i] listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> @@ -185,16 +194,16 @@ listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of + case listsIndex i sh of item -> item ::$ listsPermute is sh --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh) -listsIndex _ _ SZ (n ::$ _) = n -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) +-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed +listsIndex :: forall f i sh. SNat i -> ListS sh f -> f (Index i sh) +listsIndex SZ (n ::$ _) = n +listsIndex (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" + = listsIndex i sh +listsIndex _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) @@ -205,7 +214,7 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, Generic) + deriving (Eq, Ord, NFData) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -247,8 +256,6 @@ instance Foldable (IxS sh) where null ZIS = False null _ = True -instance NFData i => NFData (IxS sh i) - ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -258,6 +265,10 @@ ixsRank (IxS l) = listsRank l ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i ixsFromList = coerce (listsFromList @_ @i) +{-# INLINEABLE ixsFromIxS #-} +ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS = coerce (listsFromListS @_ @i0 @i) + {-# INLINEABLE ixsToList #-} ixsToList :: forall sh i. IxS sh i -> [i] ixsToList = coerce (listsToList @_ @i) @@ -278,11 +289,9 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i -ixsCast ZSS ZIS = ZIS -ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx -ixsCast _ _ = error "ixsCast: ranks don't match" +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -318,21 +327,34 @@ ixsToLinear = \sh i -> go sh i 0 -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Generic) +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS +pattern ZSS <- ShS (matchZS -> Just Refl) + where ZSS = ShS ZSX + +matchZS :: forall sh f. ShX (MapJust sh) f -> Maybe (sh :~: '[]) +matchZS ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZS _ = Nothing pattern (:$$) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) +pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl)) + where i :$$ ShS shl = ShS (SKnown i :$% shl) + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) + | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing infixr 3 :$$ @@ -342,15 +364,13 @@ infixr 3 :$$ deriving instance Show (ShS sh) #else instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + showsPrec d (ShS shx) = showsPrec d shx #endif -instance NFData (ShS sh) where - rnf (ShS ZS) = () - rnf (ShS (SNat ::$ l)) = rnf (ShS l) - instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) @@ -358,10 +378,13 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l +shsLength (ShS shx) = shxLength shx -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Rank (MapJust sh) :~: Rank sh) $ + shxRank shx shsSize :: ShS sh -> Int shsSize ZSS = 1 @@ -391,31 +414,68 @@ shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> in go topsh) shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @_ @_ @Int) -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @_ @Int) + +shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce shxTakeLen -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) +shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLen = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce shxDropLen -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce shxPermute -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh)) +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i sh = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex i (coerce sh) of + SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 6389e67..a9fc14c 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -285,7 +285,7 @@ sumInner ssh ssh' arr go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) + , let sn = ssxRank ssh = XArray (liftO1 (numEltSum1Inner sn) arr') in go $ |
