diff options
36 files changed, 2199 insertions, 1022 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..009d267 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,7 @@ +# Changelog for `ox-arrays` + +This package intends to follow the [PVP](https://pvp.haskell.org/). + +## 0.1.0.0 +- Initial release +- Various aspects of the API are still experimental, and breaking changes are expected in the future. @@ -1,56 +1,173 @@ -Wrapper library around `orthotope` that defines nested arrays, including -tuples, of (eventually) unboxed values. The arrays are represented in -struct-of-arrays form via the `Data.Vector.Unboxed` data family trick. Below -the surface layer, there is a more low-level wrapper around `orthotope` that -defines an array type type-indexed by `[Maybe Nat]`: some dimensions are -shape-typed (i.e. have their size statically known), and some not. +## ox-arrays -An overview of the API: +ox-arrays is an array library that defines nested arrays, including tuples, of +(eventually) unboxed values. The arrays are represented in struct-of-arrays +form via the `Data.Vector.Unboxed` data family trick; the component arrays are +`orthotope` arrays +([RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html)) +which describe elements using a _stride vector_ or +[LMAD](https://dl.acm.org/doi/pdf/10.1145/509705.509708) so that `transpose` +and `replicate` need only modify array metadata, not actually move around data. + +Because of the struct-of-arrays representation, nested arrays are not fully +general: indeed, arrays are not actually nested under the hood, so if one has an +array of arrays, those element arrays must all have the same shape (length, +width, etc.). If one has an array of tuples of arrays, then all the `fst` +components must have the same shape and all the `snd` components must have the +same shape, but the two pair components themselves can be different. + +However, the nesting functionality of ox-arrays can be completely ignored if you +only care about other parts of its API, or the vectorised arithmetic operations +(using hand-written C code). Nesting support mostly does not get in the way, and +has essentially no overhead (both when it's used and when it's not used). + +ox-arrays defines three array types: `Ranked`, `Shaped` and `Mixed`. +- `Ranked` corresponds to `orthotope`'s + [RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html) + and has the _rank_ of the array (its number of dimensions) on the type level. + For example, `Ranked 2 Float` is a two-dimensional array of `Float`s, i.e. a + matrix. +- `Shaped` corresponds to `orthotope`'s + [ShapedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html). + and has the full _shape_ of the array (its dimension sizes) on the type level + as a type-level list of `Nat`s. For example, `Shaped [2,3] Float` is a 2-by-3 + matrix. The innermost dimension correspond to the right-most element in the + list. +- `Mixed` is halfway between the two: it has a type parameter of kind + `[Maybe Nat]` whose length is the rank of the array; `Nothing` elements have + unknown size, whereas `Just` elements have the indicated size. The type + `Mixed [Nothing, Nothing] a` is equivalent to `Ranked 2 a`; the type + `Mixed [Just n, Just m] a` is equivalent to `Shaped [n, m] a`. + +In various places in the API of a library like ox-arrays, one can make a +decision between 1. requiring a type class constraint providing certain +information (e.g. +[KnownNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:KnownNat) +or `orthotope`'s +[Shape](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html#t:Shape)), +or 2. taking singleton _values_ that encode said information in a way that is +linked to the type level (e.g. +[SNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:SNat)). +`orthotope` chooses the type class approach; ox-arrays chooses the singleton +approach. Singletons are more verbose at times, but give the programmer more +insight in what data is flowing where, and more importantly, more control: type +class inference is very nice and implicit, but if it's not powerful enough for +the trickery you're doing, you're out of luck. Singletons allow you to explain +as precisely as you want to GHC what exactly you're doing. + +Below the surface layer, there is a more low-level wrapper (`XArray`) around +`orthotope` that defines a non-nested `Mixed`-style array type. + +**Be aware**: `ox-arrays` attempts to preserve sharing as much as possible. +That is to say: if a function is able to avoid copying array data and return an +array that references the original underlying `Vector`, it may do so. For +example, this means that if you convert a nested array to a list of arrays, all +returned arrays reference part of the original array without copying. This +makes `mtoList` fast, but also means that memory may be retained longer than +you might expect. + +Here is a little taster of the API, to get a sense for the design: ```haskell -data Ranked (n :: INat) a {- e.g. -} Ranked 3 Float -data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float -data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float - -Ranked I0 a = Ranked Z a ~~= Acc.Array Z a = Acc.Scalar a -Ranked I1 a = Ranked (S Z) a ~~= Acc.Array (Z :. Int) a = Acc.Vector a -Ranked I2 a = Ranked (S (S Z)) a ~~= Acc.Array (Z :. Int :. Int) a = Acc.Matrix a +import GHC.TypeLits (Nat, SNat) +data Ranked (n :: Nat) a {- e.g. -} Ranked 3 Float +data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float +data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float -rshape :: (Elt a, KnownINat n) => Ranked n a -> IxR n -sshape :: (Elt a, KnownShape sh) => Shaped sh a -> IxS sh -mshape :: (Elt a, KnownShapeX xsh) => Mixed xsh a -> IxX xsh +-- Shape types are written Sh{R,S,X}. The 'I' prefix denotes a Int-filled shape; +-- ShR and ShX are more general containers. ShS is a singleton. +rshape :: Elt a => Ranked n a -> IShR n +sshape :: Elt a => Shaped sh a -> ShS sh +mshape :: Elt a => Mixed xsh a -> IShX xsh -rindex :: Elt a => Ranked n a -> IxR n -> a -sindex :: Elt a => Shaped sh a -> IxS sh -> a -mindex :: Elt a => Mixed xsh a -> IxX xsh -> a +-- Index types are written Ix{R,S,X}. +rindex :: Elt a => Ranked n a -> IIxR n -> a +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +mindex :: Elt a => Mixed xsh a -> IIxX xsh -> a -data IxR n where - IZR :: IxR Z - (:::) :: Int -> IxR n -> IxR (S n) +-- The index types can be used as if they were defined as follows; pattern +-- synonyms are provided to construct the illusion. (The actual definitions are +-- a bit more general and indirect.) +data IIxR n where + ZIR :: IIxR 0 + (:.:) :: Int -> IIxR n -> IIxR (n + 1) -data IxS sh where - IZS :: IxS '[] - (::$) :: Int -> IxS sh -> IxS (n : sh) +data IIxS sh where + ZIS :: IIxS '[] + (:.$) :: Int -> IIxS sh -> IIxS (n : sh) -data IxX sh where - IZX :: IxX '[] - (::@) :: Int -> IxX sh -> IxX (Just n : sh) - (::?) :: Int -> IxX sh -> IxX (Nothing : sh) +data IIxX xsh where + ZIX :: IIxX '[] + (:.%) :: Int -> IIxX xsh -> IIxX (mn : xsh) +-- Similarly, the shape types can be used as if they were defined as follows. +data IShR n where + ZSR :: IShR 0 + (:$:) :: Int -> IShR n -> IShR (n + 1) + +data ShS sh where + ZSS :: ShS '[] + (:$$) :: SNat n -> ShS sh -> ShS (n : sh) + +data IShX xsh where + ZSX :: IShX '[] + (:$%) :: SMayNat Int mn -> IShX xsh -> IShX (mn : xsh) +-- where: +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: SNat n -> SMayNat i (Just n) + +-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed +-- shape -- that is to say, only the statically-known part of a mixed shape. +-- StaticShX provides for this need. It can be used as if defined as follows: +data StaticShX xsh where + ZKX :: StaticShX '[] + (:!%) :: SMayNat () mn -> StaticShX xsh -> StaticShX (mn : xsh) + +-- The Elt class describes types that can be used as elements of an array. While +-- it is technically possible to define new instances of this class, typical +-- usage should regard Elt as closed. The user-relevant instances are the +-- following: class Elt a -instance Elt () -instance Elt Double -instance Elt Int -instance (Elt a, Elt b) => Elt (a, b) -instance (Elt a, KnownINat n) => Elt (Ranked n a) -instance (Elt a, KnownShape sh) => Elt (Shaped sh a) -instance (Elt a, KnownShapeX xsh) => Elt (Mixed xsh a) +instance Elt () +instance Elt Bool +instance Elt Float +instance Elt Double +instance Elt Int +instance (Elt a, Elt b) => Elt (a, b) +instance Elt a => Elt (Ranked n a) +instance Elt a => Elt (Shaped sh a) +instance Elt a => Elt (Mixed xsh a) + +-- Essentially all functions that ox-arrays offers on arrays are first-order: +-- add two arrays elementwise, transpose an array, append arrays, compute +-- minima/maxima, zip/unzip, nest/unnest, etc. The first-order approach allows +-- operations, especially arithmetic ones, to be vectorised using hand-written +-- C code, without needing any sort of JIT compilation. +rappend :: Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +mappend :: Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -rgenerate :: Elt a => IxR n -> (IxR n -> a) -> Ranked n a -sgenerate :: (Elt a, KnownShape sh) => (IxS sh -> a) -> Shaped sh a -mgenerate :: (Elt a, KnownShapeX xsh) => IxX xsh -> (IxX xsh -> a) -> Mixed xsh a +-- Exceptionally, also one higher-order function is provided per array type: +-- 'generate'. These functions have the caveat that regularity of arrays must be +-- preserved: all returned 'a's must have equal shape. See the documentation of +-- 'mgenerate'. +-- Warning: because the invocations of the function you pass cannot be +-- vectorised, 'generate' is rather slow if 'a' is small. +-- The 'KnownElt' class captures an API infelicity where constraint-based shape +-- passing is the only practical option. +rgenerate :: KnownElt a => IShR n -> (IxR n -> a) -> Ranked n a +sgenerate :: KnownElt a => ShS sh -> (IxS sh -> a) -> Shaped sh a +mgenerate :: KnownElt a => IShX xsh -> (IxX xsh -> a) -> Mixed xsh a +-- Under the hood, Ranked and Shaped are both newtypes over Mixed. Mixed itself +-- is a data family over XArray, which is a newtype over orthotope's RankedS. newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) ``` + +About the name: when importing `orthotope` array modules, a possible naming +convention is to use qualified imports as `OR` for "orthotope ranked" arrays and +`OS` for "orthotope shaped" arrays. ox-arrays was started to fill the `OX` gap, +then grew out of proportion. diff --git a/bench/Main.hs b/bench/Main.hs index ef03b1a..ce3a9df 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -15,19 +15,12 @@ import Numeric.LinearAlgebra qualified as LA import Test.Tasty.Bench import Text.Show (showListWith) -import Data.Array.XArray (XArray(..)) import Data.Array.Nested import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked (liftRanked1, liftRanked2) import Data.Array.Strided.Arith.Internal qualified as Arith - - -enableMisc :: Bool -enableMisc = False - -bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark -bgroupIf True = bgroup -bgroupIf False = \name _ -> bgroup name [] +import Data.Array.XArray (XArray(..)) main :: IO () @@ -51,7 +44,7 @@ main_tests = defaultMain " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $ nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2) - iota n = riota @Double n + iota = riota @Double in [dotprodBench "dot 1D" (iota 10_000_000 @@ -104,7 +97,7 @@ main_tests = defaultMain in nf (\a -> RS.normalize a) (RS.rev [0] (RS.rev [0] (RS.iota @Double n))) ] - ,bgroupIf enableMisc "misc" + ,bgroup "misc" [let n = 1000 k = 1000 in bgroup ("fusion [" ++ show k ++ "]*" ++ show n) @@ -148,6 +141,16 @@ main_tests = defaultMain | ki <- [1::Int ..5]] ] ] + ,bench "ixxFromLinear 10000x" $ + let n = 10000 + sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> [ixxFromLinear @Int sh i | i <- [1..n]]) sh0 + ,bench "ixxFromLinear 1x" $ + let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> ixxFromLinear @Int sh 1234) sh0 + ,bench "shxEnum" $ + let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> shxEnum sh) sh0 ] ] @@ -156,45 +159,54 @@ tests_compare = let n = 1_000_000 in [bgroup "Num" [bench "sum(+) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (+)) a b))) (riota @Double n, riota n) ,bench "sum(*) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (*)) a b))) (riota @Double n, riota n) ,bench "sum(/) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (/)) a b))) (riota @Double n, riota n) ,bench "sum(**) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (**)) a b))) (riota @Double n, riota n) ,bench "sum(sin) Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a))) + nf (\a -> runScalar (rsumOuter1Prim (liftRanked1 (mliftPrim sin) a))) (riota @Double n) ,bench "sum Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 a)) + nf (\a -> runScalar (rsumOuter1Prim a)) + (riota @Double n) + ,bench "sumAll iota [1e6]" $ + nf (\a -> rsumAllPrim a) (riota @Double n) + ,bench "sumAll rev1(iota) [1e6]" $ + nf (\a -> rsumAllPrim a) + (rrev1 $ riota @Double n) + ,bench "sumAll reshape(iota) [1e6]" $ + nf (\a -> rsumAllPrim a) + (rreshape (1 :$: n :$: 1 :$: ZSR) $ riota @Double n) ] ,bgroup "NumElt" [bench "sum(+) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a + b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a + b))) (riota @Double n, riota n) ,bench "sum(*) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b))) (riota @Double n, riota n) ,bench "sum(/) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a / b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a / b))) (riota @Double n, riota n) ,bench "sum(**) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a ** b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a ** b))) (riota @Double n, riota n) ,bench "sum(sin) Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 (sin a))) + nf (\a -> runScalar (rsumOuter1Prim (sin a))) (riota @Double n) ,bench "sum Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 a)) + nf (\a -> runScalar (rsumOuter1Prim a)) (riota @Double n) ,bench "sum(*) Double [1e6] stride 1; -1" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b))) (riota @Double n, rrev1 (riota n)) ,bench "dotprod Float [1e6]" $ nf (\(a, b) -> rdot a b) diff --git a/cabal.project b/cabal.project index d102ed6..d76d872 100644 --- a/cabal.project +++ b/cabal.project @@ -1,2 +1,2 @@ packages: . -with-compiler: ghc-9.8.4 +with-compiler: ghc-9.12.2 diff --git a/cbits/arith.c b/cbits/arith.c index f19b01e..ee248a4 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -20,6 +20,8 @@ // Shorter names, due to CPP used both in function names and in C types. +typedef int8_t i8; +typedef int16_t i16; typedef int32_t i32; typedef int64_t i64; @@ -248,6 +250,8 @@ void oxarrays_stats_print_all(void) { #define GEN_ABS(x) \ _Generic((x), \ + i8: abs, \ + i16: abs, \ int: abs, \ long: labs, \ long long: llabs, \ @@ -490,7 +494,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { if (rank == 0) return arr[0]; \ typ result = 0; \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ - REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ + typ dest = 0; \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, dest); \ + result = result op dest; \ }); \ return result; \ } @@ -738,7 +744,7 @@ enum redop_tag_t { * Generate all the functions * *****************************************************************************/ -#define INT_TYPES_XLIST X(i32) X(i64) +#define INT_TYPES_XLIST X(i8) X(i16) X(i32) X(i64) #define FLOAT_TYPES_XLIST X(double) X(float) #define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index 432765c..dc9ad1a 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -2,38 +2,38 @@ LIST_BINOP(BO_ADD, 1, +) LIST_BINOP(BO_SUB, 2, -) LIST_BINOP(BO_MUL, 3, *) -LIST_IBINOP(IB_QUOT, 1, quot) -LIST_IBINOP(IB_REM, 2, rem) +LIST_IBINOP(IB_QUOT, 11, quot) +LIST_IBINOP(IB_REM, 12, rem) -LIST_FBINOP(FB_DIV, 1, /) -LIST_FBINOP(FB_POW, 2, **) -LIST_FBINOP(FB_LOGBASE, 3, logBase) -LIST_FBINOP(FB_ATAN2, 4, atan2) +LIST_FBINOP(FB_DIV, 21, /) +LIST_FBINOP(FB_POW, 22, **) +LIST_FBINOP(FB_LOGBASE, 23, logBase) +LIST_FBINOP(FB_ATAN2, 24, atan2) -LIST_UNOP(UO_NEG, 1,) -LIST_UNOP(UO_ABS, 2,) -LIST_UNOP(UO_SIGNUM, 3,) +LIST_UNOP(UO_NEG, 31,) +LIST_UNOP(UO_ABS, 32,) +LIST_UNOP(UO_SIGNUM, 33,) -LIST_FUNOP(FU_RECIP, 1,) -LIST_FUNOP(FU_EXP, 2,) -LIST_FUNOP(FU_LOG, 3,) -LIST_FUNOP(FU_SQRT, 4,) -LIST_FUNOP(FU_SIN, 5,) -LIST_FUNOP(FU_COS, 6,) -LIST_FUNOP(FU_TAN, 7,) -LIST_FUNOP(FU_ASIN, 8,) -LIST_FUNOP(FU_ACOS, 9,) -LIST_FUNOP(FU_ATAN, 10,) -LIST_FUNOP(FU_SINH, 11,) -LIST_FUNOP(FU_COSH, 12,) -LIST_FUNOP(FU_TANH, 13,) -LIST_FUNOP(FU_ASINH, 14,) -LIST_FUNOP(FU_ACOSH, 15,) -LIST_FUNOP(FU_ATANH, 16,) -LIST_FUNOP(FU_LOG1P, 17,) -LIST_FUNOP(FU_EXPM1, 18,) -LIST_FUNOP(FU_LOG1PEXP, 19,) -LIST_FUNOP(FU_LOG1MEXP, 20,) +LIST_FUNOP(FU_RECIP, 41,) +LIST_FUNOP(FU_EXP, 42,) +LIST_FUNOP(FU_LOG, 43,) +LIST_FUNOP(FU_SQRT, 44,) +LIST_FUNOP(FU_SIN, 45,) +LIST_FUNOP(FU_COS, 46,) +LIST_FUNOP(FU_TAN, 47,) +LIST_FUNOP(FU_ASIN, 48,) +LIST_FUNOP(FU_ACOS, 49,) +LIST_FUNOP(FU_ATAN, 50,) +LIST_FUNOP(FU_SINH, 51,) +LIST_FUNOP(FU_COSH, 52,) +LIST_FUNOP(FU_TANH, 53,) +LIST_FUNOP(FU_ASINH, 54,) +LIST_FUNOP(FU_ACOSH, 55,) +LIST_FUNOP(FU_ATANH, 56,) +LIST_FUNOP(FU_LOG1P, 57,) +LIST_FUNOP(FU_EXPM1, 58,) +LIST_FUNOP(FU_LOG1PEXP, 59,) +LIST_FUNOP(FU_LOG1MEXP, 60,) -LIST_REDOP(RO_SUM, 1,) -LIST_REDOP(RO_PRODUCT, 2,) +LIST_REDOP(RO_SUM, 81,) +LIST_REDOP(RO_PRODUCT, 82,) diff --git a/gentrace.sh b/gentrace.sh index 7be2b9c..c3f1240 100755 --- a/gentrace.sh +++ b/gentrace.sh @@ -8,7 +8,7 @@ module Data.Array.Nested.Trace ( -- * Re-exports from the plain "Data.Array.Nested" module EOF -sed -n '/^module/,/^) where/!d; /^\s*-- /d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs +sed -n '/^module/,/^) where/!d; /^\s*--\( \|$\)/d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs cat <<'EOF' ) where diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index 5802573..2eb0666 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -27,7 +27,7 @@ import Foreign.C.Types import Foreign.Ptr import Foreign.Storable import GHC.TypeLits -import GHC.TypeNats qualified as TypeNats +import GHC.TypeNats qualified as TN import Language.Haskell.TH import System.IO (hFlush, stdout) import System.IO.Unsafe @@ -42,7 +42,7 @@ import Data.Array.Strided.Array -- TODO: move this to a utilities module fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat data Dict c where Dict :: c => Dict c @@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) = unrepSize = product [n | (n, True) <- zip sh replDims] - in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims) simplifyArray :: Array n a @@ -200,7 +200,7 @@ simplifyArray :: Array n a -> r) -> r simplifyArray array k - | let revDims = map (<0) (arrStrides array) + | let revDims = map (< 0) (arrStrides array) , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array) = k array' unrepSize @@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k , let unrepSize = product [n | (n, True) <- zip sh replDims] - = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> k @lenshF (Array shF strides1F offset1 vec1) (Array shF strides2F offset2 vec2) @@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off VS.unsafeWith vec' $ \pv -> let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv') - TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict @@ -396,6 +396,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off Nothing -> error "impossible" -- TODO: test handling of negative strides +-- TODO: simplify away normalised dimensions -- | Reduce full array {-# NOINLINE vectorRedFullOp #-} vectorRedFullOp :: forall a b n. (Num a, Storable a) @@ -490,7 +491,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) - TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict @@ -714,6 +715,36 @@ class NumElt a where numEltMaxIndex :: SNat n -> Array n a -> [Int] numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a +instance NumElt Int8 where + numEltAdd = addVectorInt8 + numEltSub = subVectorInt8 + numEltMul = mulVectorInt8 + numEltNeg = negVectorInt8 + numEltAbs = absVectorInt8 + numEltSignum = signumVectorInt8 + numEltSum1Inner = sum1VectorInt8 + numEltProduct1Inner = product1VectorInt8 + numEltSumFull = sumFullVectorInt8 + numEltProductFull = productFullVectorInt8 + numEltMinIndex _ = minindexVectorInt8 + numEltMaxIndex _ = maxindexVectorInt8 + numEltDotprodInner = dotprodinnerVectorInt8 + +instance NumElt Int16 where + numEltAdd = addVectorInt16 + numEltSub = subVectorInt16 + numEltMul = mulVectorInt16 + numEltNeg = negVectorInt16 + numEltAbs = absVectorInt16 + numEltSignum = signumVectorInt16 + numEltSum1Inner = sum1VectorInt16 + numEltProduct1Inner = product1VectorInt16 + numEltSumFull = sumFullVectorInt16 + numEltProductFull = productFullVectorInt16 + numEltMinIndex _ = minindexVectorInt16 + numEltMaxIndex _ = maxindexVectorInt16 + numEltDotprodInner = dotprodinnerVectorInt16 + instance NumElt Int32 where numEltAdd = addVectorInt32 numEltSub = subVectorInt32 @@ -830,6 +861,14 @@ class NumElt a => IntElt a where intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a intEltRem :: SNat n -> Array n a -> Array n a -> Array n a +instance IntElt Int8 where + intEltQuot = quotVectorInt8 + intEltRem = remVectorInt8 + +instance IntElt Int16 where + intEltQuot = quotVectorInt16 + intEltRem = remVectorInt16 + instance IntElt Int32 where intEltQuot = quotVectorInt32 intEltRem = remVectorInt32 @@ -840,19 +879,19 @@ instance IntElt Int64 where instance IntElt Int where intEltQuot = intWidBranch2 @Int quot - (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) - (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT)) intEltRem = intWidBranch2 @Int rem - (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) - (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM)) instance IntElt CInt where intEltQuot = intWidBranch2 @CInt quot - (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) - (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT)) intEltRem = intWidBranch2 @CInt rem - (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) - (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM)) class NumElt a => FloatElt a where floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs index 910a77c..27204d2 100644 --- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs +++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs @@ -16,7 +16,9 @@ data ArithType = ArithType intTypesList :: [ArithType] intTypesList = - [ArithType ''Int32 "i32" + [ArithType ''Int8 "i8" + ,ArithType ''Int16 "i16" + ,ArithType ''Int32 "i32" ,ArithType ''Int64 "i64" ] diff --git a/ox-arrays.cabal b/ox-arrays.cabal index d9a345b..83fe93f 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -5,7 +5,11 @@ synopsis: An efficient CPU-based multidimensional array (tensor) library description: An efficient and richly typed CPU-based multidimensional array (tensor) library built upon the optimized tensor representation (strides list) - implemented in the orthotope package. + implemented in the orthotope package. See the README. + + If you use this package: let me know (e.g. via email) if you find it useful! + Both positive feedback (keep this!) and negative feedback (I needed this but + ox-arrays doesn't provide it) is welcome. copyright: (c) 2025 Tom Smeding, Mikolaj Konarski author: Tom Smeding, Mikolaj Konarski maintainer: Tom Smeding <xhackage@tomsmeding.com> @@ -13,6 +17,7 @@ license: BSD-3-Clause category: Array, Tensors build-type: Simple +extra-doc-files: README.md CHANGELOG.md extra-source-files: cbits/arith_lists.h flag trace-wrappers @@ -20,7 +25,7 @@ flag trace-wrappers Compile modules that define wrappers around the array methods that trace their arguments and results. This is conditional on a flag because these modules make documentation generation fail. - (https://gitlab.haskell.org/ghc/ghc/-/issues/24964 , should be fixed in + (@https://gitlab.haskell.org/ghc/ghc/-/issues/24964@ , should be fixed in GHC 9.12) default: False manual: True @@ -50,16 +55,23 @@ flag default-show-instances default: False manual: True +common basics + default-language: Haskell2010 + ghc-options: -Wall -Wcompat -Widentities -Wunused-packages -Wpartial-fields -Wredundant-bang-patterns -Woperator-whitespace -Wredundant-strictness-flags + if impl(ghc >= 9.14) + ghc-options: -Wno-pattern-namespace-specifier + library + import: basics exposed-modules: -- put this module on top so ghci considers it the "main" module Data.Array.Nested - Data.Array.Mixed.Lemmas - Data.Array.Nested.Internal.Lemmas Data.Array.Nested.Convert Data.Array.Nested.Mixed Data.Array.Nested.Mixed.Shape + Data.Array.Nested.Mixed.Shape.Internal + Data.Array.Nested.Lemmas Data.Array.Nested.Permutation Data.Array.Nested.Ranked Data.Array.Nested.Ranked.Base @@ -71,6 +83,11 @@ library Data.Array.Strided.Orthotope Data.Array.XArray Data.Bag + Data.Vector.Generic.Checked + + if impl(ghc < 9.8) + exposed-modules: + GHC.TypeLits.Orphans if flag(trace-wrappers) exposed-modules: @@ -81,21 +98,22 @@ library cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES build-depends: - strided-array-ops, + ox-arrays:strided-array-ops, base, deepseq < 1.7, ghc-typelits-knownnat, ghc-typelits-natnormalise, - orthotope < 0.2, - vector + orthotope >= 0.1.8.0 && < 0.2, + template-haskell, + vector, + vector-stream hs-source-dirs: src - - default-language: Haskell2010 - ghc-options: -Wall -Wcompat -Widentities -Wunused-packages other-extensions: TemplateHaskell library strided-array-ops + import: basics + visibility: public exposed-modules: Data.Array.Strided Data.Array.Strided.Array @@ -105,9 +123,11 @@ library strided-array-ops Data.Array.Strided.Arith.Internal.Lists Data.Array.Strided.Arith.Internal.Lists.TH build-depends: - base >=4.18 && <4.22, - ghc-typelits-knownnat < 1, - ghc-typelits-natnormalise < 1, + base >=4.18 && <4.23, + ghc-typelits-knownnat >= 0.8.0 && < 1 + -- 0.9.0 is unsound: https://github.com/clash-lang/ghc-typelits-natnormalise/issues/105 + && (< 0.9.0 || > 0.9.0), + ghc-typelits-natnormalise >= 0.8.1 && < 1, template-haskell < 3, vector < 0.14 hs-source-dirs: ops @@ -122,11 +142,10 @@ library strided-array-ops -- hmatrix assumes sse2, so we can too cc-options: -msse2 - default-language: Haskell2010 - ghc-options: -Wall -Wcompat -Widentities -Wunused-packages other-extensions: TemplateHaskell test-suite test + import: basics type: exitcode-stdio-1.0 main-is: Main.hs other-modules: @@ -147,33 +166,29 @@ test-suite test tasty-hedgehog, vector hs-source-dirs: test - default-language: Haskell2010 - ghc-options: -Wall -Wcompat -Widentities -Wunused-packages test-suite example + import: basics type: exitcode-stdio-1.0 main-is: Main.hs build-depends: ox-arrays, base hs-source-dirs: example - default-language: Haskell2010 - ghc-options: -Wall -Wcompat -Widentities -Wunused-packages benchmark bench + import: basics type: exitcode-stdio-1.0 main-is: Main.hs build-depends: ox-arrays, - strided-array-ops, + ox-arrays:strided-array-ops, base, hmatrix, orthotope, tasty-bench, vector hs-source-dirs: bench - default-language: Haskell2010 - ghc-options: -Wall -Wcompat -Widentities -Wunused-packages source-repository head type: git diff --git a/release-hints.txt b/release-hints.txt index 259c671..2623caa 100644 --- a/release-hints.txt +++ b/release-hints.txt @@ -1,2 +1,5 @@ - Temporarily enable -Wredundant-constraints - Has too many false-positives to enable normally, but sometimes catches actual redundant constraints +- Don't forget to rerun gentrace.sh +- Test with GHC 9.6, it's rather picky around type-level nats + - Whenever we drop support for GHC 9.6, search for "9,8" and remove all the conditionals, as well as the GHC.TypeLits.Orphans module diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9801529..c898a75 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -6,12 +6,17 @@ module Data.Array.Nested ( ListR(ZR, (:::)), IxR(.., ZIR, (:.:)), IIxR, ShR(.., ZSR, (:$:)), IShR, - rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim, + rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim, rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, remptyArray, - rrerank, - rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, - rfromListLinear, rfromListPrimLinear, rtoListLinear, + rrerankPrim, + rreplicate, rreplicatePrim, + rfromListOuter, rfromListOuterN, + rfromList1, rfromList1N, + rfromListLinear, + rfromList1Prim, rfromList1PrimN, + rfromListPrimLinear, + rtoListOuter, rtoList, rtoListLinear, rtoListPrim, rtoListPrimLinear, rslice, rrev1, rreshape, rflatten, riota, rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot, rnest, runNest, rzip, runzip, @@ -19,7 +24,7 @@ module Data.Array.Nested ( rlift, rlift2, -- ** Conversions rtoXArrayPrim, rfromXArrayPrim, - rcastToShaped, rtoMixed, rcastToMixed, + rtoMixed, rcastToMixed, rcastToShaped, rfromOrthotope, rtoOrthotope, -- ** Additional arithmetic operations -- @@ -31,13 +36,14 @@ module Data.Array.Nested ( ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, ShS(.., ZSS, (:$$)), KnownShS(..), - sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim, + sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, -- TODO: sconcat? What should its type be? semptyArray, - srerank, - sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, - sfromListLinear, sfromListPrimLinear, stoListLinear, + srerankPrim, + sreplicate, sreplicatePrim, + sfromListOuter, sfromList1, sfromListLinear, sfromList1Prim, sfromListPrimLinear, + stoListOuter, stoList, stoListLinear, stoListPrim, stoListPrimLinear, sslice, srev1, sreshape, sflatten, siota, sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot, snest, sunNest, szip, sunzip, @@ -45,7 +51,7 @@ module Data.Array.Nested ( slift, slift2, -- ** Conversions stoXArrayPrim, sfromXArrayPrim, - stoRanked, stoMixed, scastToMixed, + stoMixed, scastToMixed, stoRanked, sfromOrthotope, stoOrthotope, -- ** Additional arithmetic operations -- @@ -59,13 +65,18 @@ module Data.Array.Nested ( ShX(.., ZSX, (:$%)), KnownShX(..), IShX, StaticShX(.., ZKX, (:!%)), SMayNat(..), - mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim, + mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim, mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, memptyArray, - mrerank, - mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, - mfromListLinear, mfromListPrimLinear, mtoListLinear, - mslice, mrev1, mreshape, mflatten, miota, + mrerankPrim, + mreplicate, mreplicatePrim, + mfromListOuter, mfromListOuterN, mfromListOuterSN, + mfromList1, mfromList1N, mfromList1SN, + mfromListLinear, + mfromList1Prim, mfromList1PrimN, mfromList1PrimSN, + mfromListPrimLinear, + mtoListOuter, mtoList, mtoListLinear, mtoListPrim, mtoListPrimLinear, + msliceN, msliceSN, mslice, mrev1, mreshape, mflatten, miota, mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot, mnest, munNest, mzip, munzip, -- ** Lifting orthotope operations to 'Mixed' arrays @@ -73,8 +84,8 @@ module Data.Array.Nested ( -- ** Conversions mtoXArrayPrim, mfromXArrayPrim, mcast, - mtoRanked, mcastToShaped, - castCastable, Castable(..), + mcastToShaped, mtoRanked, + convert, Conversion(..), -- ** Additional arithmetic operations -- -- $integralRealFloat @@ -91,7 +102,7 @@ module Data.Array.Nested ( Storable, SNat, pattern SNat, pattern SZ, pattern SS, - Perm(..), + Perm(..), PermR, IsPermutation, KnownPerm(..), NumElt, IntElt, FloatElt, @@ -102,23 +113,23 @@ module Data.Array.Nested ( import Prelude hiding (mappend, mconcat) -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types import Data.Array.Nested.Convert import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shaped import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith import Foreign.Storable import GHC.TypeLits -- $integralRealFloat -- --- These functions separate top-level functions, and not exposed in instances --- for 'RealFloat' and 'Integral', because those classes include a variety of --- other functions that make no sense for arrays. +-- These functions are separate top-level functions, and not exposed in +-- instances for 'RealFloat' and 'Integral', because those classes include a +-- variety of other functions that make no sense for arrays. -- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but -- having 'Num', 'Fractional' and 'Floating' available is just too useful. diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d5e6008..8c88d23 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,42 +1,293 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) {-# LANGUAGE TypeAbstractions #-} +#endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Data.Array.Nested.Convert ( - castCastable, - Castable(..), + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', - -- * Special cases + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions -- - -- | These functions can all be implemented using 'castCastable' in some way, + -- | These functions can all be implemented using 'convert' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, - - -- * Additional index/shape casting functions - ixrFromIxS, shrFromShS, ) where import Control.Category import Data.Proxy import Data.Type.Equality +import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: IxS sh i -> IxR (Rank sh) i +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported +-- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx + +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +-- shsFromShX re-exported + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn@SNat :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl +-- * Special cases of array conversions + mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a mcast ssh2 arr @@ -45,7 +296,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) +mtoRanked = convert ConvXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -59,7 +310,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) +mcastToShaped targetsh = convert (ConvXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr @@ -82,91 +333,3 @@ rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr - -ixrFromIxS :: IxS sh i -> IxR (Rank sh) i -ixrFromIxS ZIS = ZIR -ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix - --- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh --- ixsFromIxR = \ix -> go ix _ --- where --- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r --- go ZIR k = k ZIS --- go (i :.: ix) k = go ix (i :.$) - -shrFromShS :: ShS sh -> IShR (Rank sh) -shrFromShS ZSS = ZSR -shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh - --- | The constructors that perform runtime shape checking are marked with a --- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure --- that the shapes are already compatible. To convert between 'Ranked' and --- 'Shaped', go via 'Mixed'. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Elt b - => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - - CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' - -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) - where - -- The 'esh' is the extension shape: the casting happens under a whole - -- bunch of additional dimensions that it does not touch. These dimensions - -- are 'esh'. - -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) - = let x' = go c x - ssx' = ssxAppend (ssxFromShX esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) - in M_Ranked (M_Nest esh (mcast ssx' x')) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) - (go c x))) - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) - - lemRankAppRankEq :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ sh') - lemRankAppRankEq _ _ _ = unsafeCoerceRefl - - lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh - -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) - lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl - - lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs deleted file mode 100644 index b1589e0..0000000 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemTakeLenMapJust PNil _ = Refl -lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl -lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemDropLenMapJust PNil _ = Refl -lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl -lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" - -lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemIndexMapJust SZ (_ :$$ _) = Refl -lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemIndexMapJust i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemIndexMapJust _ ZSS = error "Index of empty" - -lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemPermuteMapJust PNil _ = Refl -lemPermuteMapJust (i `PCons` is) sh - | Refl <- lemPermuteMapJust is sh - , Refl <- lemIndexMapJust i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e6d970c..e089479 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -6,7 +6,7 @@ {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Lemmas where +module Data.Array.Nested.Lemmas where import Data.Proxy import Data.Type.Equality @@ -14,10 +14,11 @@ import GHC.TypeLits import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types --- * Lemmas +-- * Lemmas about numbers and lists -- ** Nat @@ -27,7 +28,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl - -- ** Append lemAppNil :: l ++ '[] :~: l @@ -39,42 +39,22 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l lemAppLeft _ Refl = Refl - --- ** Rank - -lemRankApp :: forall sh1 sh2. - StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp ZKX _ = Refl -lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 - = lem (Proxy @(Rank sh1T)) Proxy Proxy $ - sym (lemRankApp ssh1 ssh2) - where - lem :: proxy a -> proxy b -> proxy c - -> (a + b :~: c) - -> c + 1 :~: (a + 1 + b) - lem _ _ _ Refl = Refl - -lemRankAppComm :: proxy sh1 -> proxy sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl - -lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = unsafeCoerceRefl - - --- ** Various type families +-- ** Simple type families lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +{- for now, the plugins can't derive a type for this code, see + https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214 lemReplicatePlusApp sn _ _ = go sn where go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 + | Refl <- lemReplicateSucc @a n , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) + = sym (lemReplicateSucc @a (SNat @(n'm1 + m))) +-} +lemReplicatePlusApp _ _ _ = unsafeCoerceRefl lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest @@ -107,6 +87,8 @@ lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +-- * Lemmas about shapes + -- ** Known shapes lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) @@ -116,3 +98,69 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh lemKnownShX ZKX = Dict lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54bd5f2..0766e8c 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -7,6 +7,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -22,12 +23,14 @@ module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) +import Control.Monad (foldM_, forM_, when) import Control.Monad.ST +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce -import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) @@ -38,17 +41,18 @@ import Data.Vector.Storable qualified as VS import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types (CInt) import Foreign.Storable (Storable) +import Foreign.Storable qualified as Storable import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope import Data.Array.XArray (XArray(..)) import Data.Array.XArray qualified as X -import Data.Array.Nested.Mixed.Shape -import Data.Array.Strided.Orthotope import Data.Bag @@ -91,6 +95,9 @@ import Data.Bag -- Unfortunately, the setup of the library requires us to list these primitive -- element types multiple times; to aid in extending the list, all these lists -- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. +-- +-- NOTE: if you add PRIMITIVE types, be sure to also add NumElt and IntElt +-- instances for them! -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -118,6 +125,8 @@ instance PrimElt Bool instance PrimElt Int instance PrimElt Int64 instance PrimElt Int32 +instance PrimElt Int16 +instance PrimElt Int8 instance PrimElt CInt instance PrimElt Float instance PrimElt Double @@ -154,6 +163,8 @@ newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int16 = M_Int16 (Mixed sh (Primitive Int16)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int8 = M_Int8 (Mixed sh (Primitive Int8)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW) @@ -190,6 +201,8 @@ newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh Int16 = MV_Int16 (VS.MVector s Int16) +newtype instance MixedVecs s sh Int8 = MV_Int8 (VS.MVector s Int8) newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) @@ -227,11 +240,13 @@ instance Elt a => NFData (Mixed sh a) where rnf = mrnf +{-# INLINE mliftNumElt1 #-} mliftNumElt1 :: (PrimElt a, PrimElt b) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) -> Mixed sh a -> Mixed sh b mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) +{-# INLINE mliftNumElt2 #-} mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) -> Mixed sh a -> Mixed sh b -> Mixed sh c @@ -247,15 +262,15 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where abs = mliftNumElt1 (liftO1 . numEltAbs) signum = mliftNumElt1 (liftO1 . numEltSignum) -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS - fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" + fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicatePrim" recip = mliftNumElt1 (liftO1 . floatEltRecip) (/) = mliftNumElt2 (liftO2 . floatEltDiv) instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicatePrim" exp = mliftNumElt1 (liftO1 . floatEltExp) log = mliftNumElt1 (liftO1 . floatEltLog) sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) @@ -298,15 +313,9 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + -- | See 'mfromListOuter'. If the list does not have the given length, a + -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable. + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -351,11 +360,14 @@ class Elt a where -- | Tree giving the shape of every array component. type ShapeTree a + -- | Produces an internal representation of a tree of shapes of (potentially) + -- nested arrays. If the argument is an array, it requires that the array + -- is not empty (otherwise, its' guaranteed to crash early, if non-trivial). mshapeTree :: a -> ShapeTree a mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool mshowShapeTree :: Proxy a -> ShapeTree a -> String @@ -363,26 +375,28 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | 'mvecsFreeze' but without copying the mutable vectors before freezing + -- them. If the mutable vectors are mutated after this function, referential + -- transparency is broken. + mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances -- of this class with those of 'Elt': some instances have an additional -- "known-shape" constraint. -- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. class Elt a => KnownElt a where -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArrayUnsafe :: IShX sh -> Mixed sh a @@ -391,20 +405,27 @@ class Elt a => KnownElt a where -- this vector and an example for the contents. mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + -- | Create initialised vectors for this array type, given the shape of + -- this vector and the chosen element. + mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh + {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + {-# INLINEABLE mindexPartial #-} + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mfromListOuterSN sn l@(arr1 :| _) = + let sh = mshape arr1 + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l)) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -415,6 +436,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) @@ -426,6 +448,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -440,7 +463,7 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -457,27 +480,33 @@ instance Storable a => Elt (Primitive a) where type ShapeTree (Primitive a) = () mshapeTree _ = () mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False + mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x - -- TODO: this use of toVector is suboptimal - mvecsWritePartial + -- TODO: this use of toVectorListT is suboptimal + mvecsWritePartialLinear :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + offset = i * shxSize arrsh + f off el = do + VS.copy (VSM.slice off (VS.length el) v) el + return $! off + VS.length el + foldM_ f offset (OI.toVectorListT sht t) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v -- [PRIMITIVE ELEMENT TYPES LIST] deriving via Primitive Bool instance Elt Bool deriving via Primitive Int instance Elt Int deriving via Primitive Int64 instance Elt Int64 deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive Int16 instance Elt Int16 +deriving via Primitive Int8 instance Elt Int8 deriving via Primitive CInt instance Elt CInt deriving via Primitive Double instance Elt Double deriving via Primitive Float instance Elt Float @@ -486,6 +515,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -493,6 +523,8 @@ deriving via Primitive Bool instance KnownElt Bool deriving via Primitive Int instance KnownElt Int deriving via Primitive Int64 instance KnownElt Int64 deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive Int16 instance KnownElt Int16 +deriving via Primitive Int8 instance KnownElt Int8 deriving via Primitive CInt instance KnownElt CInt deriving via Primitive Double instance KnownElt Double deriving via Primitive Float instance KnownElt Float @@ -504,12 +536,15 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mfromListOuterSN sn l = + M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + {-# INLINE mlift #-} mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + {-# INLINE mlift2 #-} mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + {-# INLINE mliftL #-} mliftL ssh2 f = let unzipT2l [] = ([], []) unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) @@ -531,20 +566,22 @@ instance (Elt a, Elt b) => Elt (a, b) where type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. @@ -557,23 +594,23 @@ instance Elt a => Elt (Mixed sh' a) where = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i + mindex (M_Nest _ arr) = mindexPartial arr mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + mfromListOuterSN sn l@(arr :| _) = + M_Nest (SKnown sn :$% mshape arr) + (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l)) mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) @@ -591,6 +628,7 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) @@ -609,6 +647,7 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) @@ -632,14 +671,14 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh + | let sh' = shxDropSh @sh @sh' sh (mshape arr) , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -657,28 +696,33 @@ instance Elt a => Elt (Mixed sh' a) where type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + -- This requires that @arr@ is not empty. mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + -- the array is empty if either there are no subarrays, or the subarrays themselves are empty + mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) @@ -689,10 +733,30 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where where sh' = mshape example + mvecsReplicate sh example = do + vecs <- mvecsUnsafeNew sh example + forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs + -- this is a slow case, but the alternative, mvecsUnsafeNew with manual + -- writing in a loop, leads to every case being as slow + return vecs + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + +-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere +memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) mrank :: Elt a => Mixed sh a -> SNat (Rank sh) @@ -719,38 +783,52 @@ msize = shxSize . mshape -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. +-- +-- If your element type @a@ is a scalar, use the faster 'mgeneratePrim'. mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a mgenerate sh f = case shxEnum sh of [] -> memptyArrayUnsafe sh firstidx : restidxs -> let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree + in if mshapeTreeIsEmpty (Proxy @a) shapetree then memptyArrayUnsafe sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ error "Data.Array.Nested mgenerate: generated values do not have equal shapes" mvecsWrite sh idx val vecs - mvecsFreeze sh vecs + mvecsUnsafeFreeze sh vecs -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = +-- | An optimized special case of 'mgenerate', where the function results +-- are of a primitive type and so there's not need to check that all shapes +-- are equal. This is also generalized to an arbitrary @Num@ index type +-- compared to @mgenerate@. +{-# INLINE mgeneratePrim #-} +mgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => IShX sh -> (IxX sh i -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g + +msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive +msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive + +msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a +msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +msumAllPrim arr = msumAllPrimP (toPrimitive arr) mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -759,7 +837,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 ssh = ssxFromShX sh - snm :: SMayNat () SNat (AddMaybe n m) + snm :: SMayNat () (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () @@ -781,36 +859,93 @@ mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to +-- stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'. +mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuter l = mfromListOuterN (length l) l + +-- | See 'mfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable. +mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuterN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l) + Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")" + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the +-- list. +-- +-- If the elements are scalars, 'mfromList1Prim' is faster. mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? +mfromList1 = mfromListOuter . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a +mfromList1N n = mfromListOuterN n . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a +mfromList1SN sn = mfromListOuterSN sn . fmap mscalar +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +-- | If the elements are scalars, 'mfromListPrimLinear' is faster. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) + +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list. mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromList1Prim l = let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l + xarr = X.fromList1 l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter +mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a +mfromList1PrimN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l) + Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")" -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr +mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a +mfromList1PrimSN sn l = + let sh = SKnown sn :$% ZSX + in fromPrimitive $ M_Primitive sh + $ if Storable.sizeOf (undefined :: a) > 0 + then X.fromList1SN sn l + else case l of -- don't force the list if all elements are the same + a0 : _ -> X.replicateScal sh a0 + [] -> X.fromList1SN sn l -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l) in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) +mtoList :: Elt a => Mixed '[n] a -> [a] +mtoList = map munScalar . mtoListOuter mtoListLinear :: Elt a => Mixed sh a -> [a] -mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise +mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) + +mtoListPrim :: PrimElt a => Mixed '[n] a -> [a] +mtoListPrim (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr + +mtoListPrimLinear :: PrimElt a => Mixed sh a -> [a] +mtoListPrimLinear (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX @@ -821,30 +956,63 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 +-- | The arguments must have equal shapes. If they do not, an error is raised. +mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip a b + | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b + | otherwise = error "mzip: unequal shapes" munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) munzip (M_Tup2 a b) = (a, b) -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) +mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b)) +mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) = + let sh1 = shxDropSh sh shsh1 + in M_Nest sh $ + M_Primitive (shxAppend sh sh2) + (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr +-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero), +-- then there is no way to deduce the full shape of the output array (more +-- precisely, the @sh2@ part): that could only come from calling @f@, and there +-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we +-- choose to fill the shape with zeros wherever we cannot deduce what it should +-- be. +-- +-- For example, if: +-- +-- @ +-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21] +-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int) +-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float) +-- @ +-- +-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the +-- @0@ in this shape: we don't know if @f@ intended to return an array with +-- shape 0 here (it probably didn't), but there is no better number to put here +-- absent a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b) +mrerankPrim sh2 f (M_Nest sh arr) = + let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr)) + in M_Nest sh' (fromPrimitive arr') mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a @@ -856,20 +1024,28 @@ mreplicate sh arr = Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) +mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x) -mreplicateScal :: forall sh a. PrimElt a +mreplicatePrim :: forall sh a. PrimElt a => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) +mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x) -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = +msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr + +msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +msliceSN i n arr = let _ :$% sh = mshape arr in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr +mslice :: forall i n k sh a. Elt a + => SMayNat Int i -> SMayNat Int n -> SMayNat Int k -> Mixed (AddMaybe (AddMaybe i n) k : sh) a -> Mixed (n : sh) a +mslice i n k arr = + let _ :$% sh = mshape arr + uarr = mcastPartial (ssxFromShX $ smnAddMaybe (smnAddMaybe i n) k :$% ZSX) (SUnknown () :!% ZKX) Proxy arr + in mcastPartial (SUnknown () :!% ZKX) (ssxFromShX $ n :$% ZSX) Proxy + $ mlift (SUnknown () :!% ssxFromShX sh) (\_ -> X.sliceU (fromSMayNat' i) (fromSMayNat' n)) uarr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr @@ -929,11 +1105,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP +{-# INLINE mliftPrim #-} mliftPrim :: (PrimElt a, PrimElt b) => (a -> b) -> Mixed sh a -> Mixed sh b mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) +{-# INLINE mliftPrim2 #-} mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) => (a -> b -> c) -> Mixed sh a -> Mixed sh b -> Mixed sh c diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 2f35ff9..77256ab 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -14,9 +17,11 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -31,14 +36,17 @@ import Data.Functor.Const import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import GHC.TypeLits.Orphans () +#endif +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types @@ -55,7 +63,7 @@ type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type data ListX sh f where ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f + (::%) :: forall n sh {f}. f n -> ListX sh f -> ListX (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) infixr 3 ::% @@ -100,21 +108,24 @@ listxEqual (n ::% sh) (m ::% sh') = Just Refl listxEqual _ _ = Nothing +{-# INLINE listxFmap #-} listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g listxFmap _ ZX = ZX listxFmap f (x ::% xs) = f x ::% listxFmap f xs -listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFold _ ZX = mempty -listxFold f (x ::% xs) = f x <> listxFold f xs +{-# INLINE listxFoldMap #-} +listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFoldMap _ ZX = mempty +listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs listxLength :: ListX sh f -> Int -listxLength = getSum . listxFold (\_ -> Sum 1) +listxLength = getSum . listxFoldMap (\_ -> Sum 1) listxRank :: ListX sh f -> SNat (Rank sh) listxRank ZX = SNat listxRank (_ ::% l) | SNat <- listxRank l = SNat +{-# INLINE listxShow #-} listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS listxShow f l = showString "[" . go "" l . showString "]" where @@ -132,9 +143,13 @@ listxFromList topssh topl = go topssh topl ++ show (ssxLength topssh) ++ ", list has length " ++ show (length topl) ++ ")" -listxToList :: ListX sh' (Const i) -> [i] -listxToList ZX = [] -listxToList (Const i ::% is) = i : listxToList is +{-# INLINEABLE listxToList #-} +listxToList :: ListX sh (Const i) -> [i] +listxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListX sh (Const i) -> is + go ZX = nil + go (Const i ::% is) = i `cons` go is + in go list) listxHead :: ListX (mn ': sh) f -> f mn listxHead (i ::% _) = i @@ -146,9 +161,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh @@ -160,19 +175,18 @@ listxLast (x ::% ZX) = x listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = - Pair i j ::% listxZip irest jrest +listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest +{-# INLINE listxZipWith #-} listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g -> ListX sh h listxZipWith _ ZX ZX = ZX -listxZipWith f (i ::% is) (j ::% js) = - f i j ::% listxZipWith f is js +listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js -- * Mixed indices --- | This is a newtype over 'ListX'. +-- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) @@ -191,6 +205,8 @@ infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxX sh = IxX sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -201,10 +217,18 @@ instance Show i => Show (IxX sh i) where #endif instance Functor (IxX sh) where + {-# INLINE fmap #-} fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) instance Foldable (IxX sh) where - foldMap f (IxX l) = listxFold (f . getConst) l + {-# INLINE foldMap #-} + foldMap f (IxX l) = listxFoldMap (f . getConst) l + {-# INLINE foldr #-} + foldr _ z ZIX = z + foldr f z (x :.% xs) = f x (foldr f z xs) + toList = ixxToList + null ZIX = False + null _ = True instance NFData i => NFData (IxX sh i) @@ -225,6 +249,10 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i ixxFromList = coerce (listxFromList @_ @i) +{-# INLINEABLE ixxToList #-} +ixxToList :: forall sh i. IxX sh i -> [i] +ixxToList = coerce (listxToList @_ @i) + ixxHead :: IxX (n : sh) i -> i ixxHead (IxX list) = getConst (listxHead list) @@ -234,7 +262,7 @@ ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixxAppend = coerce (listxAppend @_ @(Const i)) -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i ixxDrop = coerce (listxDrop @(Const i) @(Const i)) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i @@ -243,64 +271,58 @@ ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js +{-# INLINE ixxZipWith #-} ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -ixxFromLinear :: IShX sh -> Int -> IIxX sh -ixxFromLinear = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixxToLinear #-} +ixxToLinear :: Num i => IShX sh -> IxX sh i -> i +ixxToLinear = \sh i -> go sh i 0 where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) - -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + go :: Num i => IShX sh -> IxX sh i -> i -> i + go ZSX ZIX !a = a + go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) -- * Mixed shapes -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing +{-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +{-# INLINE fromSMayNat' #-} +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -308,7 +330,7 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) @@ -317,7 +339,7 @@ smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) -- | This is a newtype over 'ListX'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) +newtype ShX sh i = ShX (ListX sh (SMayNat i)) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i @@ -326,7 +348,7 @@ pattern ZSX = ShX ZX pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i + => SMayNat i n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) where i :$% ShX shl = ShX (i ::% shl) infixr 3 :$% @@ -343,6 +365,7 @@ instance Show i => Show (ShX sh i) where #endif instance Functor (ShX sh) where + {-# INLINE fmap #-} fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) instance NFData i => NFData (ShX sh i) where @@ -390,10 +413,10 @@ shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int +shxFromList :: StaticShX sh -> [Int] -> IShX sh shxFromList topssh topl = go topssh topl where - go :: StaticShX sh' -> [Int] -> ShX sh' Int + go :: StaticShX sh' -> [Int] -> IShX sh' go ZKX [] = ZSX go (SKnown sn :!% sh) (i : is) | i == fromSNat' sn = SKnown sn :$% go sh is @@ -404,48 +427,57 @@ shxFromList topssh topl = go topssh topl ++ show (ssxLength topssh) ++ ", list has length " ++ show (length topl) ++ ")" +{-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] -shxToList ZSX = [] -shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +shxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: IShX sh -> is + go ZSX = nil + go (smn :$% sh) = fromSMayNat' smn `cons` go sh + in go list) + +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" -- | This may fail if @sh@ has @Nothing@s in it. -shxFromSSX' :: StaticShX sh -> Maybe (IShX sh) -shxFromSSX' ZKX = Just ZSX -shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh -shxFromSSX' (SUnknown _ :!% _) = Nothing +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxAppend = coerce (listxAppend @_ @(SMayNat i)) -shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead :: ShX (n : sh) i -> SMayNat i n shxHead (ShX list) = listxHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i) @(SMayNat ())) -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i) @(Const j)) -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i) @(SMayNat i)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxInit = coerce (listxInit @(SMayNat i)) -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i)) -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +{-# INLINE shxZipWith #-} +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -456,28 +488,37 @@ shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) shxEnum :: IShX sh -> [IIxX sh] -shxEnum = \sh -> go sh id [] +shxEnum = shxEnum' + +{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site +shxEnum' :: Num i => IShX sh -> [IxX sh i] +shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + + fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i + fromLin ZSX _ _ = ZIX + fromLin (_ :$% sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' + in fromIntegral (I# q#) :.% fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" @@ -486,7 +527,7 @@ shxCast' sh ssh = case shxCast sh ssh of -- | The part of a shape that is statically known. (A newtype over 'ListX'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) +newtype StaticShX sh = StaticShX (ListX sh (SMayNat ())) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh @@ -495,7 +536,7 @@ pattern ZKX = StaticShX ZX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 + => SMayNat () n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) where i :!% StaticShX shl = StaticShX (i ::% shl) infixr 3 :!% @@ -531,38 +572,44 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead :: StaticShX (n : sh) -> SMayNat () n ssxHead (StaticShX list) = listxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listxDrop @(SMayNat ()) @(SMayNat ())) + +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat ()) @(Const i)) + +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (listxDrop @(SMayNat ()) @(SMayNat i)) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxInit = coerce (listxInit @(SMayNat ())) -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat ())) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1) -ssxFromShX :: IShX sh -> StaticShX sh +ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n -- | Evidence for the static part of a shape. This pops up only when you are @@ -574,7 +621,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX k = withDict @(KnownShX sh) k +withKnownShX = withDict @(KnownShX sh) -- * Flattening @@ -587,18 +634,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh @@ -626,3 +673,8 @@ instance KnownShX sh => IsList (ShX sh Int) where type Item (ShX sh Int) = Int fromList = shxFromList (knownShX @sh) toList = shxToList + +-- This needs to be at the bottom of the file to not split the file into +-- pieces; some of the shape/index stuff refers to StaticShX. +$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |]) +{-# INLINEABLE ixxFromLinear #-} diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs new file mode 100644 index 0000000..2a86ac1 --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape/Internal.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Nested.Mixed.Shape.Internal where + +import Language.Haskell.TH + + +-- | A TH stub function to avoid having to write the same code three times for +-- the three kinds of shapes. +ixFromLinearStub :: String + -> TypeQ -> TypeQ + -> PatQ -> (PatQ -> PatQ -> PatQ) + -> ExpQ -> ExpQ + -> ExpQ + -> DecsQ +ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do + let fname = mkName fname' + typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |] + + locals <- [d| + -- Unfold first iteration of fromLin to do the range check. + -- Don't inline this function at first to allow GHC to inline the outer + -- function and realise that 'suffixes' is shared. But then later inline it + -- anyway, to avoid the function call. Removing the pragma makes GHC + -- somehow unable to recognise that 'suffixes' can be shared in a loop. + {-# NOINLINE [0] fromLin0 #-} + fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i + fromLin0 sh suffixes i = + if i < 0 then outrange sh i else + case (sh, suffixes) of + ($zshC, _) | i > 0 -> outrange sh i + | otherwise -> $ixz + ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + $ixcons (fromIntegral q) (fromLin sh' suffs r) + _ -> error "impossible" + + fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i + fromLin $zshC _ !_ = $ixz + fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in $ixcons (fromIntegral q) (fromLin sh' suffs r) + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: $ishty sh -> Int -> a + outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" |] + + body <- [| + \sh -> -- give this function arity 1 so that 'suffixes' is shared when + -- it's called many times + let suffixes = drop 1 (scanr (*) 1 ($shtolist sh)) + in fromLin0 sh suffixes |] + + return [SigD fname typesig + ,FunD fname [Clause [] (NormalB body) locals]] diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 031755f..6bebcfb 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -1,10 +1,10 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +25,7 @@ import Data.Proxy import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord +import GHC.Exts (withDict) import GHC.TypeError import GHC.TypeLits import GHC.TypeNats qualified as TN @@ -36,8 +37,8 @@ import Data.Array.Nested.Types -- * Permutations -- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. data Perm list where PNil :: Perm '[] PCons :: SNat a -> Perm l -> Perm (a : l) @@ -45,15 +46,22 @@ infixr 5 `PCons` deriving instance Show (Perm list) deriving instance Eq (Perm list) +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + permRank :: Perm list -> SNat (Rank list) permRank PNil = SNat permRank (_ `PCons` l) | SNat <- permRank l = SNat -permFromList :: [Int] -> (forall list. Perm list -> r) -> r -permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x +permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r +permFromListCont [] k = k PNil +permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Nested.Permutation.permFromListCont: negative number in list: " ++ show x permToList :: Perm list -> [Natural] permToList PNil = mempty @@ -119,6 +127,9 @@ class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + -- | Untyped permutations for ranked arrays type PermR = [Int] @@ -190,22 +201,22 @@ ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) +ssxTakeLen = coerce (listxTakeLen @(SMayNat ())) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) +ssxDropLen = coerce (listxDropLen @(SMayNat ())) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) +ssxPermute = coerce (listxPermute @(SMayNat ())) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh) +ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ())) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int)) -- * Operations on permutations @@ -224,7 +235,7 @@ permInverse = \perm k -> ++ " ; invperm = " ++ show invperm) (permCheckPermutation invperm (k invperm - (\ssh -> case provePermInverse perm invperm ssh of + (\ssh -> case permCheckInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm))) @@ -238,9 +249,9 @@ permInverse = \perm k -> toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - provePermInverse :: Perm is -> Perm is' -> StaticShX sh + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = + permCheckInverse perm perminv ssh = ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where @@ -264,7 +275,13 @@ lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX PNil = Refl -lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) + | Refl <- lemRankDropLen sh is +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) + = Refl +#else + = unsafeCoerceRefl +#endif lemRankDropLen (_ :!% _) PNil = Refl lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e5c51ef..f668c3e 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -29,17 +29,17 @@ import Foreign.Storable (Storable) import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) -import Data.Array.XArray qualified as X import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X remptyArray :: KnownElt a => Ranked 1 a @@ -49,9 +49,11 @@ remptyArray = mtoRanked (memptyArray ZSX) rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape +{-# INLINEABLE rindex #-} rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) +{-# INLINEABLE rindexPartial #-} rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) @@ -59,7 +61,8 @@ rindexPartial (Ranked arr) idx = (ixxFromIxR idx)) -- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. +-- See the documentation of 'mgenerate' for more details; see also +-- 'rgeneratePrim'. rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f | sn@SNat <- shrRank sh @@ -67,7 +70,16 @@ rgenerate sh f , Refl <- lemRankReplicate sn = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) +-- | See 'mgeneratePrim'. +{-# INLINE rgeneratePrim #-} +rgeneratePrim :: forall n a i. (PrimElt a, Num i) + => IShR n -> (IxR n i -> a) -> Ranked n a +rgeneratePrim sh f = + let g i = f (ixrFromLinear sh i) + in rfromVector sh $ VS.generate (shrSize sh) g + -- | See the documentation of 'mlift'. +{-# INLINE rlift #-} rlift :: forall n1 n2 a. Elt a => SNat n2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) @@ -75,22 +87,26 @@ rlift :: forall n1 n2 a. Elt a rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) -- | See the documentation of 'mlift2'. +{-# INLINE rlift2 #-} rlift2 :: forall n1 n2 n3 a. Elt a => SNat n3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) +rsumOuter1PrimP :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1PrimP (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msumOuter1PrimP arr) + +rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive +rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a +rsumAllPrimP (Ranked arr) = msumAllPrimP arr rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr @@ -108,7 +124,7 @@ rtranspose perm arr rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce mconcat rappend :: forall n a. Elt a @@ -116,7 +132,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n) = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -137,51 +153,82 @@ rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromListOuterN' to be able to stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'. rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) +-- | See 'rfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable. +rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuterN n l + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromList1N' to be able to stream the list. +-- +-- If the elements are scalars, 'rfromList1Prim' is faster. rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) +rfromList1 = coerce mfromList1 + +-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a +rfromList1N = coerce mfromList1N + +-- | If the elements are scalars, 'rfromListPrimLinear' is faster. +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l) +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'rfromList1PrimN' to be able to stream the list. rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) +rfromList1Prim = coerce mfromList1Prim + +rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a +rfromList1PrimN = coerce mfromList1PrimN + +rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l) rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) - -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter rtoListLinear :: Elt a => Ranked n a -> [a] rtoListLinear (Ranked arr) = mtoListLinear arr +rtoListPrim :: PrimElt a => Ranked 1 a -> [a] +rtoListPrim (Ranked arr) = mtoListPrim arr + +rtoListPrimLinear :: PrimElt a => Ranked n a -> [a] +rtoListPrimLinear (Ranked arr) = mtoListPrimLinear arr + rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a rfromOrthotope sn arr | Refl <- lemRankReplicate sn = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -197,22 +244,20 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked arr -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) runzip = coerce munzip -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) +rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b) + => IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b)) +rrerankPrimP sh2 f (Ranked (M_Ranked arr)) + = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) -- | If there is a zero-sized dimension in the @n@-prefix of the shape of the -- input array, then there is no way to deduce the full shape of the output @@ -223,26 +268,28 @@ rrerankP sn sh2 f (Ranked arr) -- For example, if: -- -- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21] -- f :: Ranked 2 Int -> Ranked 3 Float -- @ -- -- then: -- -- @ --- rrerank _ _ _ f arr :: Ranked 5 Float +-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float) -- @ -- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr +-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't +-- know if @f@ intended to return an array with all-zero shape here (it +-- probably didn't), but there is no better number to put here absent a +-- subarray of the input to pass to @f@. +rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b) +rrerankPrim sh2 f (Ranked (M_Ranked arr)) = + Ranked (M_Ranked (mrerankPrim (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a @@ -250,29 +297,24 @@ rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked (mreplicate (shxFromShR sh) arr) -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x +rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicatePrimP sh x | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shxFromShR sh) x) + = Ranked (mreplicatePrimP (shxFromShR sh) x) -rreplicateScal :: forall n a. PrimElt a +rreplicatePrim :: forall n a. PrimElt a => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) +rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (rrank arr) - (\_ -> X.sliceU i n) - arr +rslice i n (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msliceN i n arr) rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr +rrev1 (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mrev1 arr) rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index f50f671..834e139 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,6 +26,7 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy +import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -34,13 +36,13 @@ import GHC.TypeLits import Data.Foldable (toList) #endif -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Types -import Data.Array.XArray (XArray(..)) +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -95,13 +97,14 @@ instance Elt a => Elt (Ranked n a) where mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + mfromListOuterSN :: SNat m -> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Just m : sh) (Ranked n a) + mfromListOuterSN sn l = M_Ranked (mfromListOuterSN sn (coerce l)) mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] mtoListOuter (M_Ranked arr) = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -110,6 +113,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -118,6 +122,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -141,24 +146,25 @@ instance Elt a => Elt (Ranked n a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWriteLinear idx (Ranked arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh sh' s. + Proxy sh -> Int -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) @@ -174,18 +180,30 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) + mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArrayUnsafe i + memptyArrayUnsafe sh | Dict <- lemKnownReplicate (SNat @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArrayUnsafe i + memptyArrayUnsafe sh mvecsUnsafeNew idx (Ranked arr) | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) @@ -208,15 +226,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where negate = liftRanked1 negate abs = liftRanked1 abs signum = liftRanked1 signum - fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicatePrim" recip = liftRanked1 recip (/) = liftRanked2 (/) instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicatePrim" exp = liftRanked1 exp log = liftRanked1 log sqrt = liftRanked1 sqrt @@ -252,3 +270,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index c0c4f17..b6bee2e 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,13 +1,14 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -18,9 +19,11 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -33,17 +36,20 @@ import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality +import GHC.Exts (Int(..), Int#, build, quotRemInt#) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types +-- * Ranked lists + type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where @@ -51,8 +57,6 @@ data ListR n i where (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) infixr 3 ::: #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -66,6 +70,22 @@ instance NFData i => NFData (ListR n i) where rnf ZR = () rnf (x ::: l) = rnf x `seq` rnf l +instance Functor (ListR n) where + {-# INLINE fmap #-} + fmap _ ZR = ZR + fmap f (x ::: xs) = f x ::: fmap f xs + +instance Foldable (ListR n) where + {-# INLINE foldMap #-} + foldMap _ ZR = mempty + foldMap f (x ::: xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZR = z + foldr f z (x ::: xs) = f x (foldr f z xs) + toList = listrToList + null ZR = False + null _ = True + data UnconsListRRes i n1 = forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) @@ -90,6 +110,7 @@ listrEqual (i ::: sh) (j ::: sh') = Just Refl listrEqual _ _ = Nothing +{-# INLINE listrShow #-} listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS listrShow f l = showString "[" . go "" l . showString "]" where @@ -108,27 +129,41 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) +listrFromList :: SNat n -> [i] -> ListR n i +listrFromList topsn topl = go topsn topl + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error $ "listrFromList: Mismatched list length (type says " + ++ show (fromSNat topsn) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE listrToList #-} +listrToList :: ListR n i -> [i] +listrToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListR n i -> is + go ZR = nil + go (i ::: is) = i `cons` go is + in go list) listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i -listrHead ZR = error "unreachable" listrTail :: ListR (n + 1) i -> ListR n i listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" listrInit :: ListR (n + 1) i -> ListR n i listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" listrLast :: ListR (n + 1) i -> i listrLast (_ ::: sh@(_ ::: _)) = listrLast sh listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" + +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x @@ -140,6 +175,7 @@ listrZip ZR ZR = ZR listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest listrZip _ _ = error "listrZip: impossible pattern needlessly required" +{-# INLINE listrZipWith #-} listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k listrZipWith _ ZR ZR = ZR listrZipWith f (i ::: irest) (j ::: jrest) = @@ -149,13 +185,15 @@ listrZipWith _ _ _ = listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case listrRank sh of { shlen@SNat -> + let sperm = listrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) @@ -172,6 +210,8 @@ listrPermutePrefix = \perm sh -> GTI -> error "listrPermutePrefix: Index in permutation out of range" +-- * Ranked indices + -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type @@ -192,6 +232,8 @@ infixr 3 :.: {-# COMPLETE ZIR, (:.:) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxR n = IxR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -213,15 +255,12 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx +ixrFromList :: forall n i. SNat n -> [i] -> IxR n i +ixrFromList = coerce (listrFromList @_ @i) -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixxFromIxR idx) +{-# INLINEABLE ixrToList #-} +ixrToList :: forall n i. IxR n i -> [i] +ixrToList = coerce (listrToList @_ @i) ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -235,18 +274,38 @@ ixrInit (IxR list) = IxR (listrInit list) ixrLast :: IxR (n + 1) i -> i ixrLast (IxR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 +{-# INLINE ixrZipWith #-} ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear = \sh i -> go sh i 0 + where + -- Additional argument: index, in the @m - m1@ dimensional array so far, + -- of the @m - m1 + n@ dimensional tensor pointed to by the current + -- @m - m1@ dimensional index prefix. + go :: Num i => IShR m1 -> IxR m1 i -> i -> i + go ZSR ZIR a = a + go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i) + + +-- * Ranked shapes type role ShR nominal representational type ShR :: Nat -> Type -> Type @@ -278,22 +337,6 @@ instance Show i => Show (ShR n i) where instance NFData i => NFData (ShR sh i) -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh - | Refl <- lemRankReplicate (Proxy @n) - = shrFromShX sh - -shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shxFromShR idx) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') @@ -317,6 +360,13 @@ shrSize :: IShR n -> Int shrSize ZSR = 1 shrSize (n :$: sh) = n * shrSize sh +shrFromList :: forall n i. SNat n -> [i] -> ShR n i +shrFromList = coerce (listrFromList @_ @i) + +{-# INLINEABLE shrToList #-} +shrToList :: forall n i. ShR n i -> [i] +shrToList = coerce (listrToList @_ @i) + shrHead :: ShR (n + 1) i -> i shrHead (ShR list) = listrHead list @@ -329,30 +379,44 @@ shrInit (ShR list) = ShR (listrInit list) shrLast :: ShR (n + 1) i -> i shrLast (ShR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) shrZip :: ShR n i -> ShR n j -> ShR n (i, j) shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +{-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i shrPermutePrefix = coerce (listrPermutePrefix @i) +shrEnum :: IShR sh -> [IIxR sh] +shrEnum = shrEnum' + +{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site +shrEnum' :: Num i => IShR sh -> [IxR sh i] +shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]] + where + suffixes = drop 1 (scanr (*) 1 (shrToList sh)) + + fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i + fromLin ZSR _ _ = ZIR + fromLin (_ :$: sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' + in fromIntegral (I# q#) :.: fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" + -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ListR n i) where type Item (ListR n i) = i - fromList topl = go (SNat @n) topl - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error $ "IsList(ListR): Mismatched list length (type says " - ++ show (fromSNat (SNat @n)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = listrFromList (SNat @n) toList = Foldable.toList -- | Untyped: length is checked at runtime. @@ -366,3 +430,14 @@ instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" + +$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) +{-# INLINEABLE ixrFromLinear #-} diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 7e38aee..acb7c89 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -29,21 +29,20 @@ import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X -semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray :: forall sh a. KnownElt a => ShS sh -> Shaped (0 : sh) a semptyArray sh = Shaped (memptyArray (shxFromShS sh)) srank :: Elt a => Shaped sh a -> SNat (Rank sh) @@ -53,13 +52,16 @@ srank = shsRank . sshape ssize :: Elt a => Shaped sh a -> Int ssize = shsSize . sshape +{-# INLINEABLE sindex #-} sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +{-# INLINEABLE shsTakeIx #-} +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh shsTakeIx _ _ ZIS = ZSS shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx +{-# INLINEABLE sindexPartial #-} sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial sarr@(Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) @@ -71,7 +73,16 @@ sindexPartial sarr@(Shaped arr) idx = sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +-- | See 'mgeneratePrim'. +{-# INLINE sgeneratePrim #-} +sgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => ShS sh -> (IxS sh i -> a) -> Shaped sh a +sgeneratePrim sh f = + let g i = f (ixsFromLinear sh i) + in sfromVector sh $ VS.generate (shsSize sh) g + -- | See the documentation of 'mlift'. +{-# INLINE slift #-} slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -79,19 +90,23 @@ slift :: forall sh1 sh2 a. Elt a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. +{-# INLINE slift2 #-} slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) +ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr) -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive +ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive + +ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a +ssumAllPrimP (Shaped arr) = msumAllPrimP arr ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr @@ -124,39 +139,54 @@ stoVectorP = coerce mtoVectorP stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'sfromListOuterSN' to be able to stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'. sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) +sfromListOuter = coerce mfromListOuterSN +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'sfromList1SN' to be able to stream the list. +-- +-- If the elements are scalars, 'sfromList1Prim' is faster. sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 - -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim +sfromList1 = coerce mfromList1SN -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) +-- | If the elements are scalars, 'sfromListPrimLinear' is faster. +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'sfromList1PrimN' to be able to stream the list. +sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromList1Prim = coerce mfromList1PrimSN -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr +sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l) -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) +stoList :: Elt a => Shaped '[n] a -> [a] +stoList = map sunScalar . stoListOuter stoListLinear :: Elt a => Shaped sh a -> [a] stoListLinear (Shaped arr) = mtoListLinear arr +stoListPrim :: PrimElt a => Shaped '[n] a -> [a] +stoListPrim (Shaped arr) = mtoListPrim arr + +stoListPrimLinear :: PrimElt a => Shaped sh a -> [a] +stoListPrimLinear (Shaped arr) = mtoListPrimLinear arr + sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a sfromOrthotope sh (SS.A (SG.A arr)) = Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) @@ -177,41 +207,41 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') = Shaped arr -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b) szip = coerce mzip sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) sunzip = coerce munzip -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh1) - , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh)))) - (shxFromShS sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) +srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b)) +srerankPrimP sh2 f (Shaped (M_Shaped arr)) + = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr)) -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr +-- | See the caveats at 'mrerankPrim'. +srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b) +srerankPrim sh2 f (Shaped (M_Shaped arr)) = + Shaped (M_Shaped (mrerankPrim (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr)) sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a sreplicate sh (Shaped arr) | Refl <- lemMapJustApp sh (Proxy @sh') = Shaped (mreplicate (shxFromShS sh) arr) -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) +sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x) -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) +sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a +sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a sslice i n@SNat arr = diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 529ac21..16c1b05 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -5,6 +5,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -25,18 +26,19 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy +import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Nested.Types -import Data.Array.XArray (XArray) -import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) -- | A shape-typed array: the full shape of the array (the sizes of its @@ -88,13 +90,14 @@ instance Elt a => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a) + mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l)) mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] mtoListOuter (M_Shaped arr) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -103,6 +106,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -111,6 +115,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -134,24 +139,25 @@ instance Elt a => Elt (Shaped sh a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) @@ -167,18 +173,30 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i + memptyArrayUnsafe sh | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i + memptyArrayUnsafe sh mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) @@ -201,15 +219,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where negate = liftShaped1 negate abs = liftShaped1 abs signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim" recip = liftShaped1 recip (/) = liftShaped2 (/) instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim" exp = liftShaped1 exp log = liftShaped1 log sqrt = liftShaped1 sqrt @@ -242,3 +260,12 @@ satan2Array = liftShaped2 matan2Array sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0b7d1c9..0c042b7 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,13 +1,12 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -18,9 +17,11 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -37,17 +38,22 @@ import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -98,13 +104,15 @@ listsEqual (n ::$ sh) (m ::$ sh') = Just Refl listsEqual _ _ = Nothing +{-# INLINE listsFmap #-} listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs +{-# INLINE listsFoldMap #-} +listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFoldMap _ ZS = mempty +listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" @@ -114,15 +122,40 @@ listsShow f l = showString "[" . go "" l . showString "]" go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsLength :: ListS sh f -> Int -listsLength = getSum . listsFold (\_ -> Sum 1) +listsLength = getSum . listsFoldMap (\_ -> Sum 1) listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) +listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList topsh topl = go topsh topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "listsFromList: Mismatched list length (type says " + ++ show (shsLength topsh) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE listsFromListS #-} +listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) +listsFromListS topl0 topl = go topl0 topl + where + go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) + go ZS [] = ZS + go (_ ::$ l0) (i : is) = Const i ::$ go l0 is + go _ _ = error $ "listsFromListS: Mismatched list length (the model says " + ++ show (listsLength topl0) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE listsToList #-} listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is +listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListS sh (Const i) -> is + go ZS = nil + go (Const i ::$ is) = i `cons` go is + in go list) listsHead :: ListS (n : sh) f -> f n listsHead (i ::$ _) = i @@ -144,14 +177,13 @@ listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = - Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +{-# INLINE listsZipWith #-} listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = - f i j ::$ listsZipWith f is js +listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS @@ -180,11 +212,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) @@ -193,6 +223,8 @@ newtype IxS sh i = IxS (ListS sh (Const i)) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) @@ -203,6 +235,8 @@ infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -213,10 +247,18 @@ instance Show i => Show (IxS sh i) where #endif instance Functor (IxS sh) where + {-# INLINE fmap #-} fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l + {-# INLINE foldMap #-} + foldMap f (IxS l) = listsFoldMap (f . getConst) l + {-# INLINE foldr #-} + foldr _ z ZIS = z + foldr f z (x :.$ xs) = f x (foldr f z xs) + toList = ixsToList + null ZIS = False + null _ = True instance NFData i => NFData (IxS sh i) @@ -226,18 +268,21 @@ ixsLength (IxS l) = listsLength l ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank (IxS l) = listsRank l +ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i +ixsFromList = coerce (listsFromList @_ @i) + +{-# INLINEABLE ixsFromIxS #-} +ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS = coerce (listsFromListS @_ @i0 @i) + +{-# INLINEABLE ixsToList #-} +ixsToList :: forall sh i. IxS sh i -> [i] +ixsToList = coerce (listsToList @_ @i) + ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh - ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) @@ -250,20 +295,39 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) -ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js -ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +{-# INLINE ixsZipWith #-} +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear = \sh i -> go sh i 0 + where + go :: Num i => ShS sh -> IxS sh i -> i -> i + go ZSS ZIS a = a + go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i) + + +-- * Shaped shapes -- | The shape of a shape-typed array given as a list of 'SNat' values. -- @@ -272,7 +336,10 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord, Generic) + deriving (Generic) + +instance Eq (ShS sh) where _ == _ = True +instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS @@ -317,26 +384,28 @@ shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh -shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh - -shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) +-- | This is a partial @const@ that fails when the second argument +-- doesn't match the first. +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList topsh topl = go topsh topl `seq` topsh where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shsFromShX (SUnknown _ :$% _) = error "impossible" + go :: ShS sh' -> [Int] -> () + go ZSS [] = () + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = go sh is + | otherwise = error $ "shsFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "shsFromList: Mismatched list length (type says " + ++ show (shsLength topsh) ++ ", list has length " + ++ show (length topl) ++ ")" -shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +{-# INLINEABLE shsToList #-} +shsToList :: ShS sh -> [Int] +shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> + let go :: ShS sh -> is + go ZSS = nil + go (sn :$$ sh) = fromSNat' sn `cons` go sh + in go topsh) shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list @@ -381,7 +450,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -391,18 +460,27 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] + where + suffixes = drop 1 (scanr (*) 1 (shsToList sh)) + + fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i + fromLin ZSS _ _ = ZIS + fromLin (_ :$$ sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh' + in fromIntegral (I# q#) :.$ fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = listsFromList (knownShS @sh) toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. @@ -414,15 +492,8 @@ instance KnownShS sh => IsList (IxS sh i) where -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = shsFromList (knownShS @sh) toList = shsToList + +$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) +{-# INLINEABLE ixsFromLinear #-} diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index 838e2b0..dfa5129 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -5,21 +5,28 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} +{-# OPTIONS -Wno-simplifiable-class-constraints #-} {-| This module is API-compatible with "Data.Array.Nested", except that inputs and -outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods -also have additional 'Show' constraints. +outputs of the methods are traced to 'stderr'. Thus the methods also have +additional 'Show' constraints. ->>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) ->>> length (show res) `seq` () -oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))] -oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))] -oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))] -oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))] -oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))] -oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))] ->>> res -Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35])))) +>>> rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) +oxtrace: (riota _ ... = rfromListLinear [6] [0,1,2,3,4,5]) +oxtrace: (rreshape [2,3] (rfromListLinear [6] [0,1,2,3,4,5]) ... = rfromListLinear [2,3] [0,1,2,3,4,5]) +oxtrace: (rtranspose [1,0] (rfromListLinear [2,3] [0,1,2,3,4,5]) ... = rfromListLinear [3,2] [0,3,1,4,2,5]) +oxtrace: (rscalar _ ... = rfromListLinear [] [7]) +oxtrace: (rreplicate [6] (rfromListLinear [] [7]) ... = rreplicate [6] 7) +oxtrace: (rreshape [3,2] (rreplicate [6] 7) ... = rreplicate [3,2] 7) +rfromListLinear [3,2] [0,21,7,28,14,35] + +The part up until and including the @...@ is printed after @seq@ing the +arguments; the @=@ and further is printed after @seq@ing the result of the +operation. Do note that tracing means that the functions in this module are +potentially __stricter__ than the plain ones in "Data.Array.Nested". + +Arguments that this module does not know how to @show@, probably due to +laziness on my side, are printed as @_@. -} module Data.Array.Nested.Trace ( -- * Traced variants @@ -37,10 +44,12 @@ module Data.Array.Nested.Trace ( ShS(..), KnownShS(..), Mixed, + ListX(ZX, (::%)), IxX(..), IIxX, - ShX(..), KnownShX(..), + ShX(..), KnownShX(..), IShX, StaticShX(..), SMayNat(..), + Conversion(..), Elt, PrimElt, @@ -51,10 +60,10 @@ module Data.Array.Nested.Trace ( Storable, SNat, pattern SNat, pattern SZ, pattern SS, - Perm(..), + Perm(..), PermR, IsPermutation, KnownPerm(..), - NumElt, FloatElt, + NumElt, IntElt, FloatElt, Rank, Product, Replicate, MapJust, @@ -67,4 +76,4 @@ import Data.Array.Nested.Trace.TH $(concat <$> mapM convertFun - ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped]) + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rgeneratePrim, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoListOuter, 'rtoList, 'rtoListLinear, 'rtoListPrim, 'rtoListPrimLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'sgeneratePrim, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoListOuter, 'stoList, 'stoListLinear, 'stoListPrim, 'stoListPrimLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'mgeneratePrim, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoListOuter, 'mtoList, 'mtoListLinear, 'mtoListPrim, 'mtoListPrimLinear, 'msliceN, 'msliceSN, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs index 4b388e3..644b4bd 100644 --- a/src/Data/Array/Nested/Trace/TH.hs +++ b/src/Data/Array/Nested/Trace/TH.hs @@ -4,11 +4,11 @@ module Data.Array.Nested.Trace.TH where import Control.Monad (zipWithM) -import Data.List (foldl', intersperse) +import Data.List (foldl') import Data.Maybe (isJust) import Language.Haskell.TH hiding (cxt) - -import Debug.Trace qualified as Debug +import System.IO (hPutStr, stderr) +import System.IO.Unsafe (unsafePerformIO) import Data.Array.Nested @@ -20,7 +20,7 @@ splitFunTy = \case in (vars, cx, t1 : args, ret) ForallT vs cx' t -> let (vars, cx, args, ret) = splitFunTy t - in (vars ++ vs, cx ++ cx', args, ret) + in (vs ++ vars, cx' ++ cx, args, ret) t -> ([], [], [], t) data Arg = RRanked Type Arg @@ -30,17 +30,27 @@ data Arg = RRanked Type Arg | ROther Type deriving (Show) --- TODO: always returns Just recognise :: Type -> Maybe Arg recognise (ConT name `AppT` sht `AppT` ty) - | name == ''Ranked = RRanked sht <$> recognise ty - | name == ''Shaped = RShaped sht <$> recognise ty - | name == ''Mixed = RMixed sht <$> recognise ty + | name == ''Ranked = Just (RRanked sht (recogniseElt ty)) + | name == ''Shaped = Just (RShaped sht (recogniseElt ty)) + | name == ''Mixed = Just (RMixed sht (recogniseElt ty)) + | name == ''Conversion = Just (RShowable ty) recognise ty@(ConT name `AppT` _) - | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] = + | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat, ''Perm] = Just (RShowable ty) +recognise ty@(ConT name) + | name == ''PermR = Just (RShowable ty) +recognise (ListT `AppT` ty) = Just (ROther ty) recognise _ = Nothing +recogniseElt :: Type -> Arg +recogniseElt (ConT name `AppT` sht `AppT` ty) + | name == ''Ranked = RRanked sht (recogniseElt ty) + | name == ''Shaped = RShaped sht (recogniseElt ty) + | name == ''Mixed = RMixed sht (recogniseElt ty) +recogniseElt ty = ROther ty + realise :: Arg -> Type realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty @@ -62,37 +72,58 @@ mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT '' mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty] mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty] -convertType :: Type -> Q (Type, [Bool], Bool) +-- If you pass a polymorphic function to seq, GHC wants to monomorphise and +-- doesn't know how to instantiate the type variables. Just don't, I guess. +isSeqable :: Type -> Bool +isSeqable ForallT{} = False +isSeqable (AppT a b) = isSeqable a && isSeqable b +isSeqable _ = True -- yolo, I guess + +convertType :: Type -> Q (Type, [Bool], [Bool], Bool) convertType typ = let (tybndrs, cxt, args, ret) = splitFunTy typ - argrels = map recognise args - retrel = recognise ret + argdescrs = map recognise args + retdescr = recognise ret in return (ForallT tybndrs (cxt ++ [constr - | Just rel <- retrel : argrels + | Just rel <- retdescr : argdescrs , constr <- mkShow rel]) (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args) - ,map isJust argrels - ,isJust retrel) + ,map isJust argdescrs + ,map isSeqable args + ,isJust retdescr) convertFun :: Name -> Q [Dec] convertFun funname = do defname <- newName (nameBase funname) - (convty, argarrs, retarr) <- reifyType funname >>= convertType - names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..] + -- "ok": whether we understand this type enough to be able to show it + (convty, argoks, argsseqable, retok) <- reifyType funname >>= convertType + names <- zipWithM (\_ i -> newName ('x' : show i)) argoks [1::Int ..] + -- let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) resname <- newName "res" - let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) + let traceCall str val = VarE 'traceNoNewline `AppE` str `AppE` val + let msg1 = [LitE (StringL ("oxtrace: (" ++ nameBase funname ++ " "))] ++ + [if ok + then VarE 'showsPrec `AppE` LitE (IntegerL 11) `AppE` VarE n `AppE` LitE (StringL " ") + else LitE (StringL "_ ") + | (n, ok) <- zip names argoks] ++ + [LitE (StringL "...")] + let msg2 | retok = [LitE (StringL " = "), VarE 'show `AppE` VarE resname, LitE (StringL ")\n")] + | otherwise = [LitE (StringL " = _)\n")] let ex = LetE [ValD (VarP resname) (NormalB (foldl' AppE (VarE funname) (map VarE names))) - []] - (VarE 'Debug.trace - `AppE` (VarE 'concat `AppE` ListE - ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++ - intersperse (LitE (StringL ", ")) - (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++ - [LitE (StringL "]")])) - `AppE` VarE resname) + []] $ + flip (foldr AppE) [VarE 'seq `AppE` VarE n | (n, True) <- zip names argsseqable] $ + traceCall (VarE 'concat `AppE` ListE msg1) $ + VarE 'seq `AppE` VarE resname `AppE` + traceCall (VarE 'concat `AppE` ListE msg2) (VarE resname) return [SigD defname convty ,FunD defname [Clause (map VarP names) (NormalB ex) []]] + +{-# NOINLINE traceNoNewline #-} +traceNoNewline :: String -> a -> a +traceNoNewline str x = unsafePerformIO $ do + hPutStr stderr str + return x diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index 4172fa0..1dce868 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -6,7 +6,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} @@ -30,6 +30,7 @@ module Data.Array.Nested.Types ( Replicate, lemReplicateSucc, MapJust, + lemMapJustEmpty, lemMapJustCons, Head, Tail, Init, @@ -45,7 +46,6 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Unsafe.Coerce qualified - -- Reasoning helpers subst1 :: forall f a b. a :~: b -> f a :~: f b @@ -58,8 +58,9 @@ subst2 Refl = Refl data Dict c a where Dict :: c a => Dict c a +{-# INLINE fromSNat' #-} fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) sameNat' n@SNat m@SNat = sameNat n m @@ -108,13 +109,20 @@ type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl +lemReplicateSucc :: forall a n proxy. + proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc _ = unsafeCoerceRefl -type family MapJust l where +type family MapJust l = r | r -> l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + type family Head l where Head (x : _) = x diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs index 5c38d14..e2cd17c 100644 --- a/src/Data/Array/Strided/Orthotope.hs +++ b/src/Data/Array/Strided/Orthotope.hs @@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve toO :: AS.Array n a -> RS.Array n a toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) +{-# INLINE liftO1 #-} liftO1 :: (AS.Array n a -> AS.Array n' b) -> RS.Array n a -> RS.Array n' b liftO1 f = toO . f . fromO +{-# INLINE liftO2 #-} liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c liftO2 f x y = toO (f (fromO x) (fromO y)) +-- We don't inline this lifting function, because its code is not just +-- a wrapper, being relatively long and expensive. +{-# INLINEABLE liftVEltwise1 #-} liftVEltwise1 :: (Storable a, Storable b) => SNat n -> (VS.Vector a -> VS.Vector b) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index dde06e3..7dcfd5e 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -14,27 +17,33 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) +import Control.Monad (foldM_, foldM) +import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG import Data.Array.Internal.RankedS qualified as ORS -import Data.Array.Ranked qualified as ORB import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import Data.Type.Ord +import Data.Vector.Generic.Checked qualified as VGC import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import Unsafe.Coerce (unsafeCoerce) +#endif -import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Permutation import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.Shape import Data.Array.Strided.Orthotope @@ -108,15 +117,23 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- XArray . S.fromVector (shxShapeL sh) -- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) +{-# INLINEABLE indexPartial #-} indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx +{- Strangely, this increases allocation and there's no noticeable speedup: +indexPartial (XArray (ORS.A (ORG.A sh t))) ix = + let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t)) + len = ixxLength ix + in XArray (ORS.A (ORG.A (drop len sh) + OI.T{ OI.strides = drop len (OI.strides t) + , OI.offset = linear + , OI.values = OI.values t })) -} +{-# INLINEABLE index #-} index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' +index (XArray (ORS.A (ORG.A _ t))) i = + OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t))) append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a @@ -217,7 +234,12 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) = XArray (S.transpose (permToList' perm) arr) +#else + = XArray (unsafeCoerce (S.transpose (permToList' perm) arr)) +#endif + -- | The list argument gives indices into the original dimension list. -- @@ -243,14 +265,10 @@ transpose2 ssh1 ssh2 (XArray arr) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a -sumFull _ (XArray arr) = - S.unScalar $ - liftO1 (numEltSum1Inner (SNat @0)) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr +sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a @@ -283,33 +301,76 @@ sumOuter ssh ssh' arr reshapePartial ssh ssh' shF $ arr -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh - = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) +-- | This creates an array from a list of arrays of one less dimension. +-- The list is streamed, its length is checked and it's verified +-- that all arrays on the list have the same shape. +{-# INLINE fromListOuterSN #-} +fromListOuterSN :: forall n sh a. Storable a + => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a +fromListOuterSN m sh l + | Dict <- lemKnownNatRank sh + , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l + = case sh of + ZSX -> fromList1SN m (map S.unScalar (toList l')) + _ -> XArray (ravelOuterN (fromSNat' m) l') -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = - case S.shapeL arr of +-- | This checks that the list has the given length and that all shapes in the +-- list are equal. The list is streamed. +-- The first array in the list is forced early to potentially release some +-- memory, before allocating the (large) new array. +{-# INLINEABLE ravelOuterN #-} +ravelOuterN :: (KnownNat k, Storable a) + => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a +ravelOuterN 0 _ = error "ravelOuterN: N == 0" +ravelOuterN k as@(!a0 :| _) = runST $ do + let sh0 = S.shapeL a0 + len = product sh0 + vecSize = k * len + vec <- VSM.unsafeNew vecSize + let f !n (ORS.A (ORG.A sht t)) = + if | n >= k -> + error $ "ravelOuterN: list too long " ++ show (n, k) + -- if we do this check just once at the end, we may + -- crash instead of producing an accurate error message + | sht == sh0 -> do + let g off el = do + VS.unsafeCopy (VSM.slice off (VS.length el) vec) el + return $! off + VS.length el + foldM_ g (n * len) (OI.toVectorListT sht t) + return $! n + 1 + | otherwise -> + error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0) + nFinal <- foldM f 0 as + if nFinal == k + then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec + else error $ "ravelOuterN: list too short " ++ show (nFinal, k) + +toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) = + case shArr of + [] -> error "impossible" 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) + -- using orthotope's functions here would entail using rerank, which is slow, so we don't + [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr) + n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1] + +-- | Performance note: the list's spine is fully materialised to compute its +-- length before traversing it again to construct the array. +{-# INLINE fromList1 #-} +fromList1 :: Storable a => [a] -> XArray '[Nothing] a +fromList1 l = + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) +-- | The list is streamed. +{-# INLINE fromList1SN #-} +fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a +fromList1SN m l = + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) -toList1 :: Storable a => XArray '[n] a -> [a] -toList1 (XArray arr) = S.toList arr +toListLinear :: Storable a => XArray sh a -> [a] +toListLinear (XArray arr) = S.toList arr -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a diff --git a/src/Data/Vector/Generic/Checked.hs b/src/Data/Vector/Generic/Checked.hs new file mode 100644 index 0000000..d8aaaae --- /dev/null +++ b/src/Data/Vector/Generic/Checked.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Vector.Generic.Checked ( + fromListNChecked, +) where + +import Data.Stream.Monadic qualified as Stream +import Data.Vector.Fusion.Bundle.Monadic qualified as VBM +import Data.Vector.Fusion.Bundle.Size qualified as VBS +import Data.Vector.Fusion.Util qualified as VFU +import Data.Vector.Generic qualified as VG + +-- for INLINE_FUSED and INLINE_INNER +#include "vector.h" + + +-- These functions are copied over and lightly edited from the vector and +-- vector-stream packages, and thus inherit their BSD-3-Clause license with: +-- Copyright (c) 2008-2012, Roman Leshchinskiy +-- 2020-2022, Alexey Kuleshevich +-- 2020-2022, Aleksey Khudyakov +-- 2020-2022, Andrew Lelechenko + +fromListNChecked :: VG.Vector v a => Int -> [a] -> v a +{-# INLINE fromListNChecked #-} +fromListNChecked n = VG.unstream . bundleFromListNChecked n + +bundleFromListNChecked :: Int -> [a] -> VBM.Bundle VFU.Id v a +{-# INLINE_FUSED bundleFromListNChecked #-} +bundleFromListNChecked nTop xsTop + | nTop < 0 = error "fromListNChecked: length negative" + | otherwise = + VBM.fromStream (Stream.Stream step (xsTop, nTop)) (VBS.Max (VFU.delay_inline max nTop 0)) + where + {-# INLINE_INNER step #-} + step (xs,n) | n == 0 = case xs of + [] -> return Stream.Done + _:_ -> error "fromListNChecked: list too long" + step (x:xs,n) = return (Stream.Yield x (xs,n-1)) + step ([],_) = error "fromListNChecked: list too short" diff --git a/src/GHC/TypeLits/Orphans.hs b/src/GHC/TypeLits/Orphans.hs new file mode 100644 index 0000000..42f7293 --- /dev/null +++ b/src/GHC/TypeLits/Orphans.hs @@ -0,0 +1,13 @@ +-- | Compatibility module adding some additional instances not yet defined in +-- base-4.18 with GHC 9.6. +{-# OPTIONS -Wno-orphans #-} +module GHC.TypeLits.Orphans where + +import GHC.TypeLits + + +instance Eq (SNat n) where + _ == _ = True + +instance Ord (SNat n) where + compare _ _ = EQ diff --git a/test/Gen.hs b/test/Gen.hs index 281c620..952e8db 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -4,7 +4,6 @@ {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -20,10 +19,10 @@ import Foreign import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Nested.Permutation -import Data.Array.Nested.Types import Data.Array.Nested +import Data.Array.Nested.Permutation import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Hedgehog import Hedgehog.Gen qualified as Gen @@ -59,9 +58,12 @@ shuffleShR = \sh -> go (length sh) (toList sh) sh (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) -genShR sn = do +genShR = genShRwithTarget 100_000 + +genShRwithTarget :: Int -> SNat n -> Gen (IShR n) +genShRwithTarget targetMax sn = do let n = fromSNat' sn - targetSize <- Gen.int (Range.linear 0 100_000) + targetSize <- Gen.int (Range.linear 0 targetMax) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do @@ -76,9 +78,8 @@ genShR sn = do dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) dims <- genDims sn targetSize - let dimsL = toList dims - maxdim = maximum dimsL - cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + let maxdim = maximum dims + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize) shuffleShR (min cap <$> dims) -- | Example: given 3 and 7, might return: @@ -93,10 +94,14 @@ genShR sn = do -- other dimensions might be zero. genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) genReplicatedShR = \m n -> do - sh1 <- genShR m + let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m)) + sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m (sh2, sh3) <- injectOnes n sh1 sh1 return (sh1, sh2, sh3) where + repvalrange = (1::Int, 5) + repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) injectOnes n@SNat shOnes sh | m@SNat <- shrRank sh @@ -105,7 +110,7 @@ genReplicatedShR = \m n -> do EQI -> return (shOnes, sh) GTI -> do index <- Gen.int (Range.linear 0 (fromSNat' m)) - value <- Gen.int (Range.linear 1 5) + value <- Gen.int (uncurry Range.linear repvalrange) Refl <- return (lem n m) injectOnes n (inject index 1 shOnes) (inject index value sh) @@ -115,7 +120,7 @@ genReplicatedShR = \m n -> do inject :: Int -> Int -> IShR m -> IShR (m + 1) inject 0 v sh = v :$: sh inject i v (w :$: sh) = w :$: inject (i - 1) v sh - inject _ v ZSR = v :$: ZSR -- invalid input, but meh + inject _ _ ZSR = error "unreachable" genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do @@ -134,7 +139,7 @@ genStaticShX = \n k -> case n of genStaticShX n' $ \ssh -> k (item :!% ssh) where - genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m () + genItem :: Monad m => (forall n. SMayNat () n -> PropertyT m ()) -> PropertyT m () genItem k = do b <- forAll Gen.bool if b @@ -157,7 +162,7 @@ genPermR n = Gen.shuffle [0 .. n-1] genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r genPerm n@SNat k = do list <- forAll $ genPermR (fromSNat' n) - permFromList list $ \perm -> do + permFromListCont list $ \perm -> do case permCheckPermutation perm $ case sameNat' (permRank perm) n of Just Refl -> Just (k perm) diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 3b78bc0..e26c3dd 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) {-# LANGUAGE TypeAbstractions #-} +#endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -18,13 +21,13 @@ import Data.Type.Equality import Foreign import GHC.TypeLits -import Data.Array.Nested.Types (fromSNat') import Data.Array.Nested import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types (fromSNat') import Hedgehog import Hedgehog.Gen qualified as Gen -import Hedgehog.Internal.Property (forAllT) +import Hedgehog.Internal.Property (LabelName(..), forAllT) import Hedgehog.Range qualified as Range import Test.Tasty import Test.Tasty.Hedgehog @@ -39,8 +42,11 @@ import Util fineTol :: Double fineTol = 1e-8 -prop_sum_nonempty :: Property -prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do +debugCoverage :: Bool +debugCoverage = False + +gen_red_nonempty :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_nonempty f = property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank @@ -49,11 +55,10 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + f inrank outrank arr -prop_sum_empty :: Property -prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do +gen_red_empty :: (forall n. SNat (n + 1) -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_empty f = property $ genRank $ \outrankm1@(SNat @nm1) -> do -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. _outrank :: SNat n <- return $ SNat @(nm1 + 1) let inrank = SNat @(n + 1) @@ -62,14 +67,13 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do sht <- shuffleShR (0 :$: shtt) -- n n <- Gen.int (Range.linear 0 20) return (n :$: sht) -- n + 1 - guard (0 `elem` toList (shrTail sh)) + guard (0 `elem` shrTail sh) -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) let arr = OR.fromList @(n + 1) @Double (toList sh) [] - let rarr = rfromOrthotope inrank arr - OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + f inrank arr -prop_sum_lasteq1 :: Property -prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do +gen_red_lasteq1 :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_lasteq1 f = property $ genRank $ \outrank@(SNat @n) -> do let inrank = SNat @(n + 1) outsh <- forAll $ genShR outrank guard (all (> 0) outsh) @@ -77,11 +81,10 @@ prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> genStorables (Range.singleton (product insh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) - let rarr = rfromOrthotope inrank arr - almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + f inrank outrank arr -prop_sum_replicated :: Bool -> Property -prop_sum_replicated doTranspose = property $ +gen_red_replicated :: Bool -> (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property +gen_red_replicated doTranspose f = property $ genRank $ \inrank1@(SNat @m) -> genRank $ \outrank@(SNat @nm1) -> do inrank2 :: SNat n <- return $ SNat @(nm1 + 1) @@ -89,6 +92,10 @@ prop_sum_replicated doTranspose = property $ LTI -> return Refl -- actually we only continue if m < n _ -> discard (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 + when debugCoverage $ do + label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1))) + label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int))) + label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int))) guard (all (> 0) sh3) arr <- forAllT $ OR.stretch (toList sh3) @@ -100,8 +107,50 @@ prop_sum_replicated doTranspose = property $ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) return $ OR.transpose perm arr else return arr - let rarr = rfromOrthotope inrank2 arrTrans - almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + f inrank2 outrank arrTrans + + +prop_sum_nonempty :: Property +prop_sum_nonempty = gen_red_nonempty $ \inrank outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + +prop_sum_empty :: Property +prop_sum_empty = gen_red_empty $ \inrank arr -> do + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = gen_red_lasteq1 $ \inrank outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = gen_red_replicated doTranspose $ \inrank outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) + + +prop_sumall_nonempty :: Property +prop_sumall_nonempty = gen_red_nonempty $ \inrank _outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr) + +prop_sumall_empty :: Property +prop_sumall_empty = gen_red_empty $ \inrank arr -> do + let rarr = rfromOrthotope inrank arr + rsumAllPrim rarr === 0.0 + +prop_sumall_lasteq1 :: Property +prop_sumall_lasteq1 = gen_red_lasteq1 $ \inrank _outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr) + +prop_sumall_replicated :: Bool -> Property +prop_sumall_replicated doTranspose = gen_red_replicated doTranspose $ \inrank _outrank arr -> do + let rarr = rfromOrthotope inrank arr + almostEq 1e-6 (rsumAllPrim rarr) (OR.sumA arr) + prop_negate_with :: forall f b. Show b => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ()) @@ -130,6 +179,13 @@ tests = testGroup "C" ,testProperty "replicated" (prop_sum_replicated False) ,testProperty "replicated_transposed" (prop_sum_replicated True) ] + ,testGroup "sumAll" + [testProperty "nonempty" prop_sumall_nonempty + ,testProperty "empty" prop_sumall_empty + ,testProperty "last==1" prop_sumall_lasteq1 + ,testProperty "replicated" (prop_sumall_replicated False) + ,testProperty "replicated_transposed" (prop_sumall_replicated True) + ] ,testGroup "negate" [testProperty "normalised" $ prop_negate_with (\k -> genRank (k (Const ()))) diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs index 98a6da5..4e75d64 100644 --- a/test/Tests/Permutation.hs +++ b/test/Tests/Permutation.hs @@ -24,7 +24,7 @@ tests = testGroup "Permutation" [testProperty "permCheckPermutation" $ property $ do n <- forAll $ Gen.int (Range.linear 0 10) list <- forAll $ genPermR n - let r = permFromList list $ \perm -> + let r = permFromListCont list $ \perm -> permCheckPermutation perm () case r of Just () -> return () diff --git a/test/Util.hs b/test/Util.hs index 8a5ba72..6514fbf 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -36,16 +36,20 @@ orSumOuter1 (sn@SNat :: SNat n) = let n = fromSNat' sn in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) -class AlmostEq f where - type AlmostEqConstr f :: Type -> Constraint +class AlmostEq t where + type EltOf t :: Type -- | absolute tolerance, lhs, rhs - almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) - => a -> f a -> f a -> m () + almostEq :: MonadTest m => EltOf t -> t -> t -> m () -instance AlmostEq (OR.Array n) where - type AlmostEqConstr (OR.Array n) = OR.Unbox +instance (OR.Unbox a, Ord a, Show a, Fractional a) => AlmostEq (OR.Array n a) where + type EltOf (OR.Array n a) = a almostEq atol lhs rhs | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = success | otherwise = failDiff lhs rhs + +instance AlmostEq Double where + type EltOf Double = Double + almostEq atol lhs rhs | abs (lhs - rhs) < atol = success + | otherwise = failDiff lhs rhs |
