diff options
Diffstat (limited to 'src')
21 files changed, 2447 insertions, 1529 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9801529..6d4ae78 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -3,15 +3,19 @@ module Data.Array.Nested ( -- * Ranked arrays Ranked(Ranked), - ListR(ZR, (:::)), IxR(.., ZIR, (:.:)), IIxR, ShR(.., ZSR, (:$:)), IShR, - rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim, + rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim, rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, remptyArray, - rrerank, - rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, - rfromListLinear, rfromListPrimLinear, rtoListLinear, + rrerankPrim, + rreplicate, rreplicatePrim, + rfromListOuter, rfromListOuterN, + rfromList1, rfromList1N, + rfromListLinear, + rfromList1Prim, rfromList1PrimN, + rfromListPrimLinear, + rtoListOuter, rtoList, rtoListLinear, rtoListPrim, rtoListPrimLinear, rslice, rrev1, rreshape, rflatten, riota, rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot, rnest, runNest, rzip, runzip, @@ -19,7 +23,7 @@ module Data.Array.Nested ( rlift, rlift2, -- ** Conversions rtoXArrayPrim, rfromXArrayPrim, - rcastToShaped, rtoMixed, rcastToMixed, + rtoMixed, rcastToMixed, rcastToShaped, rfromOrthotope, rtoOrthotope, -- ** Additional arithmetic operations -- @@ -28,16 +32,16 @@ module Data.Array.Nested ( -- * Shaped arrays Shaped(Shaped), - ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, ShS(.., ZSS, (:$$)), KnownShS(..), - sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim, + sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, -- TODO: sconcat? What should its type be? semptyArray, - srerank, - sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, - sfromListLinear, sfromListPrimLinear, stoListLinear, + srerankPrim, + sreplicate, sreplicatePrim, + sfromListOuter, sfromList1, sfromListLinear, sfromList1Prim, sfromListPrimLinear, + stoListOuter, stoList, stoListLinear, stoListPrim, stoListPrimLinear, sslice, srev1, sreshape, sflatten, siota, sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot, snest, sunNest, szip, sunzip, @@ -45,7 +49,7 @@ module Data.Array.Nested ( slift, slift2, -- ** Conversions stoXArrayPrim, sfromXArrayPrim, - stoRanked, stoMixed, scastToMixed, + stoMixed, scastToMixed, stoRanked, sfromOrthotope, stoOrthotope, -- ** Additional arithmetic operations -- @@ -54,18 +58,22 @@ module Data.Array.Nested ( -- * Mixed arrays Mixed, - ListX(ZX, (::%)), IxX(.., ZIX, (:.%)), IIxX, - ShX(.., ZSX, (:$%)), KnownShX(..), IShX, + ShX(.., (:$%)), KnownShX(..), IShX, StaticShX(.., ZKX, (:!%)), SMayNat(..), - mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim, + mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim, mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, memptyArray, - mrerank, - mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, - mfromListLinear, mfromListPrimLinear, mtoListLinear, - mslice, mrev1, mreshape, mflatten, miota, + mrerankPrim, + mreplicate, mreplicatePrim, + mfromListOuter, mfromListOuterN, mfromListOuterSN, + mfromList1, mfromList1N, mfromList1SN, + mfromListLinear, + mfromList1Prim, mfromList1PrimN, mfromList1PrimSN, + mfromListPrimLinear, + mtoListOuter, mtoList, mtoListLinear, mtoListPrim, mtoListPrimLinear, + msliceN, msliceSN, mslice, mrev1, mreshape, mflatten, miota, mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot, mnest, munNest, mzip, munzip, -- ** Lifting orthotope operations to 'Mixed' arrays @@ -73,8 +81,8 @@ module Data.Array.Nested ( -- ** Conversions mtoXArrayPrim, mfromXArrayPrim, mcast, - mtoRanked, mcastToShaped, - castCastable, Castable(..), + mcastToShaped, mtoRanked, + convert, Conversion(..), -- ** Additional arithmetic operations -- -- $integralRealFloat @@ -91,7 +99,7 @@ module Data.Array.Nested ( Storable, SNat, pattern SNat, pattern SZ, pattern SS, - Perm(..), + Perm(..), PermR, IsPermutation, KnownPerm(..), NumElt, IntElt, FloatElt, @@ -102,23 +110,25 @@ module Data.Array.Nested ( import Prelude hiding (mappend, mconcat) -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types +import Foreign.Storable +import GHC.TypeLits + import Data.Array.Nested.Convert import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shaped +import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith -import Foreign.Storable -import GHC.TypeLits -- $integralRealFloat -- --- These functions separate top-level functions, and not exposed in instances --- for 'RealFloat' and 'Integral', because those classes include a variety of --- other functions that make no sense for arrays. +-- These functions are separate top-level functions, and not exposed in +-- instances for 'RealFloat' and 'Integral', because those classes include a +-- variety of other functions that make no sense for arrays. -- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but -- having 'Num', 'Fractional' and 'Floating' available is just too useful. diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d5e6008..7619bdb 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,42 +1,302 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) {-# LANGUAGE TypeAbstractions #-} +#endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Data.Array.Nested.Convert ( - castCastable, - Castable(..), + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxS', ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX, + ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, ixxFromIxS', shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', - -- * Special cases + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions -- - -- | These functions can all be implemented using 'castCastable' in some way, + -- | These functions can all be implemented using 'convert' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, - - -- * Additional index/shape casting functions - ixrFromIxS, shrFromShS, ) where import Control.Category +import Data.Coerce (coerce) import Data.Proxy import Data.Type.Equality +import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: forall sh i. IxS sh i -> IxR (Rank sh) i +ixrFromIxS + | Refl <- lemRankReplicate (Proxy @(Rank sh)) + , Refl <- lemRankMapJust (Proxy @sh) + = coerce + Prelude.. (coerceEqualRankListX :: ListX (MapJust sh) i -> ListX (Replicate (Rank sh) Nothing) i) + Prelude.. coerce + +ixrFromIxS' :: forall sh i. SNat (Rank sh) -> IxS sh i -> IxR (Rank sh) i +ixrFromIxS' _ + | Refl <- lemRankReplicate (Proxy @(Rank sh)) + , Refl <- lemRankMapJust (Proxy @sh) + = coerce + Prelude.. (coerceEqualRankListX :: ListX (MapJust sh) i -> ListX (Replicate (Rank sh) Nothing) i) + Prelude.. coerce + +-- ixrFromIxX re-exported + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +shrFromShXAnyShape :: IShX sh -> IShR (Rank sh) +shrFromShXAnyShape ZSX = ZSR +shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx + +shrFromShX :: IShX (Replicate n Nothing) -> IShR n +shrFromShX = coerce + +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +ixsFromIxR :: forall sh i. IxR (Rank sh) i -> IxS sh i +ixsFromIxR + | Refl <- lemRankReplicate (Proxy @(Rank sh)) + , Refl <- lemRankMapJust (Proxy @sh) + = coerce + Prelude.. (coerceEqualRankListX :: ListX (Replicate (Rank sh) Nothing) i -> ListX (MapJust sh) i) + Prelude.. coerce + +ixsFromIxR' :: forall sh i. ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR' _ + | Refl <- lemRankReplicate (Proxy @(Rank sh)) + , Refl <- lemRankMapJust (Proxy @sh) + = coerce + Prelude.. (coerceEqualRankListX :: ListX (Replicate (Rank sh) Nothing) i -> ListX (MapJust sh) i) + Prelude.. coerce + +-- ixsFromIxX re-exported + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but less verbose: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +shsFromShX :: IShX (MapJust sh) -> ShS sh +shsFromShX = coerce + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +-- If it ever matters for performance, this is unsafeCoercible. +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +-- ixxFromIxR re-exported +-- ixxFromIxS re-exported + +ixxFromIxS' :: StaticShX sh' -> IxS sh i -> IxX sh' i +ixxFromIxS' sh' = ixxCast sh' Prelude.. ixxFromIxS + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR = coerce + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS = coerce + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'Data.Array.XArray.XArray's. This leads to the inclusion +-- of some operations that do not look like simple conversions (casts) +-- at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl +-- * Special cases of array conversions + mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a mcast ssh2 arr @@ -45,7 +305,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) +mtoRanked = convert ConvXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -59,7 +319,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) +mcastToShaped targetsh = convert (ConvXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr @@ -82,91 +342,3 @@ rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr - -ixrFromIxS :: IxS sh i -> IxR (Rank sh) i -ixrFromIxS ZIS = ZIR -ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix - --- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh --- ixsFromIxR = \ix -> go ix _ --- where --- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r --- go ZIR k = k ZIS --- go (i :.: ix) k = go ix (i :.$) - -shrFromShS :: ShS sh -> IShR (Rank sh) -shrFromShS ZSS = ZSR -shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh - --- | The constructors that perform runtime shape checking are marked with a --- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure --- that the shapes are already compatible. To convert between 'Ranked' and --- 'Shaped', go via 'Mixed'. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Elt b - => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - - CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' - -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) - where - -- The 'esh' is the extension shape: the casting happens under a whole - -- bunch of additional dimensions that it does not touch. These dimensions - -- are 'esh'. - -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) - = let x' = go c x - ssx' = ssxAppend (ssxFromShX esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) - in M_Ranked (M_Nest esh (mcast ssx' x')) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) - (go c x))) - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) - - lemRankAppRankEq :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ sh') - lemRankAppRankEq _ _ _ = unsafeCoerceRefl - - lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh - -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) - lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl - - lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs deleted file mode 100644 index b1589e0..0000000 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemTakeLenMapJust PNil _ = Refl -lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl -lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemDropLenMapJust PNil _ = Refl -lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl -lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" - -lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemIndexMapJust SZ (_ :$$ _) = Refl -lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemIndexMapJust i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemIndexMapJust _ ZSS = error "Index of empty" - -lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemPermuteMapJust PNil _ = Refl -lemPermuteMapJust (i `PCons` is) sh - | Refl <- lemPermuteMapJust is sh - , Refl <- lemIndexMapJust i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e6d970c..850f8ea 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -6,7 +6,11 @@ {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Lemmas where +module Data.Array.Nested.Lemmas ( + module Data.Array.Nested.Lemmas, + lemReplicateSucc, lemMapJustEmpty, lemMapJustCons, lemMapJustHead, + lemRankMapJust +) where import Data.Proxy import Data.Type.Equality @@ -14,10 +18,11 @@ import GHC.TypeLits import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types --- * Lemmas +-- * Lemmas about numbers and lists -- ** Nat @@ -27,7 +32,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl - -- ** Append lemAppNil :: l ++ '[] :~: l @@ -39,42 +43,40 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l lemAppLeft _ Refl = Refl - --- ** Rank - -lemRankApp :: forall sh1 sh2. - StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp ZKX _ = Refl -lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 - = lem (Proxy @(Rank sh1T)) Proxy Proxy $ - sym (lemRankApp ssh1 ssh2) - where - lem :: proxy a -> proxy b -> proxy c - -> (a + b :~: c) - -> c + 1 :~: (a + 1 + b) - lem _ _ _ Refl = Refl - -lemRankAppComm :: proxy sh1 -> proxy sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl - -lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = unsafeCoerceRefl - - --- ** Various type families +-- ** Simple type families lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +{- for now, the plugins can't derive a type for this code, see + https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214 lemReplicatePlusApp sn _ _ = go sn where go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 + | Refl <- lemReplicateSucc @a n , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) + = sym (lemReplicateSucc @a (SNat @(n'm1 + m))) +-} +lemReplicatePlusApp _ _ _ = unsafeCoerceRefl + +lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0 +lemReplicateEmpty _ Refl = unsafeCoerceRefl + +-- TODO: make less ad-hoc and rename the following few: +lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1 +lemReplicateCons _ _ Refl = unsafeCoerceRefl + +lemReplicateCons2 :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> sh :~: Replicate (Rank sh) Nothing +lemReplicateCons2 _ _ Refl = unsafeCoerceRefl + +lemReplicateSucc2 :: forall n1 n proxy. + proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing +lemReplicateSucc2 _ _ = unsafeCoerceRefl + +-- TODO: simplify, but GHC doesn't consistently use congruence nor transitivity +lemReplicateHead :: proxy x -> proxy' sh -> proxy'' t -> proxy''' n -> x : sh :~: Replicate n t -> x :~: t +lemReplicateHead _ _ _ _ Refl = unsafeCoerceRefl lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest @@ -107,6 +109,8 @@ lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +-- * Lemmas about shapes + -- ** Known shapes lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) @@ -116,3 +120,65 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh lemKnownShX ZKX = Dict lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54bd5f2..7371c4b 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -7,6 +7,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -22,12 +23,14 @@ module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) +import Control.Monad (foldM_, forM_, when) import Control.Monad.ST +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce -import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) @@ -38,22 +41,22 @@ import Data.Vector.Storable qualified as VS import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types (CInt) import Foreign.Storable (Storable) +import Foreign.Storable qualified as Storable import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope import Data.Array.XArray (XArray(..)) import Data.Array.XArray qualified as X -import Data.Array.Nested.Mixed.Shape -import Data.Array.Strided.Orthotope import Data.Bag -- TODO: --- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int -- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) -- After benchmarking: matmul and matvec @@ -91,6 +94,9 @@ import Data.Bag -- Unfortunately, the setup of the library requires us to list these primitive -- element types multiple times; to aid in extending the list, all these lists -- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. +-- +-- NOTE: if you add PRIMITIVE types, be sure to also add NumElt and IntElt +-- instances for them! -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -118,6 +124,8 @@ instance PrimElt Bool instance PrimElt Int instance PrimElt Int64 instance PrimElt Int32 +instance PrimElt Int16 +instance PrimElt Int8 instance PrimElt CInt instance PrimElt Float instance PrimElt Double @@ -154,6 +162,8 @@ newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int16 = M_Int16 (Mixed sh (Primitive Int16)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int8 = M_Int8 (Mixed sh (Primitive Int8)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW) @@ -190,6 +200,8 @@ newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh Int16 = MV_Int16 (VS.MVector s Int16) +newtype instance MixedVecs s sh Int8 = MV_Int8 (VS.MVector s Int8) newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) @@ -227,11 +239,13 @@ instance Elt a => NFData (Mixed sh a) where rnf = mrnf +{-# INLINE mliftNumElt1 #-} mliftNumElt1 :: (PrimElt a, PrimElt b) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) -> Mixed sh a -> Mixed sh b mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) +{-# INLINE mliftNumElt2 #-} mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) -> Mixed sh a -> Mixed sh b -> Mixed sh c @@ -247,15 +261,15 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where abs = mliftNumElt1 (liftO1 . numEltAbs) signum = mliftNumElt1 (liftO1 . numEltSignum) -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS - fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" + fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicatePrim" recip = mliftNumElt1 (liftO1 . floatEltRecip) (/) = mliftNumElt2 (liftO2 . floatEltDiv) instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicatePrim" exp = mliftNumElt1 (liftO1 . floatEltExp) log = mliftNumElt1 (liftO1 . floatEltLog) sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) @@ -287,9 +301,10 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. +-- | Allowable element types in a mixed array, and by extension +-- in a 'Data.Array.Nested.Ranked.Ranked' or 'Data.Array.Nested.Shaped.Shaped' +-- array. Note the polymorphic instance for 'Elt' of @'Primitive' a@; +-- see the documentation for 'Primitive' for more details. class Elt a where -- ====== PUBLIC METHODS ====== -- @@ -298,15 +313,9 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + -- | See 'mfromListOuter'. If the list does not have the given length, a + -- runtime error is thrown. + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -340,8 +349,8 @@ class Elt a where mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - -- | All arrays in the input must have equal shapes, including subarrays - -- inside their elements. + -- | All arrays in the input must have equal shapes (except possibly + -- for the outermost dimension) including subarrays inside their elements. mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a mrnf :: Mixed sh a -> () @@ -351,11 +360,14 @@ class Elt a where -- | Tree giving the shape of every array component. type ShapeTree a + -- | Produces an internal representation of a tree of shapes of (potentially) + -- nested arrays. If the argument is an array, it requires that the array + -- is not empty (otherwise, its' guaranteed to crash early, if non-trivial). mshapeTree :: a -> ShapeTree a mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool mshowShapeTree :: Proxy a -> ShapeTree a -> String @@ -363,26 +375,28 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | 'mvecsFreeze' but without copying the mutable vectors before freezing + -- them. If the mutable vectors are mutated after this function, referential + -- transparency is broken. + mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances -- of this class with those of 'Elt': some instances have an additional -- "known-shape" constraint. -- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. class Elt a => KnownElt a where -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArrayUnsafe :: IShX sh -> Mixed sh a @@ -391,20 +405,30 @@ class Elt a => KnownElt a where -- this vector and an example for the contents. mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + -- | Create initialised vectors for this array type, given the shape of + -- this vector and the chosen element. + mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where + -- Somehow, INLINE here can increase allocation with GHC 9.14.1. + -- Maybe that happens in void instances such as @Primitive ()@. + {-# INLINEABLE mshape #-} mshape (M_Primitive sh _) = sh + {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + {-# INLINEABLE mindexPartial #-} + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mfromListOuterSN sn l@(arr1 :| _) = + let sh = mshape arr1 + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l)) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -415,6 +439,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) @@ -426,6 +451,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -440,7 +466,7 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -457,27 +483,33 @@ instance Storable a => Elt (Primitive a) where type ShapeTree (Primitive a) = () mshapeTree _ = () mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False + mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x - -- TODO: this use of toVector is suboptimal - mvecsWritePartial + -- TODO: this use of toVectorListT is suboptimal + mvecsWritePartialLinear :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + offset = i * shxSize arrsh + f off el = do + VS.copy (VSM.slice off (VS.length el) v) el + return $! off + VS.length el + foldM_ f offset (OI.toVectorListT sht t) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v -- [PRIMITIVE ELEMENT TYPES LIST] deriving via Primitive Bool instance Elt Bool deriving via Primitive Int instance Elt Int deriving via Primitive Int64 instance Elt Int64 deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive Int16 instance Elt Int16 +deriving via Primitive Int8 instance Elt Int8 deriving via Primitive CInt instance Elt CInt deriving via Primitive Double instance Elt Double deriving via Primitive Float instance Elt Float @@ -486,6 +518,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -493,6 +526,8 @@ deriving via Primitive Bool instance KnownElt Bool deriving via Primitive Int instance KnownElt Int deriving via Primitive Int64 instance KnownElt Int64 deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive Int16 instance KnownElt Int16 +deriving via Primitive Int8 instance KnownElt Int8 deriving via Primitive CInt instance KnownElt CInt deriving via Primitive Double instance KnownElt Double deriving via Primitive Float instance KnownElt Float @@ -500,16 +535,22 @@ deriving via Primitive () instance KnownElt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where + {-# INLINEABLE mshape #-} mshape (M_Tup2 a _) = mshape a + {-# INLINEABLE mindex #-} mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + {-# INLINEABLE mindexPartial #-} mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mfromListOuterSN sn l = + M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + {-# INLINE mlift #-} mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + {-# INLINE mlift2 #-} mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + {-# INLINE mliftL #-} mliftL ssh2 f = let unzipT2l [] = ([], []) unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) @@ -531,20 +572,22 @@ instance (Elt a, Elt b) => Elt (a, b) where type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. @@ -552,38 +595,41 @@ instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. + {-# INLINEABLE mshape #-} mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + = shxTakeSh (Proxy @sh') sh (mshape arr) + {-# INLINEABLE mindex #-} mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i + mindex (M_Nest _ arr) = mindexPartial arr + {-# INLINEABLE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + mfromListOuterSN sn l@(arr :| _) = + M_Nest (SKnown sn :$% mshape arr) + (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l)) mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -591,16 +637,17 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -609,16 +656,17 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -632,14 +680,14 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh + | let sh' = shxDropSh @sh @sh' sh (mshape arr) , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -651,34 +699,39 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result + in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + -- This requires that @arr@ is not empty. mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + -- the array is empty if either there are no subarrays, or the subarrays themselves are empty + mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) @@ -689,10 +742,30 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where where sh' = mshape example + mvecsReplicate sh example = do + vecs <- mvecsUnsafeNew sh example + forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs + -- this is a slow case, but the alternative, mvecsUnsafeNew with manual + -- writing in a loop, leads to every case being as slow + return vecs + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx = mvecsWriteLinear (ixxToLinear sh idx) + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) + +-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere +memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) mrank :: Elt a => Mixed sh a -> SNat (Rank sh) @@ -719,38 +792,56 @@ msize = shxSize . mshape -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. +-- +-- If your element type @a@ is a scalar, use the faster 'mgeneratePrim'. mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a mgenerate sh f = case shxEnum sh of [] -> memptyArrayUnsafe sh firstidx : restidxs -> let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree + in if mshapeTreeIsEmpty (Proxy @a) shapetree then memptyArrayUnsafe sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ error "Data.Array.Nested mgenerate: generated values do not have equal shapes" mvecsWrite sh idx val vecs - mvecsFreeze sh vecs + mvecsUnsafeFreeze sh vecs + +-- | An optimized special case of 'mgenerate', where the function results +-- are of a primitive type and so there's not need to check that all shapes +-- are equal. This is also generalized to an arbitrary @Num@ index type +-- compared to @mgenerate@. +{-# INLINE mgeneratePrim #-} +mgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => IShX sh -> (IxX sh i -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = +{-# INLINEABLE msumOuter1PrimP #-} +msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive +{-# INLINEABLE msumOuter1Prim #-} +msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive +{-# INLINEABLE msumAllPrimP #-} +msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a +msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr + +{-# INLINEABLE msumAllPrim #-} msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +msumAllPrim arr = msumAllPrimP (toPrimitive arr) mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -759,7 +850,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 ssh = ssxFromShX sh - snm :: SMayNat () SNat (AddMaybe n m) + snm :: SMayNat () (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () @@ -769,82 +860,176 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b f ssh' = X.append (ssxAppend ssh ssh') +{-# INLINEABLE mfromVectorP #-} mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) +{-# INLINEABLE mfromVector #-} mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) +{-# INLINEABLE mtoVectorP #-} mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a mtoVectorP (M_Primitive _ v) = X.toVector v +{-# INLINEABLE mtoVector #-} mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to +-- stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'. +mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuter l = mfromListOuterN (length l) l + +-- | See 'mfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable. +mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuterN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l) + Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")" + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the +-- list. +-- +-- If the elements are scalars, 'mfromList1Prim' is faster. mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? +mfromList1 = mfromListOuter . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a +mfromList1N n = mfromListOuterN n . fmap mscalar +-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a +mfromList1SN sn = mfromListOuterSN sn . fmap mscalar + +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +-- | If the elements are scalars, 'mfromListPrimLinear' is faster. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) + +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list. mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromList1Prim l = let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l + xarr = X.fromList1 l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter +mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a +mfromList1PrimN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l) + Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")" -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr +mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a +mfromList1PrimSN sn l = + let sh = SKnown sn :$% ZSX + in fromPrimitive $ M_Primitive sh + $ if Storable.sizeOf (undefined :: a) > 0 + then X.fromList1SN sn l + else case l of -- don't force the list if all elements are the same + a0 : _ -> X.replicateScal sh a0 + [] -> X.fromList1SN sn l -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l) in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) +mtoList :: Elt a => Mixed '[n] a -> [a] +mtoList = map munScalar . mtoListOuter mtoListLinear :: Elt a => Mixed sh a -> [a] -mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise +mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) + +mtoListPrim :: PrimElt a => Mixed '[n] a -> [a] +mtoListPrim (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr + +mtoListPrimLinear :: PrimElt a => Mixed sh a -> [a] +mtoListPrimLinear (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr +mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 +-- | The arguments must have equal shapes. If they do not, an error is raised. +mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip a b + | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b + | otherwise = error "mzip: unequal shapes" munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) munzip (M_Tup2 a b) = (a, b) -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) +mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b)) +mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) = + let sh1 = shxDropSh sh shsh1 + in M_Nest sh $ + M_Primitive (shxAppend sh sh2) + (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr +-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero), +-- then there is no way to deduce the full shape of the output array (more +-- precisely, the @sh2@ part): that could only come from calling @f@, and there +-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we +-- choose to fill the shape with zeros wherever we cannot deduce what it should +-- be. +-- +-- For example, if: +-- +-- @ +-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21] +-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int) +-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float) +-- @ +-- +-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the +-- @0@ in this shape: we don't know if @f@ intended to return an array with +-- shape 0 here (it probably didn't), but there is no better number to put here +-- absent a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b) +mrerankPrim sh2 f (M_Nest sh arr) = + let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr)) + in M_Nest sh' (fromPrimitive arr') mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a @@ -856,20 +1041,28 @@ mreplicate sh arr = Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) +mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x) -mreplicateScal :: forall sh a. PrimElt a +mreplicatePrim :: forall sh a. PrimElt a => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) +mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x) + +msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = +msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +msliceSN i n arr = let _ :$% sh = mshape arr in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr +mslice :: forall i n k sh a. Elt a + => SMayNat Int i -> SMayNat Int n -> SMayNat Int k -> Mixed (AddMaybe (AddMaybe i n) k : sh) a -> Mixed (n : sh) a +mslice i n k arr = case mshape arr of + _ :$% sh -> + let uarr = mcastPartial (ssxFromShX $ smnAddMaybe (smnAddMaybe i n) k :$% ZSX) (SUnknown () :!% ZKX) Proxy arr + in mcastPartial (SUnknown () :!% ZKX) (ssxFromShX $ n :$% ZSX) Proxy + $ mlift (SUnknown () :!% ssxFromShX sh) (\_ -> X.sliceU (fromSMayNat' i) (fromSMayNat' n)) uarr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr @@ -896,6 +1089,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) +{-# INLINEABLE mdot1Inner #-} mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) @@ -911,6 +1105,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'mdot1Inner' if applicable. +{-# INLINEABLE mdot #-} mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a mdot a b = munScalar $ @@ -929,11 +1124,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP +{-# INLINE mliftPrim #-} mliftPrim :: (PrimElt a, PrimElt b) => (a -> b) -> Mixed sh a -> Mixed sh b mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) +{-# INLINE mliftPrim2 #-} mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) => (a -> b -> c) -> Mixed sh a -> Mixed sh b -> Mixed sh c diff --git a/src/Data/Array/Nested/Mixed/ListX.hs b/src/Data/Array/Nested/Mixed/ListX.hs new file mode 100644 index 0000000..2c8a9cc --- /dev/null +++ b/src/Data/Array/Nested/Mixed/ListX.hs @@ -0,0 +1,139 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Mixed.ListX (ListX, pattern ZX, pattern (::%), listxShow, lazily, lazilyConcat, lazilyForce, Rank, coerceEqualRankListX) where + +import Control.DeepSeq (NFData(..)) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Type.Equality +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import GHC.TypeLits.Orphans () +#endif + +import Data.Array.Nested.Types + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists implementation + +-- | Data invariant: each element on the list is in WHNF (the spine may be lazy) +-- and the length of the list is the same as of the type-level shape. +type role ListX nominal representational +type ListX :: [Maybe Nat] -> Type -> Type +newtype ListX sh i = ListX [i] + deriving (Eq, Ord, NFData, Foldable) + +{-# INLINE ZX #-} +pattern ZX :: forall sh i. () => sh ~ '[] => ListX sh i +pattern ZX <- (listxNull -> Just Refl) + where ZX = ListX [] + +{-# INLINE listxNull #-} +listxNull :: ListX sh i -> Maybe (sh :~: '[]) +listxNull (ListX []) = Just unsafeCoerceRefl +listxNull (ListX (_ : _)) = Nothing + +{-# INLINE (::%) #-} +pattern (::%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> ListX sh i -> ListX sh1 i +pattern i ::% l <- (listxUncons -> Just (UnconsListXRes i l)) + where !i ::% ListX !l = ListX (i : l) +infixr 3 ::% + +data UnconsListXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes i (ListX sh i) +{-# INLINE listxUncons #-} +listxUncons :: forall sh1 i. ListX sh1 i -> Maybe (UnconsListXRes i sh1) +listxUncons (ListX (i : l)) = gcastWith (unsafeCoerceRefl :: Head sh1 ': Tail sh1 :~: sh1) $ + Just (UnconsListXRes i (ListX @(Tail sh1) l)) +listxUncons (ListX []) = Nothing + +{-# COMPLETE ZX, (::%) #-} + +-- | This function makes no attempt to stop you from breaking the data invariant for 'ListX'. If you do so, you must later ensure that the invariant is reinstated, for example using @'lazilyForce' 'id'@. +{-# INLINE lazily #-} +lazily :: ([a] -> [b]) -> ListX sh a -> ListX sh b +lazily f (ListX l) = ListX $ f l + +-- | This function makes no attempt to stop you from breaking the data invariant for 'ListX'. If you do so, you must later ensure that the invariant is reinstated, for example using @'lazilyForce' 'id'@. +{-# INLINE lazilyConcat #-} +lazilyConcat :: ([a] -> [b] -> [c]) -> ListX sh a -> ListX sh' b -> ListX (sh ++ sh') c +lazilyConcat f (ListX l) (ListX k) = ListX $ f l k + +-- | This operation forces all elements of the @[b]@ list to restore the strictness part of the data invariant for 'ListX'. Note that ensuring the list has the right length is still the user's responsibility. +{-# INLINE lazilyForce #-} +lazilyForce :: ([a] -> [b]) -> ListX sh a -> ListX sh b +lazilyForce f (ListX l) = let res = f l + in foldr seq () res `seq` ListX res + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListX sh i) +#else +instance Show i => Show (ListX sh i) where + showsPrec _ = listxShow shows +#endif + +{-# INLINE listxShow #-} +listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' i -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +-- This can't be derived, becauses the list needs to be fully evaluated, +-- per data invariant. This version is faster than versions defined using +-- (::%) or lazilyForce. +instance Functor (ListX l) where + {-# INLINE fmap #-} + fmap f (ListX l) = + let fmap' [] = [] + fmap' (x : xs) = let y = f x + rest = fmap' xs + in y `seq` rest `seq` (y : rest) + in ListX $ fmap' l + +-- | Very untyped: not even length is checked (at runtime). +instance IsList (ListX sh i) where + type Item (ListX sh i) = i + {-# INLINE fromList #-} + fromList l = foldr seq () l `seq` ListX l + {-# INLINE toList #-} + toList = Foldable.toList + +coerceEqualRankListX :: Rank sh ~ Rank sh' => ListX sh i -> ListX sh' i +coerceEqualRankListX (ListX l) = ListX l diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 2f35ff9..a01e0f3 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -1,13 +1,15 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -17,166 +19,44 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Mixed.Shape where +module Data.Array.Nested.Mixed.Shape ( + module Data.Array.Nested.Mixed.Shape, + Rank, +) where import Control.DeepSeq (NFData(..)) +import Control.Exception (assert) import Data.Bifunctor (first) import Data.Coerce import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) -import GHC.Generics (Generic) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits - -import Data.Array.Nested.Types - - --- | The length of a type-level list. If the argument is a shape, then the --- result is the rank of that shape. -type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 - - --- * Mixed lists - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -#ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListX sh f) -#else -instance (forall n. Show (f n)) => Show (ListX sh f) where - showsPrec _ = listxShow shows +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import GHC.TypeLits.Orphans () #endif -instance (forall n. NFData (f n)) => NFData (ListX sh f) where - rnf ZX = () - rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) -listxUncons ZX = Nothing - --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqType ZX ZX = Just Refl -listxEqType (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , Just Refl <- listxEqType sh sh' - = Just Refl -listxEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqual ZX ZX = Just Refl -listxEqual (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listxEqual sh sh' - = Just Refl -listxEqual _ _ = Nothing - -listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -listxFmap _ ZX = ZX -listxFmap f (x ::% xs) = f x ::% listxFmap f xs - -listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFold _ ZX = mempty -listxFold f (x ::% xs) = f x <> listxFold f xs - -listxLength :: ListX sh f -> Int -listxLength = getSum . listxFold (\_ -> Sum 1) - -listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat - -listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -listxShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListX sh' f -> ShowS - go _ ZX = id - go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) -listxFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "listxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" - -listxToList :: ListX sh' (Const i) -> [i] -listxToList ZX = [] -listxToList (Const i ::% is) = i : listxToList is - -listxHead :: ListX (mn ': sh) f -> f mn -listxHead (i ::% _) = i - -listxTail :: ListX (n : sh) i -> ListX sh i -listxTail (_ ::% sh) = sh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f -listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh -listxInit (_ ::% ZX) = ZX - -listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) -listxLast (_ ::% sh@(_ ::% _)) = listxLast sh -listxLast (x ::% ZX) = x - -listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) -listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = - Pair i j ::% listxZip irest jrest - -listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g - -> ListX sh h -listxZipWith _ ZX ZX = ZX -listxZipWith f (i ::% is) (j ::% js) = - f i j ::% listxZipWith f is js +import Data.Array.Nested.Mixed.ListX +import Data.Array.Nested.Types -- * Mixed indices --- | This is a newtype over 'ListX'. +-- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) +newtype IxX sh i = IxX (ListX sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX @@ -185,34 +65,30 @@ pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) +pattern i :.% l <- IxX (i ::% (IxX -> l)) + where i :.% IxX l = IxX (i ::% l) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxX sh = IxX sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxX sh i) #else instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = listxShow (shows . getConst) l + showsPrec _ (IxX l) = listxShow shows l #endif -instance Functor (IxX sh) where - fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) - -instance Foldable (IxX sh) where - foldMap f (IxX l) = listxFold (f . getConst) l - -instance NFData i => NFData (IxX sh i) - -ixxLength :: IxX sh i -> Int -ixxLength (IxX l) = listxLength l +{-# INLINE ixxFromList #-} +ixxFromList :: StaticShX sh -> [i] -> IxX sh i +ixxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l ixxRank :: IxX sh i -> SNat (Rank sh) -ixxRank (IxX l) = listxRank l +ixxRank ZIX = SNat +ixxRank (_ :.% l) | SNat <- ixxRank l = SNat ixxZero :: StaticShX sh -> IIxX sh ixxZero ZKX = ZIX @@ -222,85 +98,131 @@ ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh -ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i -ixxFromList = coerce (listxFromList @_ @i) - -ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = getConst (listxHead list) +ixxHead :: IxX (mn ': sh) i -> i +ixxHead (i :.% _) = i ixxTail :: IxX (n : sh) i -> IxX sh i -ixxTail (IxX list) = IxX (listxTail list) +ixxTail (_ :.% sh) = sh ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @(Const i)) +ixxAppend (IxX l1) (IxX l2) = IxX $ lazilyConcat (++) l1 l2 -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixxDrop = coerce (listxDrop @(Const i) @(Const i)) +ixxDrop :: forall i j sh sh'. IxX sh j -> IxX (sh ++ sh') i -> IxX sh' i +ixxDrop ZIX long = long +ixxDrop (_ :.% short) long = case long of _ :.% long' -> ixxDrop short long' -ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @(Const i)) +ixxInit :: forall i n sh. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit (i :.% sh@(_ :.% _)) = i :.% ixxInit sh +ixxInit (_ :.% ZIX) = ZIX -ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @(Const i)) +ixxLast :: forall i n sh. IxX (n : sh) i -> i +ixxLast (_ :.% sh@(_ :.% _)) = ixxLast sh +ixxLast (x :.% ZIX) = x ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js +{-# INLINE ixxZipWith #-} ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -ixxFromLinear :: IShX sh -> Int -> IIxX sh -ixxFromLinear = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixxToLinear #-} +ixxToLinear :: Num i => IShX sh -> IxX sh i -> i +ixxToLinear = \sh i -> go sh i 0 + where + go :: Num i => IShX sh -> IxX sh i -> i -> i + go ZSX ZIX !a = a + go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) + +{-# INLINEABLE ixxFromLinear #-} +ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i +ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times + let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + in fromLin0 sh suffixes where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) + -- Unfold first iteration of fromLin to do the range check. + -- Don't inline this function at first to allow GHC to inline the outer + -- function and realise that 'suffixes' is shared. But then later inline it + -- anyway, to avoid the function call. Removing the pragma makes GHC + -- somehow unable to recognise that 'suffixes' can be shared in a loop. + {-# NOINLINE [0] fromLin0 #-} + fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin0 sh suffixes i = + if i < 0 then outrange sh i else + case (sh, suffixes) of + (ZSX, _) | i > 0 -> outrange sh i + | otherwise -> ZIX + ((fromSMayNat' -> n) :$% sh', suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + fromIntegral q :.% fromLin sh' suffs r + _ -> error "impossible" + + fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin ZSX _ !_ = ZIX + fromLin (_ :$% sh') (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in fromIntegral q :.% fromLin sh' suffs r + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: IShX sh -> Int -> a + outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = shxEnum' -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) +{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site +shxEnum' :: Num i => IShX sh -> [IxX sh i] +shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + + fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i + fromLin ZSX _ _ = ZIX + fromLin (_ :$% sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' + in fromIntegral (I# q#) :.% fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" -- * Mixed shapes -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +instance NFData i => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x + rnf (SKnown SNat) = () -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing +{-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +{-# INLINE fromSMayNat' #-} +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -308,144 +230,226 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) --- | This is a newtype over 'ListX'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} +data ShX sh i where + ZSX :: ShX '[] i + ConsUnknown :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i +-- TODO: bring this UNPACK back when GHC no longer crashes: +-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ShX sh i -> ShX (Just n : sh) i + ConsKnown :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i +deriving instance Ord i => Ord (ShX sh i) -type IShX sh = ShX sh Int +-- A manually defined instance and this INLINEABLE is needed to specialize +-- mdot1Inner (otherwise GHC warns specialization breaks down here). +instance Eq i => Eq (ShX sh i) where + {-# INLINEABLE (==) #-} + ZSX == ZSX = True + ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2 + ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2 #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ShX sh i) #else instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ l = shxShow (fromSMayNat shows (shows . fromSNat)) l #endif +instance NFData i => NFData (ShX sh i) where + rnf ZSX = () + rnf (x `ConsUnknown` l) = rnf x `seq` rnf l + rnf (SNat `ConsKnown` l) = rnf l + instance Functor (ShX sh) where - fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + {-# INLINE fmap #-} + fmap f l = shxFmap (fromSMayNat (SUnknown . f) SKnown) l -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) +data UnconsShXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShXRes (SMayNat i n) (ShX sh i) +shxUncons :: ShX sh1 i -> Maybe (UnconsShXRes i sh1) +shxUncons (i `ConsUnknown` shl') = Just (UnconsShXRes (SUnknown i) shl') +shxUncons (i `ConsKnown` shl') = Just (UnconsShXRes (SKnown i) shl') +shxUncons ZSX = Nothing --- | This checks only whether the types are equal; unknown dimensions might --- still differ. This corresponds to 'testEquality', except on the penultimate --- type parameter. +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqType ZSX ZSX = Just Refl -shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m - , Just Refl <- shxEqType sh sh' - = Just Refl -shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') +shxEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') | Just Refl <- shxEqType sh sh' = Just Refl +shxEqType (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m + , Just Refl <- shxEqType sh sh' + = Just Refl shxEqType _ _ = Nothing --- | This checks whether all dimensions have the same value. This is more than --- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the --- @some@ package (except on the penultimate type parameter). +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m +shxEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') + | n == m , Just Refl <- shxEqual sh sh' = Just Refl -shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') - | i == j +shxEqual (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual _ _ = Nothing +{-# INLINE shxFmap #-} +shxFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ShX sh i -> ShX sh j +shxFmap _ ZSX = ZSX +shxFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of + SUnknown y -> y `ConsUnknown` shxFmap f xs +shxFmap f (x `ConsKnown` xs) = case f (SKnown x) of + SKnown y -> y `ConsKnown` shxFmap f xs + +{-# INLINE shxFoldMap #-} +shxFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ShX sh i -> m +shxFoldMap _ ZSX = mempty +shxFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> shxFoldMap f xs +shxFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> shxFoldMap f xs + shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l +shxLength = getSum . shxFoldMap (\_ -> Sum 1) shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l +shxRank ZSX = SNat +shxRank (_ `ConsUnknown` l) | SNat <- shxRank l = SNat +shxRank (_ `ConsKnown` l) | SNat <- shxRank l = SNat + +{-# INLINE shxShow #-} +shxShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ShX sh i -> ShowS +shxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ShX sh' i -> ShowS + go _ ZSX = id + go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs + go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs + +shxHead :: ShX (mn ': sh) i -> SMayNat i mn +shxHead (i `ConsUnknown` _) = SUnknown i +shxHead (i `ConsKnown` _) = SKnown i + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ `ConsUnknown` sh) = sh +shxTail (_ `ConsKnown` sh) = sh + +shxAppend :: ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend ZSX idx' = idx' +shxAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` shxAppend idx idx' +shxAppend (i `ConsKnown` idx) idx' = i `ConsKnown` shxAppend idx idx' + +shxDropSh :: forall sh sh' i j. ShX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh ZSX long = long +shxDropSh (_ `ConsUnknown` short) long = case long of + _ `ConsUnknown` long' -> shxDropSh short long' +shxDropSh (_ `ConsKnown` short) long = case long of + _ `ConsKnown` long' -> shxDropSh short long' + +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (shxDropSh @_ @_ @i @()) + +shxInit :: forall i n sh. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` shxInit sh +shxInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` shxInit sh +shxInit (_ `ConsUnknown` ZSX) = ZSX +shxInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` shxInit sh +shxInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` shxInit sh +shxInit (_ `ConsKnown` ZSX) = ZSX + +shxLast :: forall i n sh. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = shxLast sh +shxLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = shxLast sh +shxLast (x `ConsUnknown` ZSX) = SUnknown x +shxLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = shxLast sh +shxLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = shxLast sh +shxLast (x `ConsKnown` ZSX) = SKnown x + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- (shxUncons -> Just (UnconsShXRes i shl)) + where i :$% shl = case i of; SUnknown x -> x `ConsUnknown` shl; SKnown x -> x `ConsKnown` shl +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int -shxFromList topssh topl = go topssh topl +-- We don't report the size of the list in case of errors in order not to retain the list. +{-# INLINEABLE shxFromList #-} +shxFromList :: StaticShX sh -> [Int] -> IShX sh +shxFromList (StaticShX topssh) topl = go topssh topl where - go :: StaticShX sh' -> [Int] -> ShX sh' Int - go ZKX [] = ZSX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn :$% go sh is - | otherwise = error $ "shxFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is - go _ _ = error $ "shxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" + go :: ShX sh' () -> [Int] -> ShX sh' Int + go ZSX [] = ZSX + go ZSX _ = error $ "shxFromList: List too long (type says " ++ show (shxLength topssh) ++ ")" + go (ConsKnown sn sh) (i : is) + | i == fromSNat' sn = ConsKnown sn (go sh is) + | otherwise = error "shxFromList: Value does not match typing" + go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is) + go _ _ = error $ "shxFromList: List too short (type says " ++ show (shxLength topssh) ++ ")" +{-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] -shxToList ZSX = [] -shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +shxToList l = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil + go (ConsUnknown i rest) = i `cons` go rest + go (ConsKnown sn rest) = fromSNat' sn `cons` go rest + in go l) --- | This may fail if @sh@ has @Nothing@s in it. -shxFromSSX' :: StaticShX sh -> Maybe (IShX sh) -shxFromSSX' ZKX = Just ZSX -shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh -shxFromSSX' (SUnknown _ :!% _) = Nothing +-- If it ever matters for performance, this is unsafeCoercible. +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" -shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -shxHead :: ShX (n : sh) i -> SMayNat i SNat n -shxHead (ShX list) = listxHead list - -shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) - -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) - -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh -shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSh _ ZSX _ = ZSX +shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) +{-# INLINEABLE shxTakeIx #-} +shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i +shxTakeIx _ (IxX ZX) _ = ZSX +shxTakeIx proxy (IxX (_ ::% long)) short = case short of i :$% short' -> i :$% shxTakeIx proxy (IxX long) short' -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +{-# INLINEABLE shxDropIx #-} +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropIx ZIX long = long +shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +{-# INLINE shxZipWith #-} +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -456,48 +460,45 @@ shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) -shxEnum :: IShX sh -> [IIxX sh] -shxEnum = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" -- * Static mixed shapes --- | The part of a shape that is statically known. (A newtype over 'ListX'.) +-- | The part of a shape that is statically known. (A newtype over 'ShX'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) +newtype StaticShX sh = StaticShX (ShX sh ()) + deriving (NFData) + +instance Eq (StaticShX sh) where _ == _ = True +instance Ord (StaticShX sh) where compare _ _ = EQ pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX +pattern ZKX = StaticShX ZSX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) + => SMayNat () n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (shxUncons -> Just (UnconsShXRes i (StaticShX -> shl))) + where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) + infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} @@ -506,63 +507,68 @@ infixr 3 :!% deriving instance Show (StaticShX sh) #else instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (StaticShX l) = shxShow (fromSMayNat shows (shows . fromSNat)) l #endif -instance NFData (StaticShX sh) where - rnf (StaticShX ZX) = () - rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) - instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + testEquality (StaticShX l1) (StaticShX l2) = shxEqType l1 l2 ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l +ssxLength (StaticShX l) = shxLength l ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l +ssxRank (StaticShX l) = shxRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' +ssxAppend = coerce (shxAppend @_ @()) -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n -ssxHead (StaticShX list) = listxHead list +ssxHead :: StaticShX (n : sh) -> SMayNat () n +ssxHead (StaticShX list) = shxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh +ssxTail (StaticShX list) = StaticShX (shxTail list) + +ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh +ssxTakeIx _ (IxX ZX) _ = ZKX +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short' -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropIx (IxX ZX) long = long +ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' + +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (shxDropSh @_ @_ @() @i) + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (shxDropSh @_ @_ @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxInit = coerce (shxInit @()) -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (shxLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1) -ssxFromShX :: IShX sh -> StaticShX sh +ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n -- | Evidence for the static part of a shape. This pops up only when you are @@ -574,7 +580,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX k = withDict @(KnownShX sh) k +withKnownShX = withDict @(KnownShX sh) -- * Flattening @@ -587,18 +593,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh @@ -609,20 +615,14 @@ shxFlatten = go (SNat @1) goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i - fromList = listxFromList (knownShX @sh) - toList = listxToList - -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where +instance IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . IsList.fromList toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int +instance KnownShX sh => IsList (IShX sh) where + type Item (IShX sh) = Int fromList = shxFromList (knownShX @sh) toList = shxToList diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 031755f..85fbd89 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -1,10 +1,10 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -18,13 +18,13 @@ module Data.Array.Nested.Permutation where import Data.Coerce (coerce) -import Data.Functor.Const import Data.List (sort) import Data.Maybe (fromMaybe) import Data.Proxy import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord +import GHC.Exts (withDict) import GHC.TypeError import GHC.TypeLits import GHC.TypeNats qualified as TN @@ -36,8 +36,8 @@ import Data.Array.Nested.Types -- * Permutations -- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. data Perm list where PNil :: Perm '[] PCons :: SNat a -> Perm l -> Perm (a : l) @@ -45,15 +45,22 @@ infixr 5 `PCons` deriving instance Show (Perm list) deriving instance Eq (Perm list) +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + permRank :: Perm list -> SNat (Rank list) permRank PNil = SNat permRank (_ `PCons` l) | SNat <- permRank l = SNat -permFromList :: [Int] -> (forall list. Perm list -> r) -> r -permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x +permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r +permFromListCont [] k = k PNil +permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Nested.Permutation.permFromListCont: negative number in list: " ++ show x permToList :: Perm list -> [Natural] permToList PNil = mempty @@ -119,6 +126,9 @@ class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + -- | Untyped permutations for ranked arrays type PermR = [Int] @@ -161,51 +171,78 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen PNil _ = ZX -listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" +shxTakeLenPerm :: forall i is sh. Perm is -> ShX sh i -> ShX (TakeLen is sh) i +shxTakeLenPerm PNil _ = ZSX +shxTakeLenPerm (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` shxTakeLenPerm is sh +shxTakeLenPerm (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` shxTakeLenPerm is sh +shxTakeLenPerm (_ `PCons` _) ZSX = error "Permutation longer than shape" -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen PNil sh = sh -listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" +shxDropLenPerm :: forall i is sh. Perm is -> ShX sh i -> ShX (DropLen is sh) i +shxDropLenPerm PNil sh = sh +shxDropLenPerm (_ `PCons` is) (_ `ConsUnknown` sh) = shxDropLenPerm is sh +shxDropLenPerm (_ `PCons` is) (_ `ConsKnown` sh) = shxDropLenPerm is sh +shxDropLenPerm (_ `PCons` _) ZSX = error "Permutation longer than shape" -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh +shxPermute :: forall i is sh. Perm is -> ShX sh i -> ShX (Permute is sh) i +shxPermute PNil _ = ZSX +shxPermute (i `PCons` (is :: Perm is')) (sh :: ShX sh i) = + case shxIndex i sh of + SUnknown x -> x `ConsUnknown` shxPermute is sh + SKnown x -> x `ConsKnown` shxPermute is sh -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" +shxIndex :: forall i k sh. SNat k -> ShX sh i -> SMayNat i (Index k sh) +shxIndex SZ (n `ConsUnknown` _) = SUnknown n +shxIndex SZ (n `ConsKnown` _) = SKnown n +shxIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ShX sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh') + = shxIndex i sh +shxIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ShX sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh') + = shxIndex i sh +shxIndex _ ZSX = error "Index into empty shape" -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) +shxPermutePrefix :: forall i is sh. Perm is -> ShX sh i -> ShX (PermutePrefix is sh) i +shxPermutePrefix perm sh = shxAppend (shxPermute perm (shxTakeLenPerm perm sh)) (shxDropLenPerm perm sh) -ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) -ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) +ssxTakeLenPerm :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLenPerm = coerce (shxTakeLenPerm @()) -ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) +ssxDropLenPerm :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLenPerm = coerce (shxDropLenPerm @()) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) +ssxPermute = coerce (shxPermute @()) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) +ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh) +ssxIndex k = coerce (shxIndex @() k) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) +ssxPermutePrefix = coerce (shxPermutePrefix @()) + + +ixxTakeLenPerm :: forall i is sh. Perm is -> IxX sh i -> IxX (TakeLen is sh) i +ixxTakeLenPerm PNil _ = ZIX +ixxTakeLenPerm (_ `PCons` is) (n :.% sh) = n :.% ixxTakeLenPerm is sh +ixxTakeLenPerm (_ `PCons` _) ZIX = error "Permutation longer than shape" + +ixxDropLenPerm :: forall i is sh. Perm is -> IxX sh i -> IxX (DropLen is sh) i +ixxDropLenPerm PNil sh = sh +ixxDropLenPerm (_ `PCons` is) (_ :.% sh) = ixxDropLenPerm is sh +ixxDropLenPerm (_ `PCons` _) ZIX = error "Permutation longer than shape" + +ixxPermute :: forall i is sh. Perm is -> IxX sh i -> IxX (Permute is sh) i +ixxPermute PNil _ = ZIX +ixxPermute (i `PCons` (is :: Perm is')) (sh :: IxX sh f) = + ixxIndex i sh :.% ixxPermute is sh + +ixxIndex :: forall j i sh. SNat i -> IxX sh j -> j +ixxIndex SZ (n :.% _) = n +ixxIndex (SS i) (_ :.% sh) = ixxIndex i sh +ixxIndex _ ZIX = error "Index into empty shape" -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix perm sh = ixxAppend (ixxPermute perm (ixxTakeLenPerm perm sh)) (ixxDropLenPerm perm sh) -- * Operations on permutations @@ -224,7 +261,7 @@ permInverse = \perm k -> ++ " ; invperm = " ++ show invperm) (permCheckPermutation invperm (k invperm - (\ssh -> case provePermInverse perm invperm ssh of + (\ssh -> case permCheckInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm))) @@ -238,9 +275,9 @@ permInverse = \perm k -> toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - provePermInverse :: Perm is -> Perm is' -> StaticShX sh + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = + permCheckInverse perm perminv ssh = ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where @@ -248,11 +285,50 @@ type family MapSucc is where MapSucc (i : is) = i + 1 : MapSucc is permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc +permShift1 = (SZ `PCons`) . permMapSucc where permMapSucc :: Perm l -> Perm (MapSucc l) permMapSucc PNil = PNil - permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + permMapSucc (sn `PCons` ns) = snatSucc sn `PCons` permMapSucc ns + +-- | @PermId n@ is the type of the identity permutation of length @n@. +type family PermId n where + PermId 0 = '[] + PermId 1 = '[0] + PermId n = PermId (n - 1) ++ '[n - 1] + +{- Doesn't type-check: +permId :: SNat n -> Perm (PermId n) +permId SZ = PNil +permId (SS SZ) = PCons SZ PNil +permId (SS k) = permId k `permAppend` PCons k PNil +-} +permId :: forall n. SNat n -> Perm (PermId n) +permId n = go SZ + where + go :: forall k l. SNat k -> Perm l + go k = if fromSNat' k >= fromSNat' n + then gcastWith (unsafeCoerceRefl :: (l :~: '[])) $ + PNil + else gcastWith (unsafeCoerceRefl :: (l :~: k : anything)) $ + k `PCons` go (SS k) + +-- | Note that the second argument is not a valid permutation. +permAppend :: Perm l -> Perm l2 -> Perm (l ++ l2) +permAppend PNil l2 = l2 +permAppend (n `PCons` rest) l2 = n `PCons` permAppend rest l2 + +type family MapPlusN n is where + MapPlusN n '[] = '[] + MapPlusN n (i : is) = i + n : MapPlusN n is + +-- TODO: instead of permAppend and permShiftN define permComp :: Perm l1 -> Perm l2 -> Perm (l1 ++ MapPlusN (Rank l1) l2), where all three are valid permutations +permShiftN :: forall n l. SNat n -> Perm l -> Perm (PermId n ++ MapPlusN n l) +permShiftN n = (permId n `permAppend`) . permMapPlusN + where + permMapPlusN :: Perm l1 -> Perm (MapPlusN n l1) + permMapPlusN PNil = PNil + permMapPlusN (sn `PCons` ns) = snatPlus sn n `PCons` permMapPlusN ns -- * Lemmas @@ -264,7 +340,13 @@ lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX PNil = Refl -lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) + | Refl <- lemRankDropLen sh is +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) + = Refl +#else + = unsafeCoerceRefl +#endif lemRankDropLen (_ :!% _) PNil = Refl lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e5c51ef..2d8b624 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -29,17 +29,17 @@ import Foreign.Storable (Storable) import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) -import Data.Array.XArray qualified as X import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X remptyArray :: KnownElt a => Ranked 1 a @@ -49,9 +49,11 @@ remptyArray = mtoRanked (memptyArray ZSX) rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape +{-# INLINEABLE rindex #-} rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) +{-# INLINEABLE rindexPartial #-} rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) @@ -59,15 +61,25 @@ rindexPartial (Ranked arr) idx = (ixxFromIxR idx)) -- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. +-- See the documentation of 'mgenerate' for more details; see also +-- 'rgeneratePrim'. rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f - | sn@SNat <- shrRank sh + | sn <- shrRank sh , Dict <- lemKnownReplicate sn , Refl <- lemRankReplicate sn = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) +-- | See 'mgeneratePrim'. +{-# INLINE rgeneratePrim #-} +rgeneratePrim :: forall n a i. (PrimElt a, Num i) + => IShR n -> (IxR n i -> a) -> Ranked n a +rgeneratePrim sh f = + let g i = f (ixrFromLinear sh i) + in rfromVector sh $ VS.generate (shrSize sh) g + -- | See the documentation of 'mlift'. +{-# INLINE rlift #-} rlift :: forall n1 n2 a. Elt a => SNat n2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) @@ -75,40 +87,48 @@ rlift :: forall n1 n2 a. Elt a rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) -- | See the documentation of 'mlift2'. +{-# INLINE rlift2 #-} rlift2 :: forall n1 n2 n3 a. Elt a => SNat n3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) +{-# INLINE rsumOuter1PrimP #-} +rsumOuter1PrimP :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1PrimP (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msumOuter1PrimP arr) + +{-# INLINEABLE rsumOuter1Prim #-} +rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive +{-# INLINE rsumAllPrimP #-} +rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a +rsumAllPrimP (Ranked arr) = msumAllPrimP arr +{-# INLINE rsumAllPrim #-} rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a rtranspose perm arr - | sn@SNat <- rrank arr + | sn <- rrank arr , Dict <- lemKnownReplicate sn - , length perm <= fromIntegral (natVal (Proxy @n)) + , length perm <= fromSNat' sn = rlift sn - (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + (\ssh' -> X.transposeUntyped sn ssh' perm) arr | otherwise = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce mconcat rappend :: forall n a. Elt a @@ -116,72 +136,107 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n) = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) +{-# INLINEABLE rfromVectorP #-} rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (shrRank sh) = Ranked (mfromVectorP (shxFromShR sh) v) +{-# INLINEABLE rfromVector #-} rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = rfromPrimitive (rfromVectorP sh v) +{-# INLINEABLE rtoVectorP #-} rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a rtoVectorP = coerce mtoVectorP +{-# INLINEABLE rtoVector #-} rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromListOuterN' to be able to stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'. rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) +-- | See 'rfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable. +rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuterN n l + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromList1N' to be able to stream the list. +-- +-- If the elements are scalars, 'rfromList1Prim' is faster. rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) +rfromList1 = coerce mfromList1 + +-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a +rfromList1N = coerce mfromList1N + +-- | If the elements are scalars, 'rfromListPrimLinear' is faster. +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l) +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'rfromList1PrimN' to be able to stream the list. rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) +rfromList1Prim = coerce mfromList1Prim + +rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a +rfromList1PrimN = coerce mfromList1PrimN + +rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l) rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) - -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter rtoListLinear :: Elt a => Ranked n a -> [a] rtoListLinear (Ranked arr) = mtoListLinear arr +rtoListPrim :: PrimElt a => Ranked 1 a -> [a] +rtoListPrim (Ranked arr) = mtoListPrim arr + +rtoListPrimLinear :: PrimElt a => Ranked n a -> [a] +rtoListPrimLinear (Ranked arr) = mtoListPrimLinear arr + rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a rfromOrthotope sn arr | Refl <- lemRankReplicate sn = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -197,22 +252,20 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked arr -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) runzip = coerce munzip -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) +rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b) + => IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b)) +rrerankPrimP sh2 f (Ranked (M_Ranked arr)) + = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) -- | If there is a zero-sized dimension in the @n@-prefix of the shape of the -- input array, then there is no way to deduce the full shape of the output @@ -223,26 +276,28 @@ rrerankP sn sh2 f (Ranked arr) -- For example, if: -- -- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21] -- f :: Ranked 2 Int -> Ranked 3 Float -- @ -- -- then: -- -- @ --- rrerank _ _ _ f arr :: Ranked 5 Float +-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float) -- @ -- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr +-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't +-- know if @f@ intended to return an array with all-zero shape here (it +-- probably didn't), but there is no better number to put here absent a +-- subarray of the input to pass to @f@. +rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b) +rrerankPrim sh2 f (Ranked (M_Ranked arr)) = + Ranked (M_Ranked (mrerankPrim (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a @@ -250,29 +305,24 @@ rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked (mreplicate (shxFromShR sh) arr) -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x +rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicatePrimP sh x | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shxFromShR sh) x) + = Ranked (mreplicatePrimP (shxFromShR sh) x) -rreplicateScal :: forall n a. PrimElt a +rreplicatePrim :: forall n a. PrimElt a => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) +rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (rrank arr) - (\_ -> X.sliceU i n) - arr +rslice i n (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msliceN i n arr) rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr +rrev1 (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mrev1 arr) rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a @@ -299,6 +349,7 @@ rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixrFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE rdot1Inner #-} rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a rdot1Inner arr1 arr2 | SNat <- rrank arr1 @@ -307,14 +358,15 @@ rdot1Inner arr1 arr2 -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'rdot1Inner' if applicable. +{-# INLINE rdot #-} rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) +rtoXArrayPrimP (Ranked arr) = first shrFromShX (mtoXArrayPrimP arr) rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) +rtoXArrayPrim (Ranked arr) = first shrFromShX (mtoXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index f50f671..beedbcf 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -30,17 +31,13 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -#ifndef OXAR_DEFAULT_SHOW_INSTANCES -import Data.Foldable (toList) -#endif - -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -63,7 +60,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) #ifndef OXAR_DEFAULT_SHOW_INSTANCES instance (Show a, Elt a) => Show (Ranked n a) where showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) + let sh = show (shrToList (rshape arr)) in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr #endif @@ -85,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance Elt a => Elt (Ranked n a) where + {-# INLINE mshape #-} mshape (M_Ranked arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Ranked arr) i = Ranked (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ @@ -95,13 +95,14 @@ instance Elt a => Elt (Ranked n a) where mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + mfromListOuterSN :: SNat m -> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Just m : sh) (Ranked n a) + mfromListOuterSN sn l = M_Ranked (mfromListOuterSN sn (coerce l)) mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] mtoListOuter (M_Ranked arr) = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -110,6 +111,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -118,6 +120,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -137,28 +140,29 @@ instance Elt a => Elt (Ranked n a) where type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) + mshapeTree (Ranked arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWriteLinear idx (Ranked arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh sh' s. + Proxy sh -> Int -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) @@ -174,18 +178,30 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) + mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArrayUnsafe i + memptyArrayUnsafe sh | Dict <- lemKnownReplicate (SNat @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArrayUnsafe i + memptyArrayUnsafe sh mvecsUnsafeNew idx (Ranked arr) | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) @@ -208,15 +224,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where negate = liftRanked1 negate abs = liftRanked1 abs signum = liftRanked1 signum - fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicatePrim" recip = liftRanked1 recip (/) = liftRanked2 (/) instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicatePrim" exp = liftRanked1 exp log = liftRanked1 log sqrt = liftRanked1 sqrt @@ -247,8 +263,9 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a ratan2Array = liftRanked2 matan2Array +{-# INLINE rshape #-} rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shrFromShX2 (mshape arr) +rshape (Ranked arr) = coerce (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c0c4f17..c04f39e 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,9 +1,5 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -11,16 +7,19 @@ {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +{-# LANGUAGE TypeAbstractions #-} +#endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -28,243 +27,239 @@ module Data.Array.Nested.Ranked.Shape where import Control.DeepSeq (NFData(..)) +import Control.Exception (assert) import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality -import GHC.Generics (Generic) +import GHC.Exts (build) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Types -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -#ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance Show i => Show (ListR n i) -#else -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows -#endif - -instance NFData i => NFData (ListR n i) where - rnf ZR = () - rnf (x ::: l) = rnf x `seq` rnf l - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqRank ZR ZR = Just Refl -listrEqRank (_ ::: sh) (_ ::: sh') - | Just Refl <- listrEqRank sh sh' - = Just Refl -listrEqRank _ _ = Nothing - --- | This compares the lists for value equality. -listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqual ZR ZR = Just Refl -listrEqual (i ::: sh) (j ::: sh') - | Just Refl <- listrEqual sh sh' - , i == j - = Just Refl -listrEqual _ _ = Nothing - -listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR n' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrLength :: ListR n i -> Int -listrLength = length - -listrRank :: ListR n i -> SNat n -listrRank ZR = SNat -listrRank (_ ::: sh) = snatSucc (listrRank sh) - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrHead :: ListR (n + 1) i -> i -listrHead (i ::: _) = i -listrHead ZR = error "unreachable" - -listrTail :: ListR (n + 1) i -> ListR n i -listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" - -listrInit :: ListR (n + 1) i -> ListR n i -listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh -listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" - -listrLast :: ListR (n + 1) i -> i -listrLast (_ ::: sh@(_ ::: _)) = listrLast sh -listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrZip :: ListR n i -> ListR n j -> ListR n (i, j) -listrZip ZR ZR = ZR -listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest -listrZip _ _ = error "listrZip: impossible pattern needlessly required" - -listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k -listrZipWith _ ZR ZR = ZR -listrZipWith f (i ::: irest) (j ::: jrest) = - f i j ::: listrZipWith f irest jrest -listrZipWith _ _ _ = - error "listrZipWith: impossible pattern needlessly required" - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - +-- * Ranked indices -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) +newtype IxR n i = IxR (IxX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR +pattern ZIR <- IxR (matchZIX @n -> Just Refl) + where ZIR = IxR ZIX + +matchZIX :: forall n i. IxX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZIX ZIX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZIX _ = Nothing pattern (:.:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) +pattern i :.: l <- (ixrUncons -> Just (UnconsIxRRes i l)) + where i :.: IxR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = IxR (i :.% l) infixr 3 :.: +data UnconsIxRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsIxRRes i (IxR n i) +ixrUncons :: forall n1 i. IxR n1 i -> Maybe (UnconsIxRRes i n1) +ixrUncons (IxR ((:.%) @n @sh i l)) + | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl + , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl = + Just (UnconsIxRRes i (IxR @(Rank sh) l)) +ixrUncons (IxR _) = Nothing + {-# COMPLETE ZIR, (:.:) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxR n = IxR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxR n i) #else instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l + showsPrec _ = ixrShow shows #endif -instance NFData i => NFData (IxR sh i) +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +ixrEqRank :: IxR n i -> IxR n' i -> Maybe (n :~: n') +ixrEqRank ZIR ZIR = Just Refl +ixrEqRank (_ :.: sh) (_ :.: sh') + | Just Refl <- ixrEqRank sh sh' + = Just Refl +ixrEqRank _ _ = Nothing -ixrLength :: IxR sh i -> Int -ixrLength (IxR l) = listrLength l +-- | This compares the lists for value equality. +ixrEqual :: Eq i => IxR n i -> IxR n' i -> Maybe (n :~: n') +ixrEqual ZIR ZIR = Just Refl +ixrEqual (i :.: sh) (j :.: sh') + | Just Refl <- ixrEqual sh sh' + , i == j + = Just Refl +ixrEqual _ _ = Nothing + +{-# INLINE ixrShow #-} +ixrShow :: forall n i. (i -> ShowS) -> IxR n i -> ShowS +ixrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> IxR n' i -> ShowS + go _ ZIR = id + go prefix (x :.: xs) = showString prefix . f x . go "," xs ixrRank :: IxR n i -> SNat n -ixrRank (IxR sh) = listrRank sh +ixrRank ZIR = SNat +ixrRank (_ :.: sh) = snatSucc (ixrRank sh) ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx - -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixxFromIxR idx) +{-# INLINE ixrFromList #-} +ixrFromList :: SNat n -> [i] -> IxR n i +ixrFromList topsn topl = assert (fromSNat' topsn == length topl) + $ IxR $ IsList.fromList topl ixrHead :: IxR (n + 1) i -> i -ixrHead (IxR list) = listrHead list +ixrHead (i :.: _) = i ixrTail :: IxR (n + 1) i -> IxR n i -ixrTail (IxR list) = IxR (listrTail list) +ixrTail (_ :.: sh) = sh ixrInit :: IxR (n + 1) i -> IxR n i -ixrInit (IxR list) = IxR (listrInit list) +ixrInit (n :.: sh@(_ :.: _)) = n :.: ixrInit sh +ixrInit (_ :.: ZIR) = ZIR ixrLast :: IxR (n + 1) i -> i -ixrLast (IxR list) = listrLast list +ixrLast (_ :.: sh@(_ :.: _)) = ixrLast sh +ixrLast (n :.: ZIR) = n + +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast SZ ZIR = ZIR +ixrCast (SS n) (i :.: l) = i :.: ixrCast n l +ixrCast _ _ = error "ixrCast: ranks don't match" +-- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) +ixrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $ + coerce (ixxAppend @_ @_ @i) + +ixrIndex :: forall k n i. (k + 1 <= n) => SNat k -> IxR n i -> i +ixrIndex SZ (x :.: _) = x +ixrIndex (SS i) (_ :.: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = ixrIndex i xs +ixrIndex _ ZIR = error "k + 1 <= 0" ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) -ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 +ixrZip ZIR ZIR = ZIR +ixrZip (i :.: irest) (j :.: jrest) = (i, j) :.: ixrZip irest jrest +ixrZip _ _ = error "ixrZip: impossible pattern needlessly required" +{-# INLINE ixrZipWith #-} ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k -ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 +ixrZipWith _ ZIR ZIR = ZIR +ixrZipWith f (i :.: irest) (j :.: jrest) = + f i j :.: ixrZipWith f irest jrest +ixrZipWith _ _ _ = + error "ixrZipWith: impossible pattern needlessly required" -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) +ixrSplitAt :: m <= n' => SNat m -> IxR n' i -> (IxR m i, IxR (n' - m) i) +ixrSplitAt SZ sh = (ZIR, sh) +ixrSplitAt (SS m) (n :.: sh) = (\(pre, post) -> (n :.: pre, post)) (ixrSplitAt m sh) +ixrSplitAt SS{} ZIR = error "m' + 1 <= 0" +ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i +ixrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case ixrRank sh of { shlen@SNat -> + let sperm = ixrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } + where + applyPermRFull :: SNat m -> IxR k Int -> IxR m i -> IxR k i + applyPermRFull _ ZIR _ = ZIR + applyPermRFull sm@SNat (i :.: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> ixrIndex si l :.: applyPermRFull sm perm l + EQI -> ixrIndex si l :.: applyPermRFull sm perm l + GTI -> error "ixrPermutePrefix: Index in permutation out of range" + +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix) + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR = coerce + +{-# INLINEABLE ixrFromLinear #-} +ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i +ixrFromLinear (ShR sh) i + | Refl <- lemRankReplicate (Proxy @m) + = ixrFromIxX $ ixxFromLinear sh i + +ixrFromIxX :: IxX (Replicate n Nothing) i -> IxR n i +ixrFromIxX = coerce + +shrEnum :: IShR n -> [IIxR n] +shrEnum = shrEnum' + +{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site +shrEnum' :: forall i n. Num i => IShR n -> [IxR n i] +shrEnum' (ShR sh) + | Refl <- lemRankReplicate (Proxy @n) + = (coerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh + +-- * Ranked shapes type role ShR nominal representational type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) +newtype ShR n i = ShR (ShX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR +pattern ZSR <- ShR (matchZSR @n -> Just Refl) + where ZSR = ShR ZSX + +matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZSR _ = Nothing pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) +pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes i sh)) + where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh) infixr 3 :$: +data UnconsShRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsShRRes i (ShR n i) +shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1) +shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) + | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl + = Just (UnconsShRRes x (ShR sh')) +shrUncons (ShR _) = Nothing + {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int @@ -273,96 +268,160 @@ type IShR n = ShR n Int deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l + showsPrec d (ShR l) = showsPrec d l #endif -instance NFData i => NFData (ShR sh i) - -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh - | Refl <- lemRankReplicate (Proxy @n) - = shrFromShX sh - -shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shxFromShR idx) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' +shrEqRank ZSR ZSR = Just Refl +shrEqRank (_ :$: sh) (_ :$: sh') + | Just Refl <- shrEqRank sh sh' + = Just Refl +shrEqRank _ _ = Nothing -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' +shrEqual ZSR ZSR = Just Refl +shrEqual (i :$: sh) (i' :$: sh') + | Just Refl <- shrEqual sh sh' + , i == i' + = Just Refl +shrEqual _ _ = Nothing shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l +shrLength (ShR l) = shxLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh +shrRank :: forall n i. ShR n i -> SNat n +shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh +shrSize (ShR sh) = shxSize sh -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list +-- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@. +-- We don't report the size of the list in case of errors in order not to retain the list. +{-# INLINEABLE shrFromList #-} +shrFromList :: SNat n -> [Int] -> IShR n +shrFromList snat topl = ShR $ go snat topl + where + go :: SNat n -> [Int] -> ShX (Replicate n Nothing) Int + go SZ [] = ZSX + go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")" + go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is) + go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")" -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) +-- This is equivalent to but faster than @coerce shxToList@. +{-# INLINEABLE shrToList #-} +shrToList :: IShR n -> [Int] +shrToList (ShR l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil + go (ConsUnknown i rest) = i `cons` go rest + go ConsKnown{} = error "shrToList: impossible case" + in go l) -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) +shrHead :: forall n i. ShR (n + 1) i -> i +shrHead (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxHead @Nothing @(Replicate n Nothing) sh of + SUnknown i -> i -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list +shrTail :: forall n i. ShR (n + 1) i -> ShR n i +shrTail + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (shxTail @_ @_ @i) -shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) +{-# INLINEABLE shrTakeIx #-} +shrTakeIx :: forall n n' i j. Proxy n' -> IxR n j -> ShR (n + n') i -> ShR n i +shrTakeIx _ ZIR _ = ZSR +shrTakeIx p (_ :.: idx) sh = case sh of n :$: sh' -> n :$: shrTakeIx p idx sh' + +{-# INLINEABLE shrDropIx #-} +shrDropIx :: forall n n' i j. IxR n j -> ShR (n + n') i -> ShR n' i +shrDropIx ZIR long = long +shrDropIx (_ :.: short) long = case long of _ :$: long' -> shrDropIx short long' + +shrInit :: forall n i. ShR (n + 1) i -> ShR n i +shrInit + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = -- TODO: change this and all other unsafeCoerceRefl to lemmas: + gcastWith (unsafeCoerceRefl + :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ + coerce (shxInit @i) -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +shrLast :: forall n i. ShR (n + 1) i -> i +shrLast (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxLast sh of + SUnknown i -> i + SKnown{} -> error "shrLast: impossible SKnown" +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast SZ ZSR = ZSR +shrCast (SS n) (i :$: sh) = i :$: shrCast n sh +shrCast _ _ = error "shrCast: ranks don't match" + +shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i +shrAppend = + -- lemReplicatePlusApp requires an SNat + gcastWith (unsafeCoerceRefl + :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ + coerce (shxAppend @_ @i) + +{-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 +shrZipWith _ ZSR ZSR = ZSR +shrZipWith f (i :$: irest) (j :$: jrest) = + f i j :$: shrZipWith f irest jrest +shrZipWith _ _ _ = + error "shrZipWith: impossible pattern needlessly required" -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) +shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i) +shrSplitAt SZ sh = (ZSR, sh) +shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh) +shrSplitAt SS{} ZSR = error "m' + 1 <= 0" +shrIndex :: forall k sh i. SNat k -> ShR sh i -> i +shrIndex k (ShR sh) = case shxIndex @i k sh of + SUnknown i -> i + SKnown{} -> error "shrIndex: impossible SKnown" + +-- Copy-pasted from ixrPermutePrefix, probably unavoidably. +shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i +shrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case shrRank sh of { shlen@SNat -> + let sperm = shrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } + where + applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i + applyPermRFull _ ZSR _ = ZSR + applyPermRFull sm@SNat (i :$: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> shrIndex si l :$: applyPermRFull sm perm l + EQI -> shrIndex si l :$: applyPermRFull sm perm l + GTI -> error "shrPermutePrefix: Index in permutation out of range" --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList topl = go (SNat @n) topl - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error $ "IsList(ListR): Mismatched list length (type says " - ++ show (fromSNat (SNat @n)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IxR n i) where type Item (IxR n i) = i - fromList = IxR . IsList.fromList + fromList = ixrFromList (SNat @n) toList = Foldable.toList -- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +instance KnownNat n => IsList (IShR n) where + type Item (IShR n) = Int + fromList = shrFromList (SNat @n) + toList = shrToList diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 7e38aee..be1bfc5 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -14,7 +14,7 @@ module Data.Array.Nested.Shaped ( liftShaped1, liftShaped2, ) where -import Prelude hiding (mappend, mconcat) +import Prelude hiding (mappend) import Data.Array.Internal.RankedG qualified as RG import Data.Array.Internal.RankedS qualified as RS @@ -29,21 +29,20 @@ import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X -semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray :: forall sh a. KnownElt a => ShS sh -> Shaped (0 : sh) a semptyArray sh = Shaped (memptyArray (shxFromShS sh)) srank :: Elt a => Shaped sh a -> SNat (Rank sh) @@ -53,25 +52,32 @@ srank = shsRank . sshape ssize :: Elt a => Shaped sh a -> Int ssize = shsSize . sshape +{-# INLINEABLE sindex #-} sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - +{-# INLINEABLE sindexPartial #-} sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) idx (sshape sarr)) (Proxy @sh2))) arr) (ixxFromIxS idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX)) + +-- | See 'mgeneratePrim'. +{-# INLINE sgeneratePrim #-} +sgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => ShS sh -> (IxS sh i -> a) -> Shaped sh a +sgeneratePrim sh f = + let g i = f (ixsFromLinear sh i) + in sfromVector sh $ VS.generate (shsSize sh) g -- | See the documentation of 'mlift'. +{-# INLINE slift #-} slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -79,20 +85,28 @@ slift :: forall sh1 sh2 a. Elt a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. +{-# INLINE slift2 #-} slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) +{-# INLINE ssumOuter1PrimP #-} +ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr) + +{-# INLINEABLE ssumOuter1Prim #-} +ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive +{-# INLINE ssumAllPrimP #-} +ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a +ssumAllPrimP (Shaped arr) = msumAllPrimP arr +{-# INLINE ssumAllPrim #-} ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr @@ -102,8 +116,8 @@ stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemTakeLenMapJust perm (sshape sarr) , Refl <- lemDropLenMapJust perm (sshape sarr) - , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + , Refl <- lemPermuteMapJust perm (shsTakeLenPerm perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm (sshape sarr))) (Proxy @(DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a @@ -112,51 +126,59 @@ sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) +{-# INLINEABLE sfromVectorP #-} sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) +{-# INLINEABLE sfromVector #-} sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) +{-# INLINEABLE stoVectorP #-} stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP +{-# INLINEABLE stoVector #-} stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'. sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) +sfromListOuter = coerce mfromListOuterSN +-- | If the elements are scalars, 'sfromList1Prim' is faster. sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 +sfromList1 = coerce mfromList1SN -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) +-- | If the elements are scalars, 'sfromListPrimLinear' is faster. +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter +sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromList1Prim = coerce mfromList1PrimSN -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr +sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l) -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) +stoList :: Elt a => Shaped '[n] a -> [a] +stoList = map sunScalar . stoListOuter stoListLinear :: Elt a => Shaped sh a -> [a] stoListLinear (Shaped arr) = mtoListLinear arr +stoListPrim :: PrimElt a => Shaped '[n] a -> [a] +stoListPrim (Shaped arr) = mtoListPrim arr + +stoListPrimLinear :: PrimElt a => Shaped sh a -> [a] +stoListPrimLinear (Shaped arr) = mtoListPrimLinear arr + sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a sfromOrthotope sh (SS.A (SG.A arr)) = Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) @@ -177,44 +199,44 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') = Shaped arr -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b) szip = coerce mzip sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) sunzip = coerce munzip -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh1) - , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) - (shxFromShS sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) +srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b)) +srerankPrimP sh2 f (Shaped (M_Shaped arr)) + = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr)) -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr +-- | See the caveats at 'mrerankPrim'. +srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b) +srerankPrim sh2 f (Shaped (M_Shaped arr)) = + Shaped (M_Shaped (mrerankPrim (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr)) sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') = Shaped (mreplicate (shxFromShS sh) arr) -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) +sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x) -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) +sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a +sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = +sslice i n arr = let _ :$$ sh = sshape arr in slift (n :$$ sh) (\_ -> X.slice i n) arr @@ -225,21 +247,20 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr +sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE sdot1Inner #-} sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) @@ -251,6 +272,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) _ -> error "unreachable" +{-# INLINE sdot #-} -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 529ac21..a5e6247 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -30,13 +31,13 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) -- | A shape-typed array: the full shape of the array (the sizes of its @@ -44,7 +45,8 @@ import Data.Array.Strided.Arith -- these are "GHC.TypeLits" naturals, because we do not need induction over -- them and we want very large arrays to be possible. -- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- Like for 'Data.Array.Nested.Ranked.Base.Ranked', +-- the valid elements are described by the 'Elt' type class, -- and 'Shaped' itself is again an instance of 'Elt' as well. -- -- 'Shaped' is a newtype around a 'Mixed' of 'Just's. @@ -78,9 +80,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance Elt a => Elt (Shaped sh a) where + {-# INLINE mshape #-} mshape (M_Shaped arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Shaped arr) i = Shaped (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ @@ -88,13 +93,14 @@ instance Elt a => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a) + mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l)) mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] mtoListOuter (M_Shaped arr) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -103,6 +109,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -111,6 +118,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -130,28 +138,29 @@ instance Elt a => Elt (Shaped sh a) where type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + mshapeTree (Shaped arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) @@ -167,18 +176,30 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i + memptyArrayUnsafe sh | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i + memptyArrayUnsafe sh mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) @@ -201,15 +222,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where negate = liftShaped1 negate abs = liftShaped1 abs signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim" recip = liftShaped1 recip (/) = liftShaped2 (/) instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim" exp = liftShaped1 exp log = liftShaped1 log sqrt = liftShaped1 sqrt @@ -240,5 +261,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s satan2Array = liftShaped2 matan2Array +{-# INLINE sshape #-} sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shsFromShX (mshape arr) +sshape (Shaped arr) = coerce (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0b7d1c9..60e0252 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,9 +1,5 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -11,7 +7,6 @@ {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -21,6 +16,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -28,242 +24,173 @@ module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) +import Control.Exception (assert) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) -import GHC.Generics (Generic) +import GHC.Exts (build, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) -infixr 3 ::$ - -#ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListS sh f) -#else -instance (forall n. Show (f n)) => Show (ListS sh f) where - showsPrec _ = listsShow shows -#endif - -instance (forall m. NFData (f m)) => NFData (ListS n f) where - rnf ZS = () - rnf (x ::$ l) = rnf x `seq` rnf l - -data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) -listsUncons ZS = Nothing - --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqType ZS ZS = Just Refl -listsEqType (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , Just Refl <- listsEqType sh sh' - = Just Refl -listsEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqual ZS ZS = Just Refl -listsEqual (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listsEqual sh sh' - = Just Refl -listsEqual _ _ = Nothing - -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS -listsShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListS sh' f -> ShowS - go _ ZS = id - go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listsLength :: ListS sh f -> Int -listsLength = getSum . listsFold (\_ -> Sum 1) - -listsRank :: ListS sh f -> SNat (Rank sh) -listsRank ZS = SNat -listsRank (_ ::$ sh) = snatSucc (listsRank sh) - -listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is - -listsHead :: ListS (n : sh) f -> f n -listsHead (i ::$ _) = i - -listsTail :: ListS (n : sh) f -> ListS sh f -listsTail (_ ::$ sh) = sh - -listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f -listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh -listsInit (_ ::$ ZS) = ZS - -listsLast :: ListS (n : sh) f -> f (Last (n : sh)) -listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh -listsLast (n ::$ ZS) = n - -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f -listsAppend ZS idx' = idx' -listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' - -listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) -listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = - Fun.Pair i j ::$ listsZip is js - -listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g - -> ListS sh h -listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = - f i j ::$ listsZipWith f is js - -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f -listsTakeLenPerm PNil _ = ZS -listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh -listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLenPerm PNil sh = sh -listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh -listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh - --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" - -listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f -listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) - +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, Generic) +newtype IxS sh i = IxS (IxX (MapJust sh) i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS +pattern ZIS <- IxS (matchZIX -> Just Refl) + where ZIS = IxS ZIX + +matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZIX _ = Nothing pattern (:.$) :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) +pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l)) + where i :.$ IxS l = IxS (i :.% l) infixr 3 :.$ +data UnconsIxSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i) +ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1) +ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1) + , Refl <- lemMapJustCons @sh1 Refl = + Just (UnconsIxSRes i (IxS l)) +ixsUncons (IxS _) = Nothing + {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + showsPrec _ l = ixsShow shows l #endif -instance Functor (IxS sh) where - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l +ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS +ixsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> IxS sh' i -> ShowS + go _ ZIS = id + go prefix (x :.$ xs) = showString prefix . f x . go "," xs -instance NFData i => NFData (IxS sh i) +ixsRank :: IxS sh i -> SNat (Rank sh) +ixsRank ZIS = SNat +ixsRank (_ :.$ sh) = snatSucc (ixsRank sh) -ixsLength :: IxS sh i -> Int -ixsLength (IxS l) = listsLength l +{-# INLINE ixsFromList #-} +ixsFromList :: ShS sh -> [i] -> IxS sh i +ixsFromList sh l = assert (shsLength sh == length l) + $ IxS $ IsList.fromList l -ixsRank :: IxS sh i -> SNat (Rank sh) -ixsRank (IxS l) = listsRank l +{-# INLINE ixsFromIxS #-} +ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS sh l = assert (length sh == length l) + $ IxS $ IsList.fromList l ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh - ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = getConst (listsHead list) +ixsHead (i :.$ _) = i ixsTail :: IxS (n : sh) i -> IxS sh i -ixsTail (IxS list) = IxS (listsTail list) +ixsTail (_ :.$ sh) = sh ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i -ixsInit (IxS list) = IxS (listsInit list) +ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh +ixsInit (_ :.$ ZIS) = ZIS ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = getConst (listsLast list) +ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh +ixsLast (n :.$ ZIS) = n + +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @(Const i)) +ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ + coerce (ixxAppend @_ @_ @i) -ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js -ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +{-# INLINE ixsZipWith #-} +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js +ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i +ixsTakeLenPerm PNil _ = ZIS +ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh +ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i +ixsDropLenPerm PNil sh = sh +ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh +ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i +ixsPermute PNil _ = ZIS +ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) = + case ixsIndex i sh of + item -> item :.$ ixsPermute is sh + +ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j +ixsIndex SZ (n :.$ _) = n +ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh +ixsIndex _ ZIS = error "Index into empty shape" + ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh) + +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS = coerce + +{-# INLINEABLE ixsFromLinear #-} +ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i +ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i +ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i +ixsFromIxX = coerce + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' (ShS sh) = (coerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh + +-- * Shaped shapes -- | The shape of a shape-typed array given as a list of 'SNat' values. -- @@ -271,36 +198,48 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord, Generic) +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) + +instance Eq (ShS sh) where _ == _ = True +instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS +pattern ZSS <- ShS (matchZSX -> Just Refl) + where ZSS = ShS ZSX + +matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZSX _ = Nothing pattern (:$$) :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) - +pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh)) + where i :$$ ShS sh = ShS (SKnown i :$% sh) infixr 3 :$$ +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing + {-# COMPLETE ZSS, (:$$) #-} #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show (ShS sh) #else instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + showsPrec d (ShS shx) = showsPrec d shx #endif -instance NFData (ShS sh) where - rnf (ShS ZS) = () - rnf (ShS (SNat ::$ l)) = rnf (ShS l) - instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) @@ -308,62 +247,117 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l +shsLength (ShS shx) = shxLength shx -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) | Refl <- lemRankMapJust (Proxy @sh) = + shxRank shx -shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh +lemRankMapJust :: proxy sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust _ = unsafeCoerceRefl -shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh +shsSize :: ShS sh -> Int +shsSize (ShS sh) = shxSize sh -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) +-- | This is a partial @const@ that fails when the second argument +-- doesn't match the first. We don't report the size of the list +-- in case of errors in order not to retain the list. +{-# INLINEABLE shsFromList #-} +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0 where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shsFromShX (SUnknown _ :$% _) = error "impossible" + go :: ShX sh' Int -> [Int] -> () + go ZSX [] = () + go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")" + go (ConsKnown sn sh) (i : is) + | i == fromSNat' sn = go sh is + | otherwise = error "shsFromList: Value does not match typing" + go ConsUnknown{} _ = error "shsFromList: impossible case" + go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")" -shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +-- This is equivalent to but faster than @coerce shxToList@. +{-# INLINEABLE shsToList #-} +shsToList :: ShS sh -> [Int] +shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil + go ConsUnknown{} = error "shsToList: impossible case" + go (ConsKnown snat rest) = fromSNat' snat `cons` go rest + in go l) shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) +{-# INLINEABLE shsTakeIx #-} +shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh +shsTakeIx _ ZIS _ = ZSS +shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh' -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list +{-# INLINEABLE shsDropIx #-} +shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh' +shsDropIx ZIS long = long +shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long' + +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @Int) + +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @Int) -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) +shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLenPerm = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce (shxTakeLenPerm @Int) -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) +shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLenPerm = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce (shxDropLenPerm @Int) -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce (shxPermute @Int) + +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i (ShS sh) = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex @Int i sh of + SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm sh)) (shsDropLenPerm perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 @@ -381,7 +375,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -392,37 +386,14 @@ shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listsToList - -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShS sh => IsList (IxS sh i) where type Item (IxS sh i) = i - fromList = IxS . IsList.fromList + fromList = ixsFromList (knownShS @sh) toList = Foldable.toList -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = shsFromList (knownShS @sh) toList = shsToList diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index 838e2b0..8e98ff2 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -5,21 +5,28 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} +{-# OPTIONS -Wno-simplifiable-class-constraints #-} {-| This module is API-compatible with "Data.Array.Nested", except that inputs and -outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods -also have additional 'Show' constraints. +outputs of the methods are traced to 'stderr'. Thus the methods also have +additional 'Show' constraints. ->>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) ->>> length (show res) `seq` () -oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))] -oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))] -oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))] -oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))] -oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))] -oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))] ->>> res -Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35])))) +>>> rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) +oxtrace: (riota _ ... = rfromListLinear [6] [0,1,2,3,4,5]) +oxtrace: (rreshape [2,3] (rfromListLinear [6] [0,1,2,3,4,5]) ... = rfromListLinear [2,3] [0,1,2,3,4,5]) +oxtrace: (rtranspose [1,0] (rfromListLinear [2,3] [0,1,2,3,4,5]) ... = rfromListLinear [3,2] [0,3,1,4,2,5]) +oxtrace: (rscalar _ ... = rfromListLinear [] [7]) +oxtrace: (rreplicate [6] (rfromListLinear [] [7]) ... = rreplicate [6] 7) +oxtrace: (rreshape [3,2] (rreplicate [6] 7) ... = rreplicate [3,2] 7) +rfromListLinear [3,2] [0,21,7,28,14,35] + +The part up until and including the @...@ is printed after @seq@ing the +arguments; the @=@ and further is printed after @seq@ing the result of the +operation. Do note that tracing means that the functions in this module are +potentially __stricter__ than the plain ones in "Data.Array.Nested". + +Arguments that this module does not know how to @show@, probably due to +laziness on my side, are printed as @_@. -} module Data.Array.Nested.Trace ( -- * Traced variants @@ -27,20 +34,19 @@ module Data.Array.Nested.Trace ( -- * Re-exports from the plain "Data.Array.Nested" module Ranked(Ranked), - ListR(ZR, (:::)), IxR(..), IIxR, ShR(..), IShR, Shaped(Shaped), - ListS(ZS, (::$)), IxS(..), IIxS, ShS(..), KnownShS(..), Mixed, IxX(..), IIxX, - ShX(..), KnownShX(..), + ShX(..), KnownShX(..), IShX, StaticShX(..), SMayNat(..), + Conversion(..), Elt, PrimElt, @@ -51,10 +57,10 @@ module Data.Array.Nested.Trace ( Storable, SNat, pattern SNat, pattern SZ, pattern SS, - Perm(..), + Perm(..), PermR, IsPermutation, KnownPerm(..), - NumElt, FloatElt, + NumElt, IntElt, FloatElt, Rank, Product, Replicate, MapJust, @@ -67,4 +73,4 @@ import Data.Array.Nested.Trace.TH $(concat <$> mapM convertFun - ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped]) + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rgeneratePrim, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoListOuter, 'rtoList, 'rtoListLinear, 'rtoListPrim, 'rtoListPrimLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'sgeneratePrim, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoListOuter, 'stoList, 'stoListLinear, 'stoListPrim, 'stoListPrimLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'mgeneratePrim, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoListOuter, 'mtoList, 'mtoListLinear, 'mtoListPrim, 'mtoListPrimLinear, 'msliceN, 'msliceSN, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs index 4b388e3..644b4bd 100644 --- a/src/Data/Array/Nested/Trace/TH.hs +++ b/src/Data/Array/Nested/Trace/TH.hs @@ -4,11 +4,11 @@ module Data.Array.Nested.Trace.TH where import Control.Monad (zipWithM) -import Data.List (foldl', intersperse) +import Data.List (foldl') import Data.Maybe (isJust) import Language.Haskell.TH hiding (cxt) - -import Debug.Trace qualified as Debug +import System.IO (hPutStr, stderr) +import System.IO.Unsafe (unsafePerformIO) import Data.Array.Nested @@ -20,7 +20,7 @@ splitFunTy = \case in (vars, cx, t1 : args, ret) ForallT vs cx' t -> let (vars, cx, args, ret) = splitFunTy t - in (vars ++ vs, cx ++ cx', args, ret) + in (vs ++ vars, cx' ++ cx, args, ret) t -> ([], [], [], t) data Arg = RRanked Type Arg @@ -30,17 +30,27 @@ data Arg = RRanked Type Arg | ROther Type deriving (Show) --- TODO: always returns Just recognise :: Type -> Maybe Arg recognise (ConT name `AppT` sht `AppT` ty) - | name == ''Ranked = RRanked sht <$> recognise ty - | name == ''Shaped = RShaped sht <$> recognise ty - | name == ''Mixed = RMixed sht <$> recognise ty + | name == ''Ranked = Just (RRanked sht (recogniseElt ty)) + | name == ''Shaped = Just (RShaped sht (recogniseElt ty)) + | name == ''Mixed = Just (RMixed sht (recogniseElt ty)) + | name == ''Conversion = Just (RShowable ty) recognise ty@(ConT name `AppT` _) - | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] = + | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat, ''Perm] = Just (RShowable ty) +recognise ty@(ConT name) + | name == ''PermR = Just (RShowable ty) +recognise (ListT `AppT` ty) = Just (ROther ty) recognise _ = Nothing +recogniseElt :: Type -> Arg +recogniseElt (ConT name `AppT` sht `AppT` ty) + | name == ''Ranked = RRanked sht (recogniseElt ty) + | name == ''Shaped = RShaped sht (recogniseElt ty) + | name == ''Mixed = RMixed sht (recogniseElt ty) +recogniseElt ty = ROther ty + realise :: Arg -> Type realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty @@ -62,37 +72,58 @@ mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT '' mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty] mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty] -convertType :: Type -> Q (Type, [Bool], Bool) +-- If you pass a polymorphic function to seq, GHC wants to monomorphise and +-- doesn't know how to instantiate the type variables. Just don't, I guess. +isSeqable :: Type -> Bool +isSeqable ForallT{} = False +isSeqable (AppT a b) = isSeqable a && isSeqable b +isSeqable _ = True -- yolo, I guess + +convertType :: Type -> Q (Type, [Bool], [Bool], Bool) convertType typ = let (tybndrs, cxt, args, ret) = splitFunTy typ - argrels = map recognise args - retrel = recognise ret + argdescrs = map recognise args + retdescr = recognise ret in return (ForallT tybndrs (cxt ++ [constr - | Just rel <- retrel : argrels + | Just rel <- retdescr : argdescrs , constr <- mkShow rel]) (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args) - ,map isJust argrels - ,isJust retrel) + ,map isJust argdescrs + ,map isSeqable args + ,isJust retdescr) convertFun :: Name -> Q [Dec] convertFun funname = do defname <- newName (nameBase funname) - (convty, argarrs, retarr) <- reifyType funname >>= convertType - names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..] + -- "ok": whether we understand this type enough to be able to show it + (convty, argoks, argsseqable, retok) <- reifyType funname >>= convertType + names <- zipWithM (\_ i -> newName ('x' : show i)) argoks [1::Int ..] + -- let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) resname <- newName "res" - let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) + let traceCall str val = VarE 'traceNoNewline `AppE` str `AppE` val + let msg1 = [LitE (StringL ("oxtrace: (" ++ nameBase funname ++ " "))] ++ + [if ok + then VarE 'showsPrec `AppE` LitE (IntegerL 11) `AppE` VarE n `AppE` LitE (StringL " ") + else LitE (StringL "_ ") + | (n, ok) <- zip names argoks] ++ + [LitE (StringL "...")] + let msg2 | retok = [LitE (StringL " = "), VarE 'show `AppE` VarE resname, LitE (StringL ")\n")] + | otherwise = [LitE (StringL " = _)\n")] let ex = LetE [ValD (VarP resname) (NormalB (foldl' AppE (VarE funname) (map VarE names))) - []] - (VarE 'Debug.trace - `AppE` (VarE 'concat `AppE` ListE - ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++ - intersperse (LitE (StringL ", ")) - (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++ - [LitE (StringL "]")])) - `AppE` VarE resname) + []] $ + flip (foldr AppE) [VarE 'seq `AppE` VarE n | (n, True) <- zip names argsseqable] $ + traceCall (VarE 'concat `AppE` ListE msg1) $ + VarE 'seq `AppE` VarE resname `AppE` + traceCall (VarE 'concat `AppE` ListE msg2) (VarE resname) return [SigD defname convty ,FunD defname [Clause (map VarP names) (NormalB ex) []]] + +{-# NOINLINE traceNoNewline #-} +traceNoNewline :: String -> a -> a +traceNoNewline str x = unsafePerformIO $ do + hPutStr stderr str + return x diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index 4172fa0..ec1b3dc 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -6,7 +6,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} @@ -30,6 +30,7 @@ module Data.Array.Nested.Types ( Replicate, lemReplicateSucc, MapJust, + lemMapJustEmpty, lemMapJustCons, lemMapJustHead, Head, Tail, Init, @@ -45,7 +46,6 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Unsafe.Coerce qualified - -- Reasoning helpers subst1 :: forall f a b. a :~: b -> f a :~: f b @@ -58,8 +58,9 @@ subst2 Refl = Refl data Dict c a where Dict :: c a => Dict c a +{-# INLINE fromSNat' #-} fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) sameNat' n@SNat m@SNat = sameNat n m @@ -108,13 +109,23 @@ type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl +lemReplicateSucc :: forall a n proxy. + proxy n -> a : Replicate n a :~: Replicate (n + 1) a +lemReplicateSucc _ = unsafeCoerceRefl -type family MapJust l where +type family MapJust l = r | r -> l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + +lemMapJustHead :: proxy sh1 -> Head (MapJust sh1) :~: Just (Head sh1) +lemMapJustHead _ = unsafeCoerceRefl + type family Head l where Head (x : _) = x diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs index 5c38d14..e2cd17c 100644 --- a/src/Data/Array/Strided/Orthotope.hs +++ b/src/Data/Array/Strided/Orthotope.hs @@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve toO :: AS.Array n a -> RS.Array n a toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) +{-# INLINE liftO1 #-} liftO1 :: (AS.Array n a -> AS.Array n' b) -> RS.Array n a -> RS.Array n' b liftO1 f = toO . f . fromO +{-# INLINE liftO2 #-} liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c liftO2 f x y = toO (f (fromO x) (fromO y)) +-- We don't inline this lifting function, because its code is not just +-- a wrapper, being relatively long and expensive. +{-# INLINEABLE liftVEltwise1 #-} liftVEltwise1 :: (Storable a, Storable b) => SNat n -> (VS.Vector a -> VS.Vector b) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index dde06e3..9f3ee34 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -14,27 +17,33 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) +import Control.Monad (foldM_, foldM) +import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG import Data.Array.Internal.RankedS qualified as ORS -import Data.Array.Ranked qualified as ORB import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import Data.Type.Ord +import Data.Vector.Generic.Checked qualified as VGC import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import Unsafe.Coerce (unsafeCoerce) +#endif -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape import Data.Array.Strided.Orthotope @@ -53,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" +{-# INLINEABLE fromVector #-} fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh @@ -78,7 +88,7 @@ cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (ssxFromShX sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" @@ -108,15 +118,23 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- XArray . S.fromVector (shxShapeL sh) -- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) +{-# INLINEABLE indexPartial #-} indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx +{- Strangely, this increases allocation and there's no noticeable speedup: +indexPartial (XArray (ORS.A (ORG.A sh t))) ix = + let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t)) + len = ixxLength ix + in XArray (ORS.A (ORG.A (drop len sh) + OI.T{ OI.strides = drop len (OI.strides t) + , OI.offset = linear + , OI.values = OI.values t })) -} +{-# INLINEABLE index #-} index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' +index (XArray (ORS.A (ORG.A _ t))) i = + OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t))) append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a @@ -167,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -194,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -214,10 +232,15 @@ transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) | Dict <- lemKnownNatRankSSX ssh - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLenPerm perm ssh)) (ssxDropLenPerm perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) = XArray (S.transpose (permToList' perm) arr) +#else + = XArray (unsafeCoerce (S.transpose (permToList' perm) arr)) +#endif + -- | The list argument gives indices into the original dimension list. -- @@ -243,27 +266,23 @@ transpose2 ssh1 ssh2 (XArray arr) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a -sumFull _ (XArray arr) = - S.unScalar $ - liftO1 (numEltSum1Inner (SNat @0)) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr +sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX ssh'F = ssxFromShX sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) + , let sn = ssxRank ssh = XArray (liftO1 (numEltSum1Inner sn) arr') in go $ @@ -276,40 +295,83 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX in sumInner ssh' (ssxFromShX shF) $ transpose2 (ssxFromShX shF) ssh' $ reshapePartial ssh ssh' shF $ arr -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh - = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) +-- | This creates an array from a list of arrays of one less dimension. +-- The list is streamed, its length is checked and it's verified +-- that all arrays on the list have the same shape. +{-# INLINE fromListOuterSN #-} +fromListOuterSN :: forall n sh a. Storable a + => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a +fromListOuterSN m sh l + | Dict <- lemKnownNatRank sh + , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l + = case sh of + ZSX -> fromList1SN m (map S.unScalar (toList l')) + _ -> XArray (ravelOuterN (fromSNat' m) l') -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = - case S.shapeL arr of +-- | This checks that the list has the given length and that all shapes in the +-- list are equal. The list is streamed. +-- The first array in the list is forced early to potentially release some +-- memory, before allocating the (large) new array. +{-# INLINEABLE ravelOuterN #-} +ravelOuterN :: (KnownNat k, Storable a) + => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a +ravelOuterN 0 _ = error "ravelOuterN: N == 0" +ravelOuterN k as@(!a0 :| _) = runST $ do + let sh0 = S.shapeL a0 + len = product sh0 + vecSize = k * len + vec <- VSM.unsafeNew vecSize + let f !n (ORS.A (ORG.A sht t)) = + if | n >= k -> + error $ "ravelOuterN: list too long " ++ show (n, k) + -- if we do this check just once at the end, we may + -- crash instead of producing an accurate error message + | sht == sh0 -> do + let g off el = do + VS.unsafeCopy (VSM.slice off (VS.length el) vec) el + return $! off + VS.length el + foldM_ g (n * len) (OI.toVectorListT sht t) + return $! n + 1 + | otherwise -> + error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0) + nFinal <- foldM f 0 as + if nFinal == k + then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec + else error $ "ravelOuterN: list too short " ++ show (nFinal, k) + +toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) = + case shArr of + [] -> error "impossible" 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) + -- using orthotope's functions here would entail using rerank, which is slow, so we don't + [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr) + n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1] + +-- | Performance note: the list's spine is fully materialised to compute its +-- length before traversing it again to construct the array. +{-# INLINE fromList1 #-} +fromList1 :: Storable a => [a] -> XArray '[Nothing] a +fromList1 l = + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) +-- | The list is streamed. +{-# INLINE fromList1SN #-} +fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a +fromList1SN m l = + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) -toList1 :: Storable a => XArray '[n] a -> [a] -toList1 (XArray arr) = S.toList arr +toListLinear :: Storable a => XArray sh a -> [a] +toListLinear (XArray arr) = S.toList arr -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a diff --git a/src/Data/Vector/Generic/Checked.hs b/src/Data/Vector/Generic/Checked.hs new file mode 100644 index 0000000..d8aaaae --- /dev/null +++ b/src/Data/Vector/Generic/Checked.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Vector.Generic.Checked ( + fromListNChecked, +) where + +import Data.Stream.Monadic qualified as Stream +import Data.Vector.Fusion.Bundle.Monadic qualified as VBM +import Data.Vector.Fusion.Bundle.Size qualified as VBS +import Data.Vector.Fusion.Util qualified as VFU +import Data.Vector.Generic qualified as VG + +-- for INLINE_FUSED and INLINE_INNER +#include "vector.h" + + +-- These functions are copied over and lightly edited from the vector and +-- vector-stream packages, and thus inherit their BSD-3-Clause license with: +-- Copyright (c) 2008-2012, Roman Leshchinskiy +-- 2020-2022, Alexey Kuleshevich +-- 2020-2022, Aleksey Khudyakov +-- 2020-2022, Andrew Lelechenko + +fromListNChecked :: VG.Vector v a => Int -> [a] -> v a +{-# INLINE fromListNChecked #-} +fromListNChecked n = VG.unstream . bundleFromListNChecked n + +bundleFromListNChecked :: Int -> [a] -> VBM.Bundle VFU.Id v a +{-# INLINE_FUSED bundleFromListNChecked #-} +bundleFromListNChecked nTop xsTop + | nTop < 0 = error "fromListNChecked: length negative" + | otherwise = + VBM.fromStream (Stream.Stream step (xsTop, nTop)) (VBS.Max (VFU.delay_inline max nTop 0)) + where + {-# INLINE_INNER step #-} + step (xs,n) | n == 0 = case xs of + [] -> return Stream.Done + _:_ -> error "fromListNChecked: list too long" + step (x:xs,n) = return (Stream.Yield x (xs,n-1)) + step ([],_) = error "fromListNChecked: list too short" diff --git a/src/GHC/TypeLits/Orphans.hs b/src/GHC/TypeLits/Orphans.hs new file mode 100644 index 0000000..42f7293 --- /dev/null +++ b/src/GHC/TypeLits/Orphans.hs @@ -0,0 +1,13 @@ +-- | Compatibility module adding some additional instances not yet defined in +-- base-4.18 with GHC 9.6. +{-# OPTIONS -Wno-orphans #-} +module GHC.TypeLits.Orphans where + +import GHC.TypeLits + + +instance Eq (SNat n) where + _ == _ = True + +instance Ord (SNat n) where + compare _ _ = EQ |
