diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Nested.hs | 37 | ||||
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 363 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 59 | ||||
-rw-r--r-- | src/Data/Array/Nested/Lemmas.hs (renamed from src/Data/Array/Mixed/Lemmas.hs) | 108 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 53 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 101 | ||||
-rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 33 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 62 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 20 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 61 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 37 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 17 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 71 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace.hs | 8 | ||||
-rw-r--r-- | src/Data/Array/Nested/Types.hs | 16 | ||||
-rw-r--r-- | src/Data/Array/XArray.hs | 14 | ||||
-rw-r--r-- | src/GHC/TypeLits/Orphans.hs | 13 |
17 files changed, 648 insertions, 425 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9801529..c3635e9 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -10,8 +10,9 @@ module Data.Array.Nested ( rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, remptyArray, rrerank, - rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, - rfromListLinear, rfromListPrimLinear, rtoListLinear, + rreplicate, rreplicateScal, + rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear, + rtoList, rtoListOuter, rtoListLinear, rslice, rrev1, rreshape, rflatten, riota, rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot, rnest, runNest, rzip, runzip, @@ -19,7 +20,7 @@ module Data.Array.Nested ( rlift, rlift2, -- ** Conversions rtoXArrayPrim, rfromXArrayPrim, - rcastToShaped, rtoMixed, rcastToMixed, + rtoMixed, rcastToMixed, rcastToShaped, rfromOrthotope, rtoOrthotope, -- ** Additional arithmetic operations -- @@ -36,8 +37,9 @@ module Data.Array.Nested ( -- TODO: sconcat? What should its type be? semptyArray, srerank, - sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, - sfromListLinear, sfromListPrimLinear, stoListLinear, + sreplicate, sreplicateScal, + sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear, + stoList, stoListOuter, stoListLinear, sslice, srev1, sreshape, sflatten, siota, sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot, snest, sunNest, szip, sunzip, @@ -45,7 +47,7 @@ module Data.Array.Nested ( slift, slift2, -- ** Conversions stoXArrayPrim, sfromXArrayPrim, - stoRanked, stoMixed, scastToMixed, + stoMixed, scastToMixed, stoRanked, sfromOrthotope, stoOrthotope, -- ** Additional arithmetic operations -- @@ -63,8 +65,9 @@ module Data.Array.Nested ( mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, memptyArray, mrerank, - mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, - mfromListLinear, mfromListPrimLinear, mtoListLinear, + mreplicate, mreplicateScal, + mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear, + mtoList, mtoListOuter, mtoListLinear, mslice, mrev1, mreshape, mflatten, miota, mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot, mnest, munNest, mzip, munzip, @@ -73,8 +76,8 @@ module Data.Array.Nested ( -- ** Conversions mtoXArrayPrim, mfromXArrayPrim, mcast, - mtoRanked, mcastToShaped, - castCastable, Castable(..), + mcastToShaped, mtoRanked, + convert, Conversion(..), -- ** Additional arithmetic operations -- -- $integralRealFloat @@ -102,23 +105,23 @@ module Data.Array.Nested ( import Prelude hiding (mappend, mconcat) -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types import Data.Array.Nested.Convert import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shaped import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith import Foreign.Storable import GHC.TypeLits -- $integralRealFloat -- --- These functions separate top-level functions, and not exposed in instances --- for 'RealFloat' and 'Integral', because those classes include a variety of --- other functions that make no sense for arrays. +-- These functions are separate top-level functions, and not exposed in +-- instances for 'RealFloat' and 'Integral', because those classes include a +-- variety of other functions that make no sense for arrays. -- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but -- having 'Num', 'Fractional' and 'Floating' available is just too useful. diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d5e6008..8c88d23 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,42 +1,293 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) {-# LANGUAGE TypeAbstractions #-} +#endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Data.Array.Nested.Convert ( - castCastable, - Castable(..), + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', - -- * Special cases + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions -- - -- | These functions can all be implemented using 'castCastable' in some way, + -- | These functions can all be implemented using 'convert' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, - - -- * Additional index/shape casting functions - ixrFromIxS, shrFromShS, ) where import Control.Category import Data.Proxy import Data.Type.Equality +import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: IxS sh i -> IxR (Rank sh) i +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported +-- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx + +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +-- shsFromShX re-exported + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn@SNat :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl +-- * Special cases of array conversions + mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a mcast ssh2 arr @@ -45,7 +296,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) +mtoRanked = convert ConvXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -59,7 +310,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) +mcastToShaped targetsh = convert (ConvXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr @@ -82,91 +333,3 @@ rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr - -ixrFromIxS :: IxS sh i -> IxR (Rank sh) i -ixrFromIxS ZIS = ZIR -ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix - --- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh --- ixsFromIxR = \ix -> go ix _ --- where --- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r --- go ZIR k = k ZIS --- go (i :.: ix) k = go ix (i :.$) - -shrFromShS :: ShS sh -> IShR (Rank sh) -shrFromShS ZSS = ZSR -shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh - --- | The constructors that perform runtime shape checking are marked with a --- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure --- that the shapes are already compatible. To convert between 'Ranked' and --- 'Shaped', go via 'Mixed'. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Elt b - => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - - CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' - -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) - where - -- The 'esh' is the extension shape: the casting happens under a whole - -- bunch of additional dimensions that it does not touch. These dimensions - -- are 'esh'. - -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) - = let x' = go c x - ssx' = ssxAppend (ssxFromShX esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) - in M_Ranked (M_Nest esh (mcast ssx' x')) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) - (go c x))) - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) - - lemRankAppRankEq :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ sh') - lemRankAppRankEq _ _ _ = unsafeCoerceRefl - - lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh - -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) - lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl - - lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs deleted file mode 100644 index b1589e0..0000000 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemTakeLenMapJust PNil _ = Refl -lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl -lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemDropLenMapJust PNil _ = Refl -lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl -lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" - -lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemIndexMapJust SZ (_ :$$ _) = Refl -lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemIndexMapJust i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemIndexMapJust _ ZSS = error "Index of empty" - -lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemPermuteMapJust PNil _ = Refl -lemPermuteMapJust (i `PCons` is) sh - | Refl <- lemPermuteMapJust is sh - , Refl <- lemIndexMapJust i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e6d970c..e089479 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -6,7 +6,7 @@ {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Lemmas where +module Data.Array.Nested.Lemmas where import Data.Proxy import Data.Type.Equality @@ -14,10 +14,11 @@ import GHC.TypeLits import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types --- * Lemmas +-- * Lemmas about numbers and lists -- ** Nat @@ -27,7 +28,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl - -- ** Append lemAppNil :: l ++ '[] :~: l @@ -39,42 +39,22 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l lemAppLeft _ Refl = Refl - --- ** Rank - -lemRankApp :: forall sh1 sh2. - StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp ZKX _ = Refl -lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 - = lem (Proxy @(Rank sh1T)) Proxy Proxy $ - sym (lemRankApp ssh1 ssh2) - where - lem :: proxy a -> proxy b -> proxy c - -> (a + b :~: c) - -> c + 1 :~: (a + 1 + b) - lem _ _ _ Refl = Refl - -lemRankAppComm :: proxy sh1 -> proxy sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl - -lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = unsafeCoerceRefl - - --- ** Various type families +-- ** Simple type families lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +{- for now, the plugins can't derive a type for this code, see + https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214 lemReplicatePlusApp sn _ _ = go sn where go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 + | Refl <- lemReplicateSucc @a n , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) + = sym (lemReplicateSucc @a (SNat @(n'm1 + m))) +-} +lemReplicatePlusApp _ _ _ = unsafeCoerceRefl lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest @@ -107,6 +87,8 @@ lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +-- * Lemmas about shapes + -- ** Known shapes lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) @@ -116,3 +98,69 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh lemKnownShX ZKX = Dict lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54bd5f2..144230e 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -42,13 +42,13 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope import Data.Array.XArray (XArray(..)) import Data.Array.XArray qualified as X -import Data.Array.Nested.Mixed.Shape -import Data.Array.Strided.Orthotope import Data.Bag @@ -380,9 +380,7 @@ class Elt a where -- of this class with those of 'Elt': some instances have an additional -- "known-shape" constraint. -- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. class Elt a => KnownElt a where -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArrayUnsafe :: IShX sh -> Mixed sh a @@ -398,7 +396,7 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 @@ -440,7 +438,7 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -557,13 +555,13 @@ instance Elt a => Elt (Mixed sh' a) where = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i + mindex (M_Nest _ arr) = mindexPartial arr mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -632,14 +630,14 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh + | let sh' = shxDropSh @sh @sh' sh (mshape arr) , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -784,14 +782,10 @@ mtoVector arr = mtoVectorP (toPrimitive arr) mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1 l) mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromListPrim l = @@ -804,10 +798,8 @@ mfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) +mtoList :: Elt a => Mixed '[n] a -> [a] +mtoList = map munScalar . mtoListOuter mtoListLinear :: Elt a => Mixed sh a -> [a] mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise @@ -821,8 +813,11 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 +-- | The arguments must have equal shapes. If they do not, an error is raised. +mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip a b + | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b + | otherwise = error "mzip: unequal shapes" munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) munzip (M_Tup2 a b) = (a, b) @@ -832,13 +827,13 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + let sh1 = shxDropSSX ssh sh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) --- | See the caveats at @X.rerank@. +-- | See the caveats at 'Data.Array.XArray.rerank'. mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => StaticShX sh -> IShX sh2 -> (Mixed sh1 a -> Mixed sh2 b) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 2f35ff9..8777739 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -31,13 +31,15 @@ import Data.Functor.Const import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality import GHC.Exts (withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import GHC.TypeLits.Orphans () +#endif import Data.Array.Nested.Types @@ -146,9 +148,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh @@ -160,19 +162,17 @@ listxLast (x ::% ZX) = x listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = - Pair i j ::% listxZip irest jrest +listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g -> ListX sh h listxZipWith _ ZX ZX = ZX -listxZipWith f (i ::% is) (j ::% js) = - f i j ::% listxZipWith f is js +listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js -- * Mixed indices --- | This is a newtype over 'ListX'. +-- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) @@ -191,6 +191,8 @@ infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxX sh = IxX sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -234,7 +236,7 @@ ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixxAppend = coerce (listxAppend @_ @(Const i)) -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i ixxDrop = coerce (listxDrop @(Const i) @(Const i)) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i @@ -243,6 +245,11 @@ ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js @@ -390,10 +397,10 @@ shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int +shxFromList :: StaticShX sh -> [Int] -> IShX sh shxFromList topssh topl = go topssh topl where - go :: StaticShX sh' -> [Int] -> ShX sh' Int + go :: StaticShX sh' -> [Int] -> IShX sh' go ZKX [] = ZSX go (SKnown sn :!% sh) (i : is) | i == fromSNat' sn = SKnown sn :$% go sh is @@ -408,11 +415,18 @@ shxToList :: IShX sh -> [Int] shxToList ZSX = [] shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" + -- | This may fail if @sh@ has @Nothing@s in it. -shxFromSSX' :: StaticShX sh -> Maybe (IShX sh) -shxFromSSX' ZKX = Just ZSX -shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh -shxFromSSX' (SUnknown _ :!% _) = Nothing +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) @@ -423,13 +437,13 @@ shxHead (ShX list) = listxHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i @@ -438,12 +452,9 @@ shxInit = coerce (listxInit @(SMayNat i SNat)) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) shxLast = coerce (listxLast @(SMayNat i SNat)) -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) -> ShX sh i -> ShX sh j -> ShX sh k @@ -456,7 +467,7 @@ shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) @@ -467,17 +478,17 @@ shxEnum = \sh -> go sh id [] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" @@ -537,9 +548,15 @@ ssxHead (StaticShX list) = listxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) + +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) + ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) ssxInit = coerce (listxInit @(SMayNat () SNat)) @@ -549,20 +566,20 @@ ssxLast = coerce (listxLast @(SMayNat () SNat)) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) -ssxFromShX :: IShX sh -> StaticShX sh +ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n -- | Evidence for the static part of a shape. This pops up only when you are @@ -574,7 +591,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX k = withDict @(KnownShX sh) k +withKnownShX = withDict @(KnownShX sh) -- * Flattening diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 031755f..045b18f 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -1,10 +1,10 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +25,7 @@ import Data.Proxy import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord +import GHC.Exts (withDict) import GHC.TypeError import GHC.TypeLits import GHC.TypeNats qualified as TN @@ -36,8 +37,8 @@ import Data.Array.Nested.Types -- * Permutations -- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. data Perm list where PNil :: Perm '[] PCons :: SNat a -> Perm l -> Perm (a : l) @@ -45,6 +46,13 @@ infixr 5 `PCons` deriving instance Show (Perm list) deriving instance Eq (Perm list) +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + permRank :: Perm list -> SNat (Rank list) permRank PNil = SNat permRank (_ `PCons` l) | SNat <- permRank l = SNat @@ -119,6 +127,9 @@ class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + -- | Untyped permutations for ranked arrays type PermR = [Int] @@ -199,7 +210,7 @@ ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) +ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) @@ -224,7 +235,7 @@ permInverse = \perm k -> ++ " ; invperm = " ++ show invperm) (permCheckPermutation invperm (k invperm - (\ssh -> case provePermInverse perm invperm ssh of + (\ssh -> case permCheckInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm))) @@ -238,9 +249,9 @@ permInverse = \perm k -> toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - provePermInverse :: Perm is -> Perm is' -> StaticShX sh + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = + permCheckInverse perm perminv ssh = ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where @@ -264,7 +275,13 @@ lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX PNil = Refl -lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) + | Refl <- lemRankDropLen sh is +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) + = Refl +#else + = unsafeCoerceRefl +#endif lemRankDropLen (_ :!% _) PNil = Refl lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e5c51ef..8b95d0f 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -29,17 +29,17 @@ import Foreign.Storable (Storable) import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) -import Data.Array.XArray qualified as X import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X remptyArray :: KnownElt a => Ranked 1 a @@ -85,7 +85,7 @@ rsumOuter1P :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (NumElt a, PrimElt a) @@ -108,7 +108,7 @@ rtranspose perm arr rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce mconcat rappend :: forall n a. Elt a @@ -116,7 +116,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n) = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -137,38 +137,32 @@ rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a rfromList1 l = Ranked (mfromList1 l) -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = rreshape sh (rfromList1 l) rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr +rfromListPrim l = Ranked (mfromListPrim l) rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a rfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoListLinear :: Elt a => Ranked n a -> [a] rtoListLinear (Ranked arr) = mtoListLinear arr @@ -179,9 +173,9 @@ rfromOrthotope sn arr = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -197,7 +191,7 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked arr -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) @@ -230,7 +224,7 @@ rrerankP sn sh2 f (Ranked arr) -- then: -- -- @ --- rrerank _ _ _ f arr :: Ranked 5 Float +-- rrerank _ _ _ f arr :: Ranked 6 Float -- @ -- -- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the @@ -261,7 +255,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = rlift (rrank arr) (\_ -> X.sliceU i n) arr @@ -270,7 +264,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 arr = rlift (rrank arr) (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of + case lemReplicateSucc @(Nothing @Nat) (Proxy @n) of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) arr diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index f50f671..babc809 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +26,7 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy +import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -34,13 +36,13 @@ import GHC.TypeLits import Data.Foldable (toList) #endif -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -252,3 +254,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c0c4f17..488b4d8 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -39,11 +39,12 @@ import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Lemmas import Data.Array.Nested.Types +-- * Ranked lists + type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where @@ -114,21 +115,21 @@ listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i -listrHead ZR = error "unreachable" listrTail :: ListR (n + 1) i -> ListR n i listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" listrInit :: ListR (n + 1) i -> ListR n i listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" listrLast :: ListR (n + 1) i -> i listrLast (_ ::: sh@(_ ::: _)) = listrLast sh listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" + +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x @@ -172,6 +173,8 @@ listrPermutePrefix = \perm sh -> GTI -> error "listrPermutePrefix: Index in permutation out of range" +-- * Ranked indices + -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type @@ -192,6 +195,8 @@ infixr 3 :.: {-# COMPLETE ZIR, (:.:) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxR n = IxR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -213,16 +218,6 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx - -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixxFromIxR idx) - ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -235,6 +230,10 @@ ixrInit (IxR list) = IxR (listrInit list) ixrLast :: IxR (n + 1) i -> i ixrLast (IxR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) @@ -248,6 +247,8 @@ ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- * Ranked shapes + type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) @@ -278,22 +279,6 @@ instance Show i => Show (ShR n i) where instance NFData i => NFData (ShR sh i) -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh - | Refl <- lemRankReplicate (Proxy @n) - = shrFromShX sh - -shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shxFromShR idx) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') @@ -329,6 +314,10 @@ shrInit (ShR list) = ShR (listrInit list) shrLast :: ShR (n + 1) i -> i shrLast (ShR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) @@ -366,3 +355,11 @@ instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 7e38aee..198a068 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -29,18 +29,17 @@ import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a @@ -124,20 +123,14 @@ stoVectorP = coerce mtoVectorP stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) - sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromListPrim sn l @@ -151,8 +144,11 @@ sfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) +stoList :: Elt a => Shaped '[n] a -> [a] +stoList = map sunScalar . stoListOuter + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) stoListLinear :: Elt a => Shaped sh a -> [a] stoListLinear (Shaped arr) = mtoListLinear arr @@ -177,7 +173,7 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') = Shaped arr -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b) szip = coerce mzip sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) @@ -190,11 +186,12 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh1) , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr)))) (shxFromShS sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) +-- | See the caveats at 'Data.Array.XArray.rerank'. srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => ShS sh -> ShS sh2 -> (Shaped sh1 a -> Shaped sh2 b) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 529ac21..879e6b5 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,18 +26,19 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy +import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) -- | A shape-typed array: the full shape of the array (the sizes of its @@ -242,3 +244,12 @@ satan2Array = liftShaped2 matan2Array sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0b7d1c9..aa98707 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,12 +1,9 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} @@ -48,6 +45,10 @@ import Data.Array.Nested.Permutation import Data.Array.Nested.Types +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -144,14 +145,12 @@ listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = - Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = - f i j ::$ listsZipWith f is js +listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS @@ -180,11 +179,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) @@ -193,6 +190,8 @@ newtype IxS sh i = IxS (ListS sh (Const i)) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) @@ -203,6 +202,8 @@ infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -230,14 +231,6 @@ ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh - ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) @@ -250,14 +243,20 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) -ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js -ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js @@ -265,6 +264,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- * Shaped shapes + -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you @@ -321,23 +322,6 @@ shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shsFromShX (SUnknown _ :$% _) = error "impossible" - -shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh - shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list @@ -381,7 +365,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -391,6 +375,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index 838e2b0..8a29aa5 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -37,10 +37,12 @@ module Data.Array.Nested.Trace ( ShS(..), KnownShS(..), Mixed, + ListX(ZX, (::%)), IxX(..), IIxX, - ShX(..), KnownShX(..), + ShX(..), KnownShX(..), IShX, StaticShX(..), SMayNat(..), + Conversion(..), Elt, PrimElt, @@ -54,7 +56,7 @@ module Data.Array.Nested.Trace ( Perm(..), IsPermutation, KnownPerm(..), - NumElt, FloatElt, + NumElt, IntElt, FloatElt, Rank, Product, Replicate, MapJust, @@ -67,4 +69,4 @@ import Data.Array.Nested.Trace.TH $(concat <$> mapM convertFun - ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped]) + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index 4172fa0..a43ae0c 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -6,7 +6,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} @@ -30,6 +30,7 @@ module Data.Array.Nested.Types ( Replicate, lemReplicateSucc, MapJust, + lemMapJustEmpty, lemMapJustCons, Head, Tail, Init, @@ -108,13 +109,20 @@ type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl +lemReplicateSucc :: forall a n proxy. + proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc _ = unsafeCoerceRefl -type family MapJust l where +type family MapJust l = r | r -> l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + type family Head l where Head (x : _) = x diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index dde06e3..92d9e13 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} @@ -30,11 +31,14 @@ import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import Unsafe.Coerce (unsafeCoerce) +#endif -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape import Data.Array.Strided.Orthotope @@ -217,7 +221,11 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) = XArray (S.transpose (permToList' perm) arr) +#else + = XArray (unsafeCoerce (S.transpose (permToList' perm) arr)) +#endif -- | The list argument gives indices into the original dimension list. -- @@ -243,7 +251,7 @@ transpose2 ssh1 ssh2 (XArray arr) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = diff --git a/src/GHC/TypeLits/Orphans.hs b/src/GHC/TypeLits/Orphans.hs new file mode 100644 index 0000000..42f7293 --- /dev/null +++ b/src/GHC/TypeLits/Orphans.hs @@ -0,0 +1,13 @@ +-- | Compatibility module adding some additional instances not yet defined in +-- base-4.18 with GHC 9.6. +{-# OPTIONS -Wno-orphans #-} +module GHC.TypeLits.Orphans where + +import GHC.TypeLits + + +instance Eq (SNat n) where + _ == _ = True + +instance Ord (SNat n) where + compare _ _ = EQ |