aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.md7
-rw-r--r--README.md197
-rw-r--r--bench/Main.hs60
-rw-r--r--cabal.project2
-rw-r--r--cbits/arith.c10
-rw-r--r--cbits/arith_lists.h62
-rwxr-xr-xgentrace.sh2
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs69
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs4
-rw-r--r--ox-arrays.cabal59
-rw-r--r--release-hints.txt3
-rw-r--r--src/Data/Array/Nested.hs61
-rw-r--r--src/Data/Array/Nested/Convert.hs363
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs59
-rw-r--r--src/Data/Array/Nested/Lemmas.hs (renamed from src/Data/Array/Mixed/Lemmas.hs)108
-rw-r--r--src/Data/Array/Nested/Mixed.hs418
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs280
-rw-r--r--src/Data/Array/Nested/Mixed/Shape/Internal.hs59
-rw-r--r--src/Data/Array/Nested/Permutation.hs55
-rw-r--r--src/Data/Array/Nested/Ranked.hs198
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs70
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs179
-rw-r--r--src/Data/Array/Nested/Shaped.hs144
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs67
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs205
-rw-r--r--src/Data/Array/Nested/Trace.hs41
-rw-r--r--src/Data/Array/Nested/Trace/TH.hs83
-rw-r--r--src/Data/Array/Nested/Types.hs20
-rw-r--r--src/Data/Array/Strided/Orthotope.hs5
-rw-r--r--src/Data/Array/XArray.hs135
-rw-r--r--src/Data/Vector/Generic/Checked.hs40
-rw-r--r--src/GHC/TypeLits/Orphans.hs13
-rw-r--r--test/Gen.hs31
-rw-r--r--test/Tests/C.hs94
-rw-r--r--test/Tests/Permutation.hs2
-rw-r--r--test/Util.hs16
36 files changed, 2199 insertions, 1022 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..009d267
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,7 @@
+# Changelog for `ox-arrays`
+
+This package intends to follow the [PVP](https://pvp.haskell.org/).
+
+## 0.1.0.0
+- Initial release
+- Various aspects of the API are still experimental, and breaking changes are expected in the future.
diff --git a/README.md b/README.md
index 9b8d543..667bf9d 100644
--- a/README.md
+++ b/README.md
@@ -1,56 +1,173 @@
-Wrapper library around `orthotope` that defines nested arrays, including
-tuples, of (eventually) unboxed values. The arrays are represented in
-struct-of-arrays form via the `Data.Vector.Unboxed` data family trick. Below
-the surface layer, there is a more low-level wrapper around `orthotope` that
-defines an array type type-indexed by `[Maybe Nat]`: some dimensions are
-shape-typed (i.e. have their size statically known), and some not.
+## ox-arrays
-An overview of the API:
+ox-arrays is an array library that defines nested arrays, including tuples, of
+(eventually) unboxed values. The arrays are represented in struct-of-arrays
+form via the `Data.Vector.Unboxed` data family trick; the component arrays are
+`orthotope` arrays
+([RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html))
+which describe elements using a _stride vector_ or
+[LMAD](https://dl.acm.org/doi/pdf/10.1145/509705.509708) so that `transpose`
+and `replicate` need only modify array metadata, not actually move around data.
+
+Because of the struct-of-arrays representation, nested arrays are not fully
+general: indeed, arrays are not actually nested under the hood, so if one has an
+array of arrays, those element arrays must all have the same shape (length,
+width, etc.). If one has an array of tuples of arrays, then all the `fst`
+components must have the same shape and all the `snd` components must have the
+same shape, but the two pair components themselves can be different.
+
+However, the nesting functionality of ox-arrays can be completely ignored if you
+only care about other parts of its API, or the vectorised arithmetic operations
+(using hand-written C code). Nesting support mostly does not get in the way, and
+has essentially no overhead (both when it's used and when it's not used).
+
+ox-arrays defines three array types: `Ranked`, `Shaped` and `Mixed`.
+- `Ranked` corresponds to `orthotope`'s
+ [RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html)
+ and has the _rank_ of the array (its number of dimensions) on the type level.
+ For example, `Ranked 2 Float` is a two-dimensional array of `Float`s, i.e. a
+ matrix.
+- `Shaped` corresponds to `orthotope`'s
+ [ShapedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html).
+ and has the full _shape_ of the array (its dimension sizes) on the type level
+ as a type-level list of `Nat`s. For example, `Shaped [2,3] Float` is a 2-by-3
+ matrix. The innermost dimension correspond to the right-most element in the
+ list.
+- `Mixed` is halfway between the two: it has a type parameter of kind
+ `[Maybe Nat]` whose length is the rank of the array; `Nothing` elements have
+ unknown size, whereas `Just` elements have the indicated size. The type
+ `Mixed [Nothing, Nothing] a` is equivalent to `Ranked 2 a`; the type
+ `Mixed [Just n, Just m] a` is equivalent to `Shaped [n, m] a`.
+
+In various places in the API of a library like ox-arrays, one can make a
+decision between 1. requiring a type class constraint providing certain
+information (e.g.
+[KnownNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:KnownNat)
+or `orthotope`'s
+[Shape](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html#t:Shape)),
+or 2. taking singleton _values_ that encode said information in a way that is
+linked to the type level (e.g.
+[SNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:SNat)).
+`orthotope` chooses the type class approach; ox-arrays chooses the singleton
+approach. Singletons are more verbose at times, but give the programmer more
+insight in what data is flowing where, and more importantly, more control: type
+class inference is very nice and implicit, but if it's not powerful enough for
+the trickery you're doing, you're out of luck. Singletons allow you to explain
+as precisely as you want to GHC what exactly you're doing.
+
+Below the surface layer, there is a more low-level wrapper (`XArray`) around
+`orthotope` that defines a non-nested `Mixed`-style array type.
+
+**Be aware**: `ox-arrays` attempts to preserve sharing as much as possible.
+That is to say: if a function is able to avoid copying array data and return an
+array that references the original underlying `Vector`, it may do so. For
+example, this means that if you convert a nested array to a list of arrays, all
+returned arrays reference part of the original array without copying. This
+makes `mtoList` fast, but also means that memory may be retained longer than
+you might expect.
+
+Here is a little taster of the API, to get a sense for the design:
```haskell
-data Ranked (n :: INat) a {- e.g. -} Ranked 3 Float
-data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
-data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
-
-Ranked I0 a = Ranked Z a ~~= Acc.Array Z a = Acc.Scalar a
-Ranked I1 a = Ranked (S Z) a ~~= Acc.Array (Z :. Int) a = Acc.Vector a
-Ranked I2 a = Ranked (S (S Z)) a ~~= Acc.Array (Z :. Int :. Int) a = Acc.Matrix a
+import GHC.TypeLits (Nat, SNat)
+data Ranked (n :: Nat) a {- e.g. -} Ranked 3 Float
+data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
+data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
-rshape :: (Elt a, KnownINat n) => Ranked n a -> IxR n
-sshape :: (Elt a, KnownShape sh) => Shaped sh a -> IxS sh
-mshape :: (Elt a, KnownShapeX xsh) => Mixed xsh a -> IxX xsh
+-- Shape types are written Sh{R,S,X}. The 'I' prefix denotes a Int-filled shape;
+-- ShR and ShX are more general containers. ShS is a singleton.
+rshape :: Elt a => Ranked n a -> IShR n
+sshape :: Elt a => Shaped sh a -> ShS sh
+mshape :: Elt a => Mixed xsh a -> IShX xsh
-rindex :: Elt a => Ranked n a -> IxR n -> a
-sindex :: Elt a => Shaped sh a -> IxS sh -> a
-mindex :: Elt a => Mixed xsh a -> IxX xsh -> a
+-- Index types are written Ix{R,S,X}.
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+mindex :: Elt a => Mixed xsh a -> IIxX xsh -> a
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
+-- The index types can be used as if they were defined as follows; pattern
+-- synonyms are provided to construct the illusion. (The actual definitions are
+-- a bit more general and indirect.)
+data IIxR n where
+ ZIR :: IIxR 0
+ (:.:) :: Int -> IIxR n -> IIxR (n + 1)
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
+data IIxS sh where
+ ZIS :: IIxS '[]
+ (:.$) :: Int -> IIxS sh -> IIxS (n : sh)
-data IxX sh where
- IZX :: IxX '[]
- (::@) :: Int -> IxX sh -> IxX (Just n : sh)
- (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+data IIxX xsh where
+ ZIX :: IIxX '[]
+ (:.%) :: Int -> IIxX xsh -> IIxX (mn : xsh)
+-- Similarly, the shape types can be used as if they were defined as follows.
+data IShR n where
+ ZSR :: IShR 0
+ (:$:) :: Int -> IShR n -> IShR (n + 1)
+
+data ShS sh where
+ ZSS :: ShS '[]
+ (:$$) :: SNat n -> ShS sh -> ShS (n : sh)
+
+data IShX xsh where
+ ZSX :: IShX '[]
+ (:$%) :: SMayNat Int mn -> IShX xsh -> IShX (mn : xsh)
+-- where:
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: SNat n -> SMayNat i (Just n)
+
+-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed
+-- shape -- that is to say, only the statically-known part of a mixed shape.
+-- StaticShX provides for this need. It can be used as if defined as follows:
+data StaticShX xsh where
+ ZKX :: StaticShX '[]
+ (:!%) :: SMayNat () mn -> StaticShX xsh -> StaticShX (mn : xsh)
+
+-- The Elt class describes types that can be used as elements of an array. While
+-- it is technically possible to define new instances of this class, typical
+-- usage should regard Elt as closed. The user-relevant instances are the
+-- following:
class Elt a
-instance Elt ()
-instance Elt Double
-instance Elt Int
-instance (Elt a, Elt b) => Elt (a, b)
-instance (Elt a, KnownINat n) => Elt (Ranked n a)
-instance (Elt a, KnownShape sh) => Elt (Shaped sh a)
-instance (Elt a, KnownShapeX xsh) => Elt (Mixed xsh a)
+instance Elt ()
+instance Elt Bool
+instance Elt Float
+instance Elt Double
+instance Elt Int
+instance (Elt a, Elt b) => Elt (a, b)
+instance Elt a => Elt (Ranked n a)
+instance Elt a => Elt (Shaped sh a)
+instance Elt a => Elt (Mixed xsh a)
+
+-- Essentially all functions that ox-arrays offers on arrays are first-order:
+-- add two arrays elementwise, transpose an array, append arrays, compute
+-- minima/maxima, zip/unzip, nest/unnest, etc. The first-order approach allows
+-- operations, especially arithmetic ones, to be vectorised using hand-written
+-- C code, without needing any sort of JIT compilation.
+rappend :: Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+mappend :: Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
-rgenerate :: Elt a => IxR n -> (IxR n -> a) -> Ranked n a
-sgenerate :: (Elt a, KnownShape sh) => (IxS sh -> a) -> Shaped sh a
-mgenerate :: (Elt a, KnownShapeX xsh) => IxX xsh -> (IxX xsh -> a) -> Mixed xsh a
+-- Exceptionally, also one higher-order function is provided per array type:
+-- 'generate'. These functions have the caveat that regularity of arrays must be
+-- preserved: all returned 'a's must have equal shape. See the documentation of
+-- 'mgenerate'.
+-- Warning: because the invocations of the function you pass cannot be
+-- vectorised, 'generate' is rather slow if 'a' is small.
+-- The 'KnownElt' class captures an API infelicity where constraint-based shape
+-- passing is the only practical option.
+rgenerate :: KnownElt a => IShR n -> (IxR n -> a) -> Ranked n a
+sgenerate :: KnownElt a => ShS sh -> (IxS sh -> a) -> Shaped sh a
+mgenerate :: KnownElt a => IShX xsh -> (IxX xsh -> a) -> Mixed xsh a
+-- Under the hood, Ranked and Shaped are both newtypes over Mixed. Mixed itself
+-- is a data family over XArray, which is a newtype over orthotope's RankedS.
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
```
+
+About the name: when importing `orthotope` array modules, a possible naming
+convention is to use qualified imports as `OR` for "orthotope ranked" arrays and
+`OS` for "orthotope shaped" arrays. ox-arrays was started to fill the `OX` gap,
+then grew out of proportion.
diff --git a/bench/Main.hs b/bench/Main.hs
index ef03b1a..ce3a9df 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -15,19 +15,12 @@ import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
import Text.Show (showListWith)
-import Data.Array.XArray (XArray(..))
import Data.Array.Nested
import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
import Data.Array.Strided.Arith.Internal qualified as Arith
-
-
-enableMisc :: Bool
-enableMisc = False
-
-bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
-bgroupIf True = bgroup
-bgroupIf False = \name _ -> bgroup name []
+import Data.Array.XArray (XArray(..))
main :: IO ()
@@ -51,7 +44,7 @@ main_tests = defaultMain
" str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
- iota n = riota @Double n
+ iota = riota @Double
in
[dotprodBench "dot 1D"
(iota 10_000_000
@@ -104,7 +97,7 @@ main_tests = defaultMain
in nf (\a -> RS.normalize a)
(RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
]
- ,bgroupIf enableMisc "misc"
+ ,bgroup "misc"
[let n = 1000
k = 1000
in bgroup ("fusion [" ++ show k ++ "]*" ++ show n)
@@ -148,6 +141,16 @@ main_tests = defaultMain
| ki <- [1::Int ..5]]
]
]
+ ,bench "ixxFromLinear 10000x" $
+ let n = 10000
+ sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> [ixxFromLinear @Int sh i | i <- [1..n]]) sh0
+ ,bench "ixxFromLinear 1x" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> ixxFromLinear @Int sh 1234) sh0
+ ,bench "shxEnum" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> shxEnum sh) sh0
]
]
@@ -156,45 +159,54 @@ tests_compare =
let n = 1_000_000 in
[bgroup "Num"
[bench "sum(+) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (+)) a b)))
(riota @Double n, riota n)
,bench "sum(*) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (*)) a b)))
(riota @Double n, riota n)
,bench "sum(/) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (/)) a b)))
(riota @Double n, riota n)
,bench "sum(**) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (**)) a b)))
(riota @Double n, riota n)
,bench "sum(sin) Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a)))
+ nf (\a -> runScalar (rsumOuter1Prim (liftRanked1 (mliftPrim sin) a)))
(riota @Double n)
,bench "sum Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 a))
+ nf (\a -> runScalar (rsumOuter1Prim a))
+ (riota @Double n)
+ ,bench "sumAll iota [1e6]" $
+ nf (\a -> rsumAllPrim a)
(riota @Double n)
+ ,bench "sumAll rev1(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rrev1 $ riota @Double n)
+ ,bench "sumAll reshape(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rreshape (1 :$: n :$: 1 :$: ZSR) $ riota @Double n)
]
,bgroup "NumElt"
[bench "sum(+) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a + b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a + b)))
(riota @Double n, riota n)
,bench "sum(*) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b)))
(riota @Double n, riota n)
,bench "sum(/) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a / b)))
(riota @Double n, riota n)
,bench "sum(**) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a ** b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a ** b)))
(riota @Double n, riota n)
,bench "sum(sin) Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 (sin a)))
+ nf (\a -> runScalar (rsumOuter1Prim (sin a)))
(riota @Double n)
,bench "sum Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 a))
+ nf (\a -> runScalar (rsumOuter1Prim a))
(riota @Double n)
,bench "sum(*) Double [1e6] stride 1; -1" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b)))
(riota @Double n, rrev1 (riota n))
,bench "dotprod Float [1e6]" $
nf (\(a, b) -> rdot a b)
diff --git a/cabal.project b/cabal.project
index d102ed6..d76d872 100644
--- a/cabal.project
+++ b/cabal.project
@@ -1,2 +1,2 @@
packages: .
-with-compiler: ghc-9.8.4
+with-compiler: ghc-9.12.2
diff --git a/cbits/arith.c b/cbits/arith.c
index f19b01e..ee248a4 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -20,6 +20,8 @@
// Shorter names, due to CPP used both in function names and in C types.
+typedef int8_t i8;
+typedef int16_t i16;
typedef int32_t i32;
typedef int64_t i64;
@@ -248,6 +250,8 @@ void oxarrays_stats_print_all(void) {
#define GEN_ABS(x) \
_Generic((x), \
+ i8: abs, \
+ i16: abs, \
int: abs, \
long: labs, \
long long: llabs, \
@@ -490,7 +494,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
if (rank == 0) return arr[0]; \
typ result = 0; \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
- REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \
+ typ dest = 0; \
+ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, dest); \
+ result = result op dest; \
}); \
return result; \
}
@@ -738,7 +744,7 @@ enum redop_tag_t {
* Generate all the functions *
*****************************************************************************/
-#define INT_TYPES_XLIST X(i32) X(i64)
+#define INT_TYPES_XLIST X(i8) X(i16) X(i32) X(i64)
#define FLOAT_TYPES_XLIST X(double) X(float)
#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index 432765c..dc9ad1a 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -2,38 +2,38 @@ LIST_BINOP(BO_ADD, 1, +)
LIST_BINOP(BO_SUB, 2, -)
LIST_BINOP(BO_MUL, 3, *)
-LIST_IBINOP(IB_QUOT, 1, quot)
-LIST_IBINOP(IB_REM, 2, rem)
+LIST_IBINOP(IB_QUOT, 11, quot)
+LIST_IBINOP(IB_REM, 12, rem)
-LIST_FBINOP(FB_DIV, 1, /)
-LIST_FBINOP(FB_POW, 2, **)
-LIST_FBINOP(FB_LOGBASE, 3, logBase)
-LIST_FBINOP(FB_ATAN2, 4, atan2)
+LIST_FBINOP(FB_DIV, 21, /)
+LIST_FBINOP(FB_POW, 22, **)
+LIST_FBINOP(FB_LOGBASE, 23, logBase)
+LIST_FBINOP(FB_ATAN2, 24, atan2)
-LIST_UNOP(UO_NEG, 1,)
-LIST_UNOP(UO_ABS, 2,)
-LIST_UNOP(UO_SIGNUM, 3,)
+LIST_UNOP(UO_NEG, 31,)
+LIST_UNOP(UO_ABS, 32,)
+LIST_UNOP(UO_SIGNUM, 33,)
-LIST_FUNOP(FU_RECIP, 1,)
-LIST_FUNOP(FU_EXP, 2,)
-LIST_FUNOP(FU_LOG, 3,)
-LIST_FUNOP(FU_SQRT, 4,)
-LIST_FUNOP(FU_SIN, 5,)
-LIST_FUNOP(FU_COS, 6,)
-LIST_FUNOP(FU_TAN, 7,)
-LIST_FUNOP(FU_ASIN, 8,)
-LIST_FUNOP(FU_ACOS, 9,)
-LIST_FUNOP(FU_ATAN, 10,)
-LIST_FUNOP(FU_SINH, 11,)
-LIST_FUNOP(FU_COSH, 12,)
-LIST_FUNOP(FU_TANH, 13,)
-LIST_FUNOP(FU_ASINH, 14,)
-LIST_FUNOP(FU_ACOSH, 15,)
-LIST_FUNOP(FU_ATANH, 16,)
-LIST_FUNOP(FU_LOG1P, 17,)
-LIST_FUNOP(FU_EXPM1, 18,)
-LIST_FUNOP(FU_LOG1PEXP, 19,)
-LIST_FUNOP(FU_LOG1MEXP, 20,)
+LIST_FUNOP(FU_RECIP, 41,)
+LIST_FUNOP(FU_EXP, 42,)
+LIST_FUNOP(FU_LOG, 43,)
+LIST_FUNOP(FU_SQRT, 44,)
+LIST_FUNOP(FU_SIN, 45,)
+LIST_FUNOP(FU_COS, 46,)
+LIST_FUNOP(FU_TAN, 47,)
+LIST_FUNOP(FU_ASIN, 48,)
+LIST_FUNOP(FU_ACOS, 49,)
+LIST_FUNOP(FU_ATAN, 50,)
+LIST_FUNOP(FU_SINH, 51,)
+LIST_FUNOP(FU_COSH, 52,)
+LIST_FUNOP(FU_TANH, 53,)
+LIST_FUNOP(FU_ASINH, 54,)
+LIST_FUNOP(FU_ACOSH, 55,)
+LIST_FUNOP(FU_ATANH, 56,)
+LIST_FUNOP(FU_LOG1P, 57,)
+LIST_FUNOP(FU_EXPM1, 58,)
+LIST_FUNOP(FU_LOG1PEXP, 59,)
+LIST_FUNOP(FU_LOG1MEXP, 60,)
-LIST_REDOP(RO_SUM, 1,)
-LIST_REDOP(RO_PRODUCT, 2,)
+LIST_REDOP(RO_SUM, 81,)
+LIST_REDOP(RO_PRODUCT, 82,)
diff --git a/gentrace.sh b/gentrace.sh
index 7be2b9c..c3f1240 100755
--- a/gentrace.sh
+++ b/gentrace.sh
@@ -8,7 +8,7 @@ module Data.Array.Nested.Trace (
-- * Re-exports from the plain "Data.Array.Nested" module
EOF
-sed -n '/^module/,/^) where/!d; /^\s*-- /d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
+sed -n '/^module/,/^) where/!d; /^\s*--\( \|$\)/d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
cat <<'EOF'
) where
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 5802573..2eb0666 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -27,7 +27,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
+import GHC.TypeNats qualified as TN
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
@@ -42,7 +42,7 @@ import Data.Array.Strided.Array
-- TODO: move this to a utilities module
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
data Dict c where
Dict :: c => Dict c
@@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) =
unrepSize = product [n | (n, True) <- zip sh replDims]
- in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
simplifyArray :: Array n a
@@ -200,7 +200,7 @@ simplifyArray :: Array n a
-> r)
-> r
simplifyArray array k
- | let revDims = map (<0) (arrStrides array)
+ | let revDims = map (< 0) (arrStrides array)
, Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
= k array'
unrepSize
@@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
, let unrepSize = product [n | (n, True) <- zip sh replDims]
- = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
k @lenshF
(Array shF strides1F offset1 vec1)
(Array shF strides2F offset2 vec2)
@@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
VS.unsafeWith vec' $ \pv ->
let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
- TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -396,6 +396,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
Nothing -> error "impossible"
-- TODO: test handling of negative strides
+-- TODO: simplify away normalised dimensions
-- | Reduce full array
{-# NOINLINE vectorRedFullOp #-}
vectorRedFullOp :: forall a b n. (Num a, Storable a)
@@ -490,7 +491,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
- TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -714,6 +715,36 @@ class NumElt a where
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+instance NumElt Int8 where
+ numEltAdd = addVectorInt8
+ numEltSub = subVectorInt8
+ numEltMul = mulVectorInt8
+ numEltNeg = negVectorInt8
+ numEltAbs = absVectorInt8
+ numEltSignum = signumVectorInt8
+ numEltSum1Inner = sum1VectorInt8
+ numEltProduct1Inner = product1VectorInt8
+ numEltSumFull = sumFullVectorInt8
+ numEltProductFull = productFullVectorInt8
+ numEltMinIndex _ = minindexVectorInt8
+ numEltMaxIndex _ = maxindexVectorInt8
+ numEltDotprodInner = dotprodinnerVectorInt8
+
+instance NumElt Int16 where
+ numEltAdd = addVectorInt16
+ numEltSub = subVectorInt16
+ numEltMul = mulVectorInt16
+ numEltNeg = negVectorInt16
+ numEltAbs = absVectorInt16
+ numEltSignum = signumVectorInt16
+ numEltSum1Inner = sum1VectorInt16
+ numEltProduct1Inner = product1VectorInt16
+ numEltSumFull = sumFullVectorInt16
+ numEltProductFull = productFullVectorInt16
+ numEltMinIndex _ = minindexVectorInt16
+ numEltMaxIndex _ = maxindexVectorInt16
+ numEltDotprodInner = dotprodinnerVectorInt16
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -830,6 +861,14 @@ class NumElt a => IntElt a where
intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+instance IntElt Int8 where
+ intEltQuot = quotVectorInt8
+ intEltRem = remVectorInt8
+
+instance IntElt Int16 where
+ intEltQuot = quotVectorInt16
+ intEltRem = remVectorInt16
+
instance IntElt Int32 where
intEltQuot = quotVectorInt32
intEltRem = remVectorInt32
@@ -840,19 +879,19 @@ instance IntElt Int64 where
instance IntElt Int where
intEltQuot = intWidBranch2 @Int quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @Int rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
instance IntElt CInt where
intEltQuot = intWidBranch2 @CInt quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @CInt rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 910a77c..27204d2 100644
--- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -16,7 +16,9 @@ data ArithType = ArithType
intTypesList :: [ArithType]
intTypesList =
- [ArithType ''Int32 "i32"
+ [ArithType ''Int8 "i8"
+ ,ArithType ''Int16 "i16"
+ ,ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
]
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index d9a345b..83fe93f 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -5,7 +5,11 @@ synopsis: An efficient CPU-based multidimensional array (tensor) library
description:
An efficient and richly typed CPU-based multidimensional array (tensor)
library built upon the optimized tensor representation (strides list)
- implemented in the orthotope package.
+ implemented in the orthotope package. See the README.
+
+ If you use this package: let me know (e.g. via email) if you find it useful!
+ Both positive feedback (keep this!) and negative feedback (I needed this but
+ ox-arrays doesn't provide it) is welcome.
copyright: (c) 2025 Tom Smeding, Mikolaj Konarski
author: Tom Smeding, Mikolaj Konarski
maintainer: Tom Smeding <xhackage@tomsmeding.com>
@@ -13,6 +17,7 @@ license: BSD-3-Clause
category: Array, Tensors
build-type: Simple
+extra-doc-files: README.md CHANGELOG.md
extra-source-files: cbits/arith_lists.h
flag trace-wrappers
@@ -20,7 +25,7 @@ flag trace-wrappers
Compile modules that define wrappers around the array methods that trace
their arguments and results. This is conditional on a flag because these
modules make documentation generation fail.
- (https://gitlab.haskell.org/ghc/ghc/-/issues/24964 , should be fixed in
+ (@https://gitlab.haskell.org/ghc/ghc/-/issues/24964@ , should be fixed in
GHC 9.12)
default: False
manual: True
@@ -50,16 +55,23 @@ flag default-show-instances
default: False
manual: True
+common basics
+ default-language: Haskell2010
+ ghc-options: -Wall -Wcompat -Widentities -Wunused-packages -Wpartial-fields -Wredundant-bang-patterns -Woperator-whitespace -Wredundant-strictness-flags
+ if impl(ghc >= 9.14)
+ ghc-options: -Wno-pattern-namespace-specifier
+
library
+ import: basics
exposed-modules:
-- put this module on top so ghci considers it the "main" module
Data.Array.Nested
- Data.Array.Mixed.Lemmas
- Data.Array.Nested.Internal.Lemmas
Data.Array.Nested.Convert
Data.Array.Nested.Mixed
Data.Array.Nested.Mixed.Shape
+ Data.Array.Nested.Mixed.Shape.Internal
+ Data.Array.Nested.Lemmas
Data.Array.Nested.Permutation
Data.Array.Nested.Ranked
Data.Array.Nested.Ranked.Base
@@ -71,6 +83,11 @@ library
Data.Array.Strided.Orthotope
Data.Array.XArray
Data.Bag
+ Data.Vector.Generic.Checked
+
+ if impl(ghc < 9.8)
+ exposed-modules:
+ GHC.TypeLits.Orphans
if flag(trace-wrappers)
exposed-modules:
@@ -81,21 +98,22 @@ library
cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES
build-depends:
- strided-array-ops,
+ ox-arrays:strided-array-ops,
base,
deepseq < 1.7,
ghc-typelits-knownnat,
ghc-typelits-natnormalise,
- orthotope < 0.2,
- vector
+ orthotope >= 0.1.8.0 && < 0.2,
+ template-haskell,
+ vector,
+ vector-stream
hs-source-dirs: src
-
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
other-extensions: TemplateHaskell
library strided-array-ops
+ import: basics
+ visibility: public
exposed-modules:
Data.Array.Strided
Data.Array.Strided.Array
@@ -105,9 +123,11 @@ library strided-array-ops
Data.Array.Strided.Arith.Internal.Lists
Data.Array.Strided.Arith.Internal.Lists.TH
build-depends:
- base >=4.18 && <4.22,
- ghc-typelits-knownnat < 1,
- ghc-typelits-natnormalise < 1,
+ base >=4.18 && <4.23,
+ ghc-typelits-knownnat >= 0.8.0 && < 1
+ -- 0.9.0 is unsound: https://github.com/clash-lang/ghc-typelits-natnormalise/issues/105
+ && (< 0.9.0 || > 0.9.0),
+ ghc-typelits-natnormalise >= 0.8.1 && < 1,
template-haskell < 3,
vector < 0.14
hs-source-dirs: ops
@@ -122,11 +142,10 @@ library strided-array-ops
-- hmatrix assumes sse2, so we can too
cc-options: -msse2
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
other-extensions: TemplateHaskell
test-suite test
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
other-modules:
@@ -147,33 +166,29 @@ test-suite test
tasty-hedgehog,
vector
hs-source-dirs: test
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
test-suite example
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
build-depends:
ox-arrays,
base
hs-source-dirs: example
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
benchmark bench
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
build-depends:
ox-arrays,
- strided-array-ops,
+ ox-arrays:strided-array-ops,
base,
hmatrix,
orthotope,
tasty-bench,
vector
hs-source-dirs: bench
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
source-repository head
type: git
diff --git a/release-hints.txt b/release-hints.txt
index 259c671..2623caa 100644
--- a/release-hints.txt
+++ b/release-hints.txt
@@ -1,2 +1,5 @@
- Temporarily enable -Wredundant-constraints
- Has too many false-positives to enable normally, but sometimes catches actual redundant constraints
+- Don't forget to rerun gentrace.sh
+- Test with GHC 9.6, it's rather picky around type-level nats
+ - Whenever we drop support for GHC 9.6, search for "9,8" and remove all the conditionals, as well as the GHC.TypeLits.Orphans module
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9801529..c898a75 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -6,12 +6,17 @@ module Data.Array.Nested (
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
- rrerank,
- rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rfromListLinear, rfromListPrimLinear, rtoListLinear,
+ rrerankPrim,
+ rreplicate, rreplicatePrim,
+ rfromListOuter, rfromListOuterN,
+ rfromList1, rfromList1N,
+ rfromListLinear,
+ rfromList1Prim, rfromList1PrimN,
+ rfromListPrimLinear,
+ rtoListOuter, rtoList, rtoListLinear, rtoListPrim, rtoListPrimLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest, rzip, runzip,
@@ -19,7 +24,7 @@ module Data.Array.Nested (
rlift, rlift2,
-- ** Conversions
rtoXArrayPrim, rfromXArrayPrim,
- rcastToShaped, rtoMixed, rcastToMixed,
+ rtoMixed, rcastToMixed, rcastToShaped,
rfromOrthotope, rtoOrthotope,
-- ** Additional arithmetic operations
--
@@ -31,13 +36,14 @@ module Data.Array.Nested (
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
- sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
semptyArray,
- srerank,
- sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sfromListLinear, sfromListPrimLinear, stoListLinear,
+ srerankPrim,
+ sreplicate, sreplicatePrim,
+ sfromListOuter, sfromList1, sfromListLinear, sfromList1Prim, sfromListPrimLinear,
+ stoListOuter, stoList, stoListLinear, stoListPrim, stoListPrimLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest, szip, sunzip,
@@ -45,7 +51,7 @@ module Data.Array.Nested (
slift, slift2,
-- ** Conversions
stoXArrayPrim, sfromXArrayPrim,
- stoRanked, stoMixed, scastToMixed,
+ stoMixed, scastToMixed, stoRanked,
sfromOrthotope, stoOrthotope,
-- ** Additional arithmetic operations
--
@@ -59,13 +65,18 @@ module Data.Array.Nested (
ShX(.., ZSX, (:$%)), KnownShX(..), IShX,
StaticShX(.., ZKX, (:!%)),
SMayNat(..),
- mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
- mrerank,
- mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mfromListLinear, mfromListPrimLinear, mtoListLinear,
- mslice, mrev1, mreshape, mflatten, miota,
+ mrerankPrim,
+ mreplicate, mreplicatePrim,
+ mfromListOuter, mfromListOuterN, mfromListOuterSN,
+ mfromList1, mfromList1N, mfromList1SN,
+ mfromListLinear,
+ mfromList1Prim, mfromList1PrimN, mfromList1PrimSN,
+ mfromListPrimLinear,
+ mtoListOuter, mtoList, mtoListLinear, mtoListPrim, mtoListPrimLinear,
+ msliceN, msliceSN, mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
-- ** Lifting orthotope operations to 'Mixed' arrays
@@ -73,8 +84,8 @@ module Data.Array.Nested (
-- ** Conversions
mtoXArrayPrim, mfromXArrayPrim,
mcast,
- mtoRanked, mcastToShaped,
- castCastable, Castable(..),
+ mcastToShaped, mtoRanked,
+ convert, Conversion(..),
-- ** Additional arithmetic operations
--
-- $integralRealFloat
@@ -91,7 +102,7 @@ module Data.Array.Nested (
Storable,
SNat, pattern SNat,
pattern SZ, pattern SS,
- Perm(..),
+ Perm(..), PermR,
IsPermutation,
KnownPerm(..),
NumElt, IntElt, FloatElt,
@@ -102,23 +113,23 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed
-import Data.Array.Nested.Ranked
-import Data.Array.Nested.Shaped
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
import Foreign.Storable
import GHC.TypeLits
-- $integralRealFloat
--
--- These functions separate top-level functions, and not exposed in instances
--- for 'RealFloat' and 'Integral', because those classes include a variety of
--- other functions that make no sense for arrays.
+-- These functions are separate top-level functions, and not exposed in
+-- instances for 'RealFloat' and 'Integral', because those classes include a
+-- variety of other functions that make no sense for arrays.
-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but
-- having 'Num', 'Fractional' and 'Floating' available is just too useful.
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index d5e6008..8c88d23 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -1,42 +1,293 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
{-# LANGUAGE TypeAbstractions #-}
+#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Data.Array.Nested.Convert (
- castCastable,
- Castable(..),
+ -- * Shape\/index\/list casting functions
+ -- ** To ranked
+ ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
+ listrCast, ixrCast, shrCast,
+ -- ** To shaped
+ ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
+ ixsCast,
+ -- ** To mixed
+ ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
+ ixxCast, shxCast, shxCast',
- -- * Special cases
+ -- * Array conversions
+ convert,
+ Conversion(..),
+
+ -- * Special cases of array conversions
--
- -- | These functions can all be implemented using 'castCastable' in some way,
+ -- | These functions can all be implemented using 'convert' in some way,
-- but some have fewer constraints.
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
mcast, mcastToShaped, mtoRanked,
-
- -- * Additional index/shape casting functions
- ixrFromIxS, shrFromShS,
) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
+import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Types
-import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+-- * Shape or index or list casting functions
+
+-- * To ranked
+
+ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
+ixrFromIxS ZIS = ZIR
+ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX ZIX = ZIR
+ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
+-- shrFromShX re-exported
+-- shrFromShX2 re-exported
+-- listrCast re-exported
+-- ixrCast re-exported
+-- shrCast re-exported
+
+-- * To shaped
+
+-- TODO: these take a ShS because there are KnownNats inside IxS.
+
+ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
+ixsFromIxR ZSS ZIR = ZIS
+ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+
+-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
+-- following, but more efficient:
+--
+-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx)
+ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i
+ixsFromIxR' ZSS ZIR = ZIS
+ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
+ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
+
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX ZSS ZIX = ZIS
+ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+
+-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
+-- the following, but more efficient:
+--
+-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx)
+ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i
+ixsFromIxX' ZSS ZIX = ZIS
+ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx
+ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank"
+
+-- | Produce an existential 'ShS' from an 'IShR'.
+withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r
+withShsFromShR ZSR k = k ZSS
+withShsFromShR (n :$: sh) k =
+ withShsFromShR sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
+
+-- shsFromShX re-exported
+
+-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
+-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
+withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r
+withShsFromShX ZSX k = k ZSS
+withShsFromShX (SKnown sn@SNat :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ k (sn :$$ sh')
+withShsFromShX (SUnknown n :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+
+shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
+shsFromSSX = shsFromShX Prelude.. shxFromSSX
+
+-- ixsCast re-exported
+
+-- * To mixed
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR ZIR = ZIX
+ixxFromIxR (n :.: (idx :: IxR m i)) =
+ castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
+ (n :.% ixxFromIxR idx)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS ZIS = ZIX
+ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+
+shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
+shxFromShR ZSR = ZSX
+shxFromShR (n :$: (idx :: ShR m i)) =
+ castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
+ (SUnknown n :$% shxFromShR idx)
+
+shxFromShS :: ShS sh -> IShX (MapJust sh)
+shxFromShS ZSS = ZSX
+shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+
+-- ixxCast re-exported
+-- shxCast re-exported
+-- shxCast' re-exported
+
+
+-- * Array conversions
+
+-- | The constructors that perform runtime shape checking are marked with a
+-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types
+-- ensure that the shapes are already compatible. To convert between 'Ranked'
+-- and 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Conversion' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'XArray's. This leads to the inclusion of some operations that do
+-- not look like simple conversions (casts) at first glance, like 'ConvZip'.
+--
+-- /Note/: Haddock gleefully renames type variables in constructors so that
+-- they match the data type head as much as possible. See the source for a more
+-- readable presentation of this data type.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Conversion (Mixed sh a) (Shaped sh' a)
+
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
+ where
+ -- The 'esh' is the extension shape: the conversion happens under a whole
+ -- bunch of additional dimensions that it does not touch. These dimensions
+ -- are 'esh'.
+ -- The strategy is to unwind step-by-step to a large Mixed array, and to
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
+ x))
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go ConvX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x)
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go ConvZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go ConvUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
+
+ lemRankAppRankEq :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
+ lemRankAppRankEq _ _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
+-- * Special cases of array conversions
+
mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
=> StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
mcast ssh2 arr
@@ -45,7 +296,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked = castCastable (CastXR CastId)
+mtoRanked = convert ConvXR
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -59,7 +310,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
-mcastToShaped targetsh = castCastable (CastXS' targetsh CastId)
+mcastToShaped targetsh = convert (ConvXS' targetsh)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
@@ -82,91 +333,3 @@ rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped targetsh arr
-
-ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
-ixrFromIxS ZIS = ZIR
-ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
-
--- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh
--- ixsFromIxR = \ix -> go ix _
--- where
--- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r
--- go ZIR k = k ZIS
--- go (i :.: ix) k = go ix (i :.$)
-
-shrFromShS :: ShS sh -> IShR (Rank sh)
-shrFromShS ZSS = ZSR
-shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
-
--- | The constructors that perform runtime shape checking are marked with a
--- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
--- that the shapes are already compatible. To convert between 'Ranked' and
--- 'Shaped', go via 'Mixed'.
-data Castable a b where
- CastId :: Castable a a
- CastCmp :: Castable b c -> Castable a b -> Castable a c
-
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
-
- CastXR :: Elt b
- => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
-
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
-
- CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
- -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
-
-instance Category Castable where
- id = CastId
- (.) = CastCmp
-
-castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
-castCastable = \c x -> munScalar (go c (mscalar x))
- where
- -- The 'esh' is the extension shape: the casting happens under a whole
- -- bunch of additional dimensions that it does not touch. These dimensions
- -- are 'esh'.
- -- The strategy is to unwind step-by-step to a large Mixed array, and to
- -- perform the required checks and castings when re-nesting back up.
- go :: Castable a b -> Mixed esh a -> Mixed esh b
- go CastId x = x
- go (CastCmp c1 c2) x = go c1 (go c2 x)
- go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
- = let x' = go c x
- ssx' = ssxAppend (ssxFromShX esh)
- (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))
- in M_Ranked (M_Nest esh (mcast ssx' x'))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
- (go c x)))
- go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
- go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
- go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)
-
- lemRankAppRankEq :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
- lemRankAppRankEq _ _ _ = unsafeCoerceRefl
-
- lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
- -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
- lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
-
- lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
deleted file mode 100644
index b1589e0..0000000
--- a/src/Data/Array/Nested/Internal/Lemmas.hs
+++ /dev/null
@@ -1,59 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Lemmas where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Shaped.Shape
-
-
-lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
-lemRankMapJust ZSS = Refl
-lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
-
-lemMapJustApp :: ShS sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustApp ZSS _ = Refl
-lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
-
-lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
-lemTakeLenMapJust PNil _ = Refl
-lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
-lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
-
-lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
-lemDropLenMapJust PNil _ = Refl
-lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
-lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
-
-lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
-lemIndexMapJust SZ (_ :$$ _) = Refl
-lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
- | Refl <- lemIndexMapJust i sh
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = Refl
-lemIndexMapJust _ ZSS = error "Index of empty"
-
-lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
-lemPermuteMapJust PNil _ = Refl
-lemPermuteMapJust (i `PCons` is) sh
- | Refl <- lemPermuteMapJust is sh
- , Refl <- lemIndexMapJust i sh
- = Refl
-
-lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
-lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKX
- go (n :$$ sh) = SKnown n :!% go sh
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index e6d970c..e089479 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -6,7 +6,7 @@
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Lemmas where
+module Data.Array.Nested.Lemmas where
import Data.Proxy
import Data.Type.Equality
@@ -14,10 +14,11 @@ import GHC.TypeLits
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types
--- * Lemmas
+-- * Lemmas about numbers and lists
-- ** Nat
@@ -27,7 +28,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
-
-- ** Append
lemAppNil :: l ++ '[] :~: l
@@ -39,42 +39,22 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl
lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl
-
--- ** Rank
-
-lemRankApp :: forall sh1 sh2.
- StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
-lemRankApp ZKX _ = Refl
-lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
- = lem (Proxy @(Rank sh1T)) Proxy Proxy $
- sym (lemRankApp ssh1 ssh2)
- where
- lem :: proxy a -> proxy b -> proxy c
- -> (a + b :~: c)
- -> c + 1 :~: (a + 1 + b)
- lem _ _ _ Refl = Refl
-
-lemRankAppComm :: proxy sh1 -> proxy sh2
- -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
-lemRankAppComm _ _ = unsafeCoerceRefl
-
-lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = unsafeCoerceRefl
-
-
--- ** Various type families
+-- ** Simple type families
lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+{- for now, the plugins can't derive a type for this code, see
+ https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214
lemReplicatePlusApp sn _ _ = go sn
where
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS (n :: SNat n'm1))
- | Refl <- lemReplicateSucc @a @n'm1
+ | Refl <- lemReplicateSucc @a n
, Refl <- go n
- = sym (lemReplicateSucc @a @(n'm1 + m))
+ = sym (lemReplicateSucc @a (SNat @(n'm1 + m)))
+-}
+lemReplicatePlusApp _ _ _ = unsafeCoerceRefl
lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
@@ -107,6 +87,8 @@ lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+-- * Lemmas about shapes
+
-- ** Known shapes
lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
@@ -116,3 +98,69 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
+
+-- ** Rank
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem (Proxy @(Rank sh1T)) Proxy Proxy $
+ sym (lemRankApp ssh1 ssh2)
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem _ _ _ Refl = Refl
+
+lemRankAppComm :: proxy sh1 -> proxy sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl
+
+lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = unsafeCoerceRefl
+
+lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust ZSS = Refl
+lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
+
+-- ** Related to MapJust and/or Permutation
+
+lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemTakeLenMapJust PNil _ = Refl
+lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
+lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemDropLenMapJust PNil _ = Refl
+lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
+lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemIndexMapJust SZ (_ :$$ _) = Refl
+lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemIndexMapJust i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemIndexMapJust _ ZSS = error "Index of empty"
+
+lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemPermuteMapJust PNil _ = Refl
+lemPermuteMapJust (i `PCons` is) sh
+ | Refl <- lemPermuteMapJust is sh
+ , Refl <- lemIndexMapJust i sh
+ = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 54bd5f2..0766e8c 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -22,12 +23,14 @@ module Data.Array.Nested.Mixed where
import Prelude hiding (mconcat)
import Control.DeepSeq (NFData(..))
-import Control.Monad (forM_, when)
+import Control.Monad (foldM_, forM_, when)
import Control.Monad.ST
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
import Data.Array.RankedS qualified as S
import Data.Bifunctor (bimap)
import Data.Coerce
-import Data.Foldable (toList)
import Data.Int
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
@@ -38,17 +41,18 @@ import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types (CInt)
import Foreign.Storable (Storable)
+import Foreign.Storable qualified as Storable
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
import Data.Array.XArray (XArray(..))
import Data.Array.XArray qualified as X
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Strided.Orthotope
import Data.Bag
@@ -91,6 +95,9 @@ import Data.Bag
-- Unfortunately, the setup of the library requires us to list these primitive
-- element types multiple times; to aid in extending the list, all these lists
-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+--
+-- NOTE: if you add PRIMITIVE types, be sure to also add NumElt and IntElt
+-- instances for them!
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
@@ -118,6 +125,8 @@ instance PrimElt Bool
instance PrimElt Int
instance PrimElt Int64
instance PrimElt Int32
+instance PrimElt Int16
+instance PrimElt Int8
instance PrimElt CInt
instance PrimElt Float
instance PrimElt Double
@@ -154,6 +163,8 @@ newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq
newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int16 = M_Int16 (Mixed sh (Primitive Int16)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int8 = M_Int8 (Mixed sh (Primitive Int8)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW)
@@ -190,6 +201,8 @@ newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool)
newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh Int16 = MV_Int16 (VS.MVector s Int16)
+newtype instance MixedVecs s sh Int8 = MV_Int8 (VS.MVector s Int8)
newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
@@ -227,11 +240,13 @@ instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
+{-# INLINE mliftNumElt1 #-}
mliftNumElt1 :: (PrimElt a, PrimElt b)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
-> Mixed sh a -> Mixed sh b
mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+{-# INLINE mliftNumElt2 #-}
mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
@@ -247,15 +262,15 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
abs = mliftNumElt1 (liftO1 . numEltAbs)
signum = mliftNumElt1 (liftO1 . numEltSignum)
-- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
- fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
+ fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicatePrim"
recip = mliftNumElt1 (liftO1 . floatEltRecip)
(/) = mliftNumElt2 (liftO2 . floatEltDiv)
instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicatePrim"
exp = mliftNumElt1 (liftO1 . floatEltExp)
log = mliftNumElt1 (liftO1 . floatEltLog)
sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
@@ -298,15 +313,9 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- | See 'mfromListOuter'. If the list does not have the given length, a
+ -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -351,11 +360,14 @@ class Elt a where
-- | Tree giving the shape of every array component.
type ShapeTree a
+ -- | Produces an internal representation of a tree of shapes of (potentially)
+ -- nested arrays. If the argument is an array, it requires that the array
+ -- is not empty (otherwise, its' guaranteed to crash early, if non-trivial).
mshapeTree :: a -> ShapeTree a
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+ mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool
mshowShapeTree :: Proxy a -> ShapeTree a -> String
@@ -363,26 +375,28 @@ class Elt a where
-- this mixed array.
marrayStrides :: Mixed sh a -> Bag [Int]
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ -- | 'mvecsFreeze' but without copying the mutable vectors before freezing
+ -- them. If the mutable vectors are mutated after this function, referential
+ -- transparency is broken.
+ mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
-- of this class with those of 'Elt': some instances have an additional
-- "known-shape" constraint.
--
--- This class is (currently) only required for 'mgenerate',
--- 'Data.Array.Nested.Ranked.rgenerate' and
--- 'Data.Array.Nested.Shaped.sgenerate'.
+-- This class is (currently) only required for `memptyArray` and 'mgenerate'.
class Elt a => KnownElt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArrayUnsafe :: IShX sh -> Mixed sh a
@@ -391,20 +405,27 @@ class Elt a => KnownElt a where
-- this vector and an example for the contents.
mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+ -- | Create initialised vectors for this array type, given the shape of
+ -- this vector and the chosen element.
+ mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
+ {-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ {-# INLINEABLE mindexPartial #-}
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mfromListOuterSN sn l@(arr1 :| _) =
+ let sh = mshape arr1
+ in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -415,6 +436,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
@@ -426,6 +448,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -440,7 +463,7 @@ instance Storable a => Elt (Primitive a) where
=> StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
mtranspose perm (M_Primitive sh arr) =
@@ -457,27 +480,33 @@ instance Storable a => Elt (Primitive a) where
type ShapeTree (Primitive a) = ()
mshapeTree _ = ()
mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
+ mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+ mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
+ -- TODO: this use of toVectorListT is suboptimal
+ mvecsWritePartialLinear
:: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+ offset = i * shxSize arrsh
+ f off el = do
+ VS.copy (VSM.slice off (VS.length el) v) el
+ return $! off + VS.length el
+ foldM_ f offset (OI.toVectorListT sht t)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+ mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
deriving via Primitive Bool instance Elt Bool
deriving via Primitive Int instance Elt Int
deriving via Primitive Int64 instance Elt Int64
deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive Int16 instance Elt Int16
+deriving via Primitive Int8 instance Elt Int8
deriving via Primitive CInt instance Elt CInt
deriving via Primitive Double instance Elt Double
deriving via Primitive Float instance Elt Float
@@ -486,6 +515,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -493,6 +523,8 @@ deriving via Primitive Bool instance KnownElt Bool
deriving via Primitive Int instance KnownElt Int
deriving via Primitive Int64 instance KnownElt Int64
deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive Int16 instance KnownElt Int16
+deriving via Primitive Int8 instance KnownElt Int8
deriving via Primitive CInt instance KnownElt CInt
deriving via Primitive Double instance KnownElt Double
deriving via Primitive Float instance KnownElt Float
@@ -504,12 +536,15 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromListOuterSN sn l =
+ M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ {-# INLINE mlift #-}
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ {-# INLINE mlift2 #-}
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ {-# INLINE mliftL #-}
mliftL ssh2 f =
let unzipT2l [] = ([], [])
unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
@@ -531,20 +566,22 @@ instance (Elt a, Elt b) => Elt (a, b) where
type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
+ mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
+ mvecsWriteLinear i x a
+ mvecsWriteLinear i y b
+ mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartialLinear proxy i x a
+ mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+ mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y
mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-- Arrays of arrays are just arrays, but with more dimensions.
@@ -557,23 +594,23 @@ instance Elt a => Elt (Mixed sh' a) where
= fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
+ mindex (M_Nest _ arr) = mindexPartial arr
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+ mfromListOuterSN sn l@(arr :| _) =
+ M_Nest (SKnown sn :$% mshape arr)
+ (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l))
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
@@ -591,6 +628,7 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
@@ -609,6 +647,7 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
@@ -632,14 +671,14 @@ instance Elt a => Elt (Mixed sh' a) where
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
= let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
-> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ | let sh' = shxDropSh @sh @sh' sh (mshape arr)
, Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
, Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
, Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
@@ -657,28 +696,33 @@ instance Elt a => Elt (Mixed sh' a) where
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+ -- This requires that @arr@ is not empty.
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ -- the array is empty if either there are no subarrays, or the subarrays themselves are empty
+ mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
marrayStrides (M_Nest _ arr) = marrayStrides arr
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+ mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
+ mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+ = mvecsWritePartialLinear proxy idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+ mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
@@ -689,10 +733,30 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
where
sh' = mshape example
+ mvecsReplicate sh example = do
+ vecs <- mvecsUnsafeNew sh example
+ forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs
+ -- this is a slow case, but the alternative, mvecsUnsafeNew with manual
+ -- writing in a loop, leads to every case being as slow
+ return vecs
+
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWrite #-}
+mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs
+
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWritePartial #-}
+mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs
+
+-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
+memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
@@ -719,38 +783,52 @@ msize = shxSize . mshape
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
+--
+-- If your element type @a@ is a scalar, use the faster 'mgeneratePrim'.
mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f = case shxEnum sh of
[] -> memptyArrayUnsafe sh
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
+ in if mshapeTreeIsEmpty (Proxy @a) shapetree
then memptyArrayUnsafe sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
forM_ restidxs $ \idx -> do
let val = f idx
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
+ mvecsUnsafeFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
+-- | An optimized special case of 'mgenerate', where the function results
+-- are of a primitive type and so there's not need to check that all shapes
+-- are equal. This is also generalized to an arbitrary @Num@ index type
+-- compared to @mgenerate@.
+{-# INLINE mgeneratePrim #-}
+mgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => IShX sh -> (IxX sh i -> a) -> Mixed sh a
+mgeneratePrim sh f =
+ let g i = f (ixxFromLinear sh i)
+ in mfromVector sh $ VS.generate (shxSize sh) g
+
+msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+
+msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
+msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+msumAllPrim arr = msumAllPrimP (toPrimitive arr)
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -759,7 +837,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
ssh = ssxFromShX sh
- snm :: SMayNat () SNat (AddMaybe n m)
+ snm :: SMayNat () (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
@@ -781,36 +859,93 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to
+-- stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'.
+mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuter l = mfromListOuterN (length l) l
+
+-- | See 'mfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable.
+mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuterN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l)
+ Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")"
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the
+-- list.
+--
+-- If the elements are scalars, 'mfromList1Prim' is faster.
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+mfromList1 = mfromListOuter . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a
+mfromList1N n = mfromListOuterN n . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a
+mfromList1SN sn = mfromListOuterSN sn . fmap mscalar
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+-- | If the elements are scalars, 'mfromListPrimLinear' is faster.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
+
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list.
mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
+ xarr = X.fromList1 l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
+mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
+mfromList1PrimN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
+ Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN sn l =
+ let sh = SKnown sn :$% ZSX
+ in fromPrimitive $ M_Primitive sh
+ $ if Storable.sizeOf (undefined :: a) > 0
+ then X.fromList1SN sn l
+ else case l of -- don't force the list if all elements are the same
+ a0 : _ -> X.replicateScal sh a0
+ [] -> X.fromList1SN sn l
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
mtoListLinear :: Elt a => Mixed sh a -> [a]
-mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
+mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr))
+
+mtoListPrim :: PrimElt a => Mixed '[n] a -> [a]
+mtoListPrim (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr
+
+mtoListPrimLinear :: PrimElt a => Mixed sh a -> [a]
+mtoListPrimLinear (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
@@ -821,30 +956,63 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
+mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b))
+mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) =
+ let sh1 = shxDropSh sh shsh1
+ in M_Nest sh $
+ M_Primitive (shxAppend sh sh2)
+ (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
--- | See the caveats at @X.rerank@.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero),
+-- then there is no way to deduce the full shape of the output array (more
+-- precisely, the @sh2@ part): that could only come from calling @f@, and there
+-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we
+-- choose to fill the shape with zeros wherever we cannot deduce what it should
+-- be.
+--
+-- For example, if:
+--
+-- @
+-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21]
+-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int)
+-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float)
+-- @
+--
+-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the
+-- @0@ in this shape: we don't know if @f@ intended to return an array with
+-- shape 0 here (it probably didn't), but there is no better number to put here
+-- absent a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
+mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b)
+mrerankPrim sh2 f (M_Nest sh arr) =
+ let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr))
+ in M_Nest sh' (fromPrimitive arr')
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
@@ -856,20 +1024,28 @@ mreplicate sh arr =
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x)
-mreplicateScal :: forall sh a. PrimElt a
+mreplicatePrim :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x)
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
+msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
+
+msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+msliceSN i n arr =
let _ :$% sh = mshape arr
in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
+mslice :: forall i n k sh a. Elt a
+ => SMayNat Int i -> SMayNat Int n -> SMayNat Int k -> Mixed (AddMaybe (AddMaybe i n) k : sh) a -> Mixed (n : sh) a
+mslice i n k arr =
+ let _ :$% sh = mshape arr
+ uarr = mcastPartial (ssxFromShX $ smnAddMaybe (smnAddMaybe i n) k :$% ZSX) (SUnknown () :!% ZKX) Proxy arr
+ in mcastPartial (SUnknown () :!% ZKX) (ssxFromShX $ n :$% ZSX) Proxy
+ $ mlift (SUnknown () :!% ssxFromShX sh) (\_ -> X.sliceU (fromSMayNat' i) (fromSMayNat' n)) uarr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr
@@ -929,11 +1105,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+{-# INLINE mliftPrim #-}
mliftPrim :: (PrimElt a, PrimElt b)
=> (a -> b)
-> Mixed sh a -> Mixed sh b
mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+{-# INLINE mliftPrim2 #-}
mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (a -> b -> c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 2f35ff9..77256ab 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -1,9 +1,12 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -14,9 +17,11 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -31,14 +36,17 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (withDict)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import GHC.TypeLits.Orphans ()
+#endif
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
@@ -55,7 +63,7 @@ type role ListX nominal representational
type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
data ListX sh f where
ZX :: ListX '[] f
- (::%) :: f n -> ListX sh f -> ListX (n : sh) f
+ (::%) :: forall n sh {f}. f n -> ListX sh f -> ListX (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
infixr 3 ::%
@@ -100,21 +108,24 @@ listxEqual (n ::% sh) (m ::% sh')
= Just Refl
listxEqual _ _ = Nothing
+{-# INLINE listxFmap #-}
listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
listxFmap _ ZX = ZX
listxFmap f (x ::% xs) = f x ::% listxFmap f xs
-listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
-listxFold _ ZX = mempty
-listxFold f (x ::% xs) = f x <> listxFold f xs
+{-# INLINE listxFoldMap #-}
+listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+listxFoldMap _ ZX = mempty
+listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs
listxLength :: ListX sh f -> Int
-listxLength = getSum . listxFold (\_ -> Sum 1)
+listxLength = getSum . listxFoldMap (\_ -> Sum 1)
listxRank :: ListX sh f -> SNat (Rank sh)
listxRank ZX = SNat
listxRank (_ ::% l) | SNat <- listxRank l = SNat
+{-# INLINE listxShow #-}
listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
where
@@ -132,9 +143,13 @@ listxFromList topssh topl = go topssh topl
++ show (ssxLength topssh) ++ ", list has length "
++ show (length topl) ++ ")"
-listxToList :: ListX sh' (Const i) -> [i]
-listxToList ZX = []
-listxToList (Const i ::% is) = i : listxToList is
+{-# INLINEABLE listxToList #-}
+listxToList :: ListX sh (Const i) -> [i]
+listxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListX sh (Const i) -> is
+ go ZX = nil
+ go (Const i ::% is) = i `cons` go is
+ in go list)
listxHead :: ListX (mn ': sh) f -> f mn
listxHead (i ::% _) = i
@@ -146,9 +161,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f
+listxDrop ZX long = long
+listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'
listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
@@ -160,19 +175,18 @@ listxLast (x ::% ZX) = x
listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g)
listxZip ZX ZX = ZX
-listxZip (i ::% irest) (j ::% jrest) =
- Pair i j ::% listxZip irest jrest
+listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest
+{-# INLINE listxZipWith #-}
listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g
-> ListX sh h
listxZipWith _ ZX ZX = ZX
-listxZipWith f (i ::% is) (j ::% js) =
- f i j ::% listxZipWith f is js
+listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js
-- * Mixed indices
--- | This is a newtype over 'ListX'.
+-- | An index into a mixed-typed array.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
@@ -191,6 +205,8 @@ infixr 3 :.%
{-# COMPLETE ZIX, (:.%) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxX sh = IxX sh Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -201,10 +217,18 @@ instance Show i => Show (IxX sh i) where
#endif
instance Functor (IxX sh) where
+ {-# INLINE fmap #-}
fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
instance Foldable (IxX sh) where
- foldMap f (IxX l) = listxFold (f . getConst) l
+ {-# INLINE foldMap #-}
+ foldMap f (IxX l) = listxFoldMap (f . getConst) l
+ {-# INLINE foldr #-}
+ foldr _ z ZIX = z
+ foldr f z (x :.% xs) = f x (foldr f z xs)
+ toList = ixxToList
+ null ZIX = False
+ null _ = True
instance NFData i => NFData (IxX sh i)
@@ -225,6 +249,10 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)
+{-# INLINEABLE ixxToList #-}
+ixxToList :: forall sh i. IxX sh i -> [i]
+ixxToList = coerce (listxToList @_ @i)
+
ixxHead :: IxX (n : sh) i -> i
ixxHead (IxX list) = getConst (listxHead list)
@@ -234,7 +262,7 @@ ixxTail (IxX list) = IxX (listxTail list)
ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
ixxAppend = coerce (listxAppend @_ @(Const i))
-ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i
ixxDrop = coerce (listxDrop @(Const i) @(Const i))
ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
@@ -243,64 +271,58 @@ ixxInit = coerce (listxInit @(Const i))
ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))
+ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
+ixxCast ZKX ZIX = ZIX
+ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx
+ixxCast _ _ = error "ixxCast: ranks don't match"
+
ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
ixxZip ZIX ZIX = ZIX
ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
+{-# INLINE ixxZipWith #-}
ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
-ixxFromLinear :: IShX sh -> Int -> IIxX sh
-ixxFromLinear = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixxToLinear #-}
+ixxToLinear :: Num i => IShX sh -> IxX sh i -> i
+ixxToLinear = \sh i -> go sh i 0
where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$% sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSMayNat' n
- in (locali :.% idx, upi)
-
-ixxToLinear :: IShX sh -> IIxX sh -> Int
-ixxToLinear = \sh i -> fst (go sh i)
- where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$% sh) (i :.% ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
+ go :: Num i => IShX sh -> IxX sh i -> i -> i
+ go ZSX ZIX !a = a
+ go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i)
-- * Mixed shapes
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
rnf (SUnknown i) = rnf i
rnf (SKnown x) = rnf x
-instance TestEquality f => TestEquality (SMayNat i f) where
+instance TestEquality (SMayNat i) where
testEquality SUnknown{} SUnknown{} = Just Refl
testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
testEquality _ _ = Nothing
+{-# INLINE fromSMayNat #-}
fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => f m -> r)
- -> SMayNat i f n -> r
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s
-fromSMayNat' :: SMayNat Int SNat n -> Int
+{-# INLINE fromSMayNat' #-}
+fromSMayNat' :: SMayNat Int n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
type family AddMaybe n m where
@@ -308,7 +330,7 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
@@ -317,7 +339,7 @@ smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
-- | This is a newtype over 'ListX'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+newtype ShX sh i = ShX (ListX sh (SMayNat i))
deriving (Eq, Ord, Generic)
pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
@@ -326,7 +348,7 @@ pattern ZSX = ShX ZX
pattern (:$%)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+ => SMayNat i n -> ShX sh i -> ShX sh1 i
pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
where i :$% ShX shl = ShX (i ::% shl)
infixr 3 :$%
@@ -343,6 +365,7 @@ instance Show i => Show (ShX sh i) where
#endif
instance Functor (ShX sh) where
+ {-# INLINE fmap #-}
fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
instance NFData i => NFData (ShX sh i) where
@@ -390,10 +413,10 @@ shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
-shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
+shxFromList :: StaticShX sh -> [Int] -> IShX sh
shxFromList topssh topl = go topssh topl
where
- go :: StaticShX sh' -> [Int] -> ShX sh' Int
+ go :: StaticShX sh' -> [Int] -> IShX sh'
go ZKX [] = ZSX
go (SKnown sn :!% sh) (i : is)
| i == fromSNat' sn = SKnown sn :$% go sh is
@@ -404,48 +427,57 @@ shxFromList topssh topl = go topssh topl
++ show (ssxLength topssh) ++ ", list has length "
++ show (length topl) ++ ")"
+{-# INLINEABLE shxToList #-}
shxToList :: IShX sh -> [Int]
-shxToList ZSX = []
-shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+shxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: IShX sh -> is
+ go ZSX = nil
+ go (smn :$% sh) = fromSMayNat' smn `cons` go sh
+ in go list)
+
+shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
+shxFromSSX ZKX = ZSX
+shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
+ | Refl <- lemMapJustCons @sh Refl
+ = SKnown n :$% shxFromSSX sh
+shxFromSSX (SUnknown _ :!% _) = error "unreachable"
-- | This may fail if @sh@ has @Nothing@s in it.
-shxFromSSX' :: StaticShX sh -> Maybe (IShX sh)
-shxFromSSX' ZKX = Just ZSX
-shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh
-shxFromSSX' (SUnknown _ :!% _) = Nothing
+shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i)
+shxFromSSX2 ZKX = Just ZSX
+shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
+shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+shxAppend = coerce (listxAppend @_ @(SMayNat i))
-shxHead :: ShX (n : sh) i -> SMayNat i SNat n
+shxHead :: ShX (n : sh) i -> SMayNat i n
shxHead (ShX list) = listxHead list
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
-shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
-shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSSX = coerce (listxDrop @(SMayNat i) @(SMayNat ()))
-shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
-shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropIx = coerce (listxDrop @(SMayNat i) @(Const j))
-shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
-shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSh = coerce (listxDrop @(SMayNat i) @(SMayNat i))
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listxInit @(SMayNat i SNat))
+shxInit = coerce (listxInit @(SMayNat i))
-shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
-shxLast = coerce (listxLast @(SMayNat i SNat))
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
+shxLast = coerce (listxLast @(SMayNat i))
-shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shxTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
-shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+{-# INLINE shxZipWith #-}
+shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n)
-> ShX sh i -> ShX sh j -> ShX sh k
shxZipWith _ ZSX ZSX = ZSX
shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
@@ -456,28 +488,37 @@ shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
-shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
shxEnum :: IShX sh -> [IIxX sh]
-shxEnum = \sh -> go sh id []
+shxEnum = shxEnum'
+
+{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
+shxEnum' :: Num i => IShX sh -> [IxX sh i]
+shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
where
- go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZSX f = (f ZIX :)
- go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+ suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+
+ fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
+ fromLin ZSX _ _ = ZIX
+ fromLin (_ :$% sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.% fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
-shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
-shxCast ZSX ZKX = Just ZSX
-shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
-shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
+shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
+shxCast ZKX ZSX = Just ZSX
+shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
+shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh
shxCast _ _ = Nothing
-- | Partial version of 'shxCast'.
-shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
-shxCast' sh ssh = case shxCast sh ssh of
+shxCast' :: StaticShX sh' -> IShX sh -> IShX sh'
+shxCast' ssh sh = case shxCast ssh sh of
Just sh' -> sh'
Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
@@ -486,7 +527,7 @@ shxCast' sh ssh = case shxCast sh ssh of
-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat ()))
deriving (Eq, Ord)
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
@@ -495,7 +536,7 @@ pattern ZKX = StaticShX ZX
pattern (:!%)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+ => SMayNat () n -> StaticShX sh -> StaticShX sh1
pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
where i :!% StaticShX shl = StaticShX (i ::% shl)
infixr 3 :!%
@@ -531,38 +572,44 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
-ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
+ssxHead :: StaticShX (n : sh) -> SMayNat () n
ssxHead (StaticShX list) = listxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listxDrop @(SMayNat ()) @(SMayNat ()))
+
+ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat ()) @(Const i))
+
+ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSh = coerce (listxDrop @(SMayNat ()) @(SMayNat i))
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listxInit @(SMayNat () SNat))
+ssxInit = coerce (listxInit @(SMayNat ()))
-ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
-ssxLast = coerce (listxLast @(SMayNat () SNat))
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
+ssxLast = coerce (listxLast @(SMayNat ()))
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ | Refl <- lemReplicateSucc @(Nothing @Nat) n
= SUnknown () :!% ssxReplicate n
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom :: StaticShX sh -> Int -> [Int]
+ssxIotaFrom ZKX _ = []
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1)
-ssxFromShX :: IShX sh -> StaticShX sh
+ssxFromShX :: ShX sh i -> StaticShX sh
ssxFromShX ZSX = ZKX
ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh
ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n
-- | Evidence for the static part of a shape. This pops up only when you are
@@ -574,7 +621,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
-withKnownShX k = withDict @(KnownShX sh) k
+withKnownShX = withDict @(KnownShX sh)
-- * Flattening
@@ -587,18 +634,18 @@ type family Flatten' acc sh where
Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-- This function is currently unused
-ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh)
ssxFlatten = go (SNat @1)
where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh)
go acc ZKX = SKnown acc
go _ (SUnknown () :!% _) = SUnknown ()
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
-shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten :: IShX sh -> SMayNat Int (Flatten sh)
shxFlatten = go (SNat @1)
where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh)
go acc ZSX = SKnown acc
go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
@@ -626,3 +673,8 @@ instance KnownShX sh => IsList (ShX sh Int) where
type Item (ShX sh Int) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
+
+-- This needs to be at the bottom of the file to not split the file into
+-- pieces; some of the shape/index stuff refers to StaticShX.
+$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |])
+{-# INLINEABLE ixxFromLinear #-}
diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
new file mode 100644
index 0000000..2a86ac1
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
@@ -0,0 +1,59 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Nested.Mixed.Shape.Internal where
+
+import Language.Haskell.TH
+
+
+-- | A TH stub function to avoid having to write the same code three times for
+-- the three kinds of shapes.
+ixFromLinearStub :: String
+ -> TypeQ -> TypeQ
+ -> PatQ -> (PatQ -> PatQ -> PatQ)
+ -> ExpQ -> ExpQ
+ -> ExpQ
+ -> DecsQ
+ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
+ let fname = mkName fname'
+ typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]
+
+ locals <- [d|
+ -- Unfold first iteration of fromLin to do the range check.
+ -- Don't inline this function at first to allow GHC to inline the outer
+ -- function and realise that 'suffixes' is shared. But then later inline it
+ -- anyway, to avoid the function call. Removing the pragma makes GHC
+ -- somehow unable to recognise that 'suffixes' can be shared in a loop.
+ {-# NOINLINE [0] fromLin0 #-}
+ fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
+ fromLin0 sh suffixes i =
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of
+ ($zshC, _) | i > 0 -> outrange sh i
+ | otherwise -> $ixz
+ ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= n then outrange sh i else
+ $ixcons (fromIntegral q) (fromLin sh' suffs r)
+ _ -> error "impossible"
+
+ fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
+ fromLin $zshC _ !_ = $ixz
+ fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in $ixcons (fromIntegral q) (fromLin sh' suffs r)
+ fromLin _ _ _ = error "impossible"
+
+ {-# NOINLINE outrange #-}
+ outrange :: $ishty sh -> Int -> a
+ outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")" |]
+
+ body <- [|
+ \sh -> -- give this function arity 1 so that 'suffixes' is shared when
+ -- it's called many times
+ let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
+ in fromLin0 sh suffixes |]
+
+ return [SigD fname typesig
+ ,FunD fname [Clause [] (NormalB body) locals]]
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 031755f..6bebcfb 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -1,10 +1,10 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,6 +25,7 @@ import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
+import GHC.Exts (withDict)
import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
@@ -36,8 +37,8 @@ import Data.Array.Nested.Types
-- * Permutations
-- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
+-- dimension list is most similar to @backpermute@ in the @vector@ package; see
+-- 'Permute' for code that implements this.
data Perm list where
PNil :: Perm '[]
PCons :: SNat a -> Perm l -> Perm (a : l)
@@ -45,15 +46,22 @@ infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
+instance TestEquality Perm where
+ testEquality PNil PNil = Just Refl
+ testEquality (x `PCons` xs) (y `PCons` ys)
+ | Just Refl <- testEquality x y
+ , Just Refl <- testEquality xs ys = Just Refl
+ testEquality _ _ = Nothing
+
permRank :: Perm list -> SNat (Rank list)
permRank PNil = SNat
permRank (_ `PCons` l) | SNat <- permRank l = SNat
-permFromList :: [Int] -> (forall list. Perm list -> r) -> r
-permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r
+permFromListCont [] k = k PNil
+permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Nested.Permutation.permFromListCont: negative number in list: " ++ show x
permToList :: Perm list -> [Natural]
permToList PNil = mempty
@@ -119,6 +127,9 @@ class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r
+withKnownPerm = withDict @(KnownPerm l)
+
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -190,22 +201,22 @@ ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is
ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+ssxTakeLen = coerce (listxTakeLen @(SMayNat ()))
ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+ssxDropLen = coerce (listxDropLen @(SMayNat ()))
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+ssxPermute = coerce (listxPermute @(SMayNat ()))
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh)
+ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ()))
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int))
-- * Operations on permutations
@@ -224,7 +235,7 @@ permInverse = \perm k ->
++ " ; invperm = " ++ show invperm)
(permCheckPermutation invperm
(k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
+ (\ssh -> case permCheckInverse perm invperm ssh of
Just eq -> eq
Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
++ " ; invperm = " ++ show invperm)))
@@ -238,9 +249,9 @@ permInverse = \perm k ->
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ permCheckInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
+ permCheckInverse perm perminv ssh =
ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
@@ -264,7 +275,13 @@ lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
=> StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
lemRankDropLen ZKX PNil = Refl
-lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is)
+ | Refl <- lemRankDropLen sh is
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+ = Refl
+#else
+ = unsafeCoerceRefl
+#endif
lemRankDropLen (_ :!% _) PNil = Refl
lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index e5c51ef..f668c3e 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -29,17 +29,17 @@ import Foreign.Storable (Storable)
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray(..))
-import Data.Array.XArray qualified as X
import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
remptyArray :: KnownElt a => Ranked 1 a
@@ -49,9 +49,11 @@ remptyArray = mtoRanked (memptyArray ZSX)
rsize :: Elt a => Ranked n a -> Int
rsize = shrSize . rshape
+{-# INLINEABLE rindex #-}
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)
+{-# INLINEABLE rindexPartial #-}
rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
@@ -59,7 +61,8 @@ rindexPartial (Ranked arr) idx =
(ixxFromIxR idx))
-- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
+-- See the documentation of 'mgenerate' for more details; see also
+-- 'rgeneratePrim'.
rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
| sn@SNat <- shrRank sh
@@ -67,7 +70,16 @@ rgenerate sh f
, Refl <- lemRankReplicate sn
= Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
+-- | See 'mgeneratePrim'.
+{-# INLINE rgeneratePrim #-}
+rgeneratePrim :: forall n a i. (PrimElt a, Num i)
+ => IShR n -> (IxR n i -> a) -> Ranked n a
+rgeneratePrim sh f =
+ let g i = f (ixrFromLinear sh i)
+ in rfromVector sh $ VS.generate (shrSize sh) g
+
-- | See the documentation of 'mlift'.
+{-# INLINE rlift #-}
rlift :: forall n1 n2 a. Elt a
=> SNat n2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
@@ -75,22 +87,26 @@ rlift :: forall n1 n2 a. Elt a
rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-- | See the documentation of 'mlift2'.
+{-# INLINE rlift2 #-}
rlift2 :: forall n1 n2 n3 a. Elt a
=> SNat n3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (msumOuter1P arr)
+rsumOuter1PrimP :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1PrimP (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (msumOuter1PrimP arr)
+
+rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
+rsumAllPrimP (Ranked arr) = msumAllPrimP arr
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -108,7 +124,7 @@ rtranspose perm arr
rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
rconcat
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= coerce mconcat
rappend :: forall n a. Elt a
@@ -116,7 +132,7 @@ rappend :: forall n a. Elt a
rappend arr1 arr2
| sn@SNat <- rrank arr1
, Dict <- lemKnownReplicate sn
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n)
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
arr1 arr2
@@ -137,51 +153,82 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromListOuterN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'.
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+-- | See 'rfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable.
+rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuterN n l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromList1N' to be able to stream the list.
+--
+-- If the elements are scalars, 'rfromList1Prim' is faster.
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
+rfromList1 = coerce mfromList1
+
+-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a
+rfromList1N = coerce mfromList1N
+
+-- | If the elements are scalars, 'rfromListPrimLinear' is faster.
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l)
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'rfromList1PrimN' to be able to stream the list.
rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
+rfromList1Prim = coerce mfromList1Prim
+
+rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a
+rfromList1PrimN = coerce mfromList1PrimN
+
+rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l)
rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoListOuter (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
-
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
-
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
rtoListLinear :: Elt a => Ranked n a -> [a]
rtoListLinear (Ranked arr) = mtoListLinear arr
+rtoListPrim :: PrimElt a => Ranked 1 a -> [a]
+rtoListPrim (Ranked arr) = mtoListPrim arr
+
+rtoListPrimLinear :: PrimElt a => Ranked n a -> [a]
+rtoListPrimLinear (Ranked arr) = mtoListPrimLinear arr
+
rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
rfromOrthotope sn arr
| Refl <- lemRankReplicate sn
= let xarr = XArray arr
in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
-rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh)
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -197,22 +244,20 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
| Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked arr
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
runzip = coerce munzip
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
+rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b))
+rrerankPrimP sh2 f (Ranked (M_Ranked arr))
+ = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
-- input array, then there is no way to deduce the full shape of the output
@@ -223,26 +268,28 @@ rrerankP sn sh2 f (Ranked arr)
-- For example, if:
--
-- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21]
-- f :: Ranked 2 Int -> Ranked 3 Float
-- @
--
-- then:
--
-- @
--- rrerank _ _ _ f arr :: Ranked 5 Float
+-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float)
-- @
--
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank sn sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
+-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't
+-- know if @f@ intended to return an array with all-zero shape here (it
+-- probably didn't), but there is no better number to put here absent a
+-- subarray of the input to pass to @f@.
+rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b)
+rrerankPrim sh2 f (Ranked (M_Ranked arr)) =
+ Ranked (M_Ranked (mrerankPrim (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
@@ -250,29 +297,24 @@ rreplicate sh (Ranked arr)
| Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked (mreplicate (shxFromShR sh) arr)
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
+rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicatePrimP sh x
| Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mreplicateScalP (shxFromShR sh) x)
+ = Ranked (mreplicatePrimP (shxFromShR sh) x)
-rreplicateScal :: forall n a. PrimElt a
+rreplicatePrim :: forall n a. PrimElt a
=> IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (rrank arr)
- (\_ -> X.sliceU i n)
- arr
+rslice i n (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (msliceN i n arr)
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (rrank arr)
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
+rrev1 (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mrev1 arr)
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index f50f671..834e139 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,6 +26,7 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -34,13 +36,13 @@ import GHC.TypeLits
import Data.Foldable (toList)
#endif
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray(..))
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -95,13 +97,14 @@ instance Elt a => Elt (Ranked n a) where
mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat m -> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Just m : sh) (Ranked n a)
+ mfromListOuterSN sn l = M_Ranked (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
mtoListOuter (M_Ranked arr) =
coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -110,6 +113,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -118,6 +122,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -141,24 +146,25 @@ instance Elt a => Elt (Ranked n a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
marrayStrides (M_Ranked arr) = marrayStrides arr
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh' (Ranked n a))
@(Mixed sh' (Mixed (Replicate n Nothing) a))
arr)
@@ -174,18 +180,30 @@ instance Elt a => Elt (Ranked n a) where
(coerce @(MixedVecs s sh (Ranked n a))
@(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownReplicate (SNat @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
@@ -208,15 +226,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where
negate = liftRanked1 negate
abs = liftRanked1 abs
signum = liftRanked1 signum
- fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
+ fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
+ fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicatePrim"
recip = liftRanked1 recip
(/) = liftRanked2 (/)
instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
+ pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicatePrim"
exp = liftRanked1 exp
log = liftRanked1 log
sqrt = liftRanked1 sqrt
@@ -252,3 +270,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
+shrFromShX ZSX = ZSR
+shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
+shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
+shrFromShX2 sh
+ | Refl <- lemRankReplicate (Proxy @n)
+ = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index c0c4f17..b6bee2e 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,13 +1,14 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -18,9 +19,11 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -33,17 +36,20 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
+import GHC.Exts (Int(..), Int#, build, quotRemInt#)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
+-- * Ranked lists
+
type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
@@ -51,8 +57,6 @@ data ListR n i where
(:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
infixr 3 :::
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -66,6 +70,22 @@ instance NFData i => NFData (ListR n i) where
rnf ZR = ()
rnf (x ::: l) = rnf x `seq` rnf l
+instance Functor (ListR n) where
+ {-# INLINE fmap #-}
+ fmap _ ZR = ZR
+ fmap f (x ::: xs) = f x ::: fmap f xs
+
+instance Foldable (ListR n) where
+ {-# INLINE foldMap #-}
+ foldMap _ ZR = mempty
+ foldMap f (x ::: xs) = f x <> foldMap f xs
+ {-# INLINE foldr #-}
+ foldr _ z ZR = z
+ foldr f z (x ::: xs) = f x (foldr f z xs)
+ toList = listrToList
+ null ZR = False
+ null _ = True
+
data UnconsListRRes i n1 =
forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
@@ -90,6 +110,7 @@ listrEqual (i ::: sh) (j ::: sh')
= Just Refl
listrEqual _ _ = Nothing
+{-# INLINE listrShow #-}
listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
where
@@ -108,27 +129,41 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+listrFromList :: SNat n -> [i] -> ListR n i
+listrFromList topsn topl = go topsn topl
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error $ "listrFromList: Mismatched list length (type says "
+ ++ show (fromSNat topsn) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+{-# INLINEABLE listrToList #-}
+listrToList :: ListR n i -> [i]
+listrToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListR n i -> is
+ go ZR = nil
+ go (i ::: is) = i `cons` go is
+ in go list)
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
listrInit :: ListR (n + 1) i -> ListR n i
listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
listrLast :: ListR (n + 1) i -> i
listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
+
+-- | Performs a runtime check that the lengths are identical.
+listrCast :: SNat n' -> ListR n i -> ListR n' i
+listrCast = listrCastWithName "listrCast"
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
@@ -140,6 +175,7 @@ listrZip ZR ZR = ZR
listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+{-# INLINE listrZipWith #-}
listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
listrZipWith _ ZR ZR = ZR
listrZipWith f (i ::: irest) (j ::: jrest) =
@@ -149,13 +185,15 @@ listrZipWith _ _ _ =
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case listrRank sh of { shlen@SNat ->
+ let sperm = listrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
where
listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
listrSplitAt SZ sh = (ZR, sh)
@@ -172,6 +210,8 @@ listrPermutePrefix = \perm sh ->
GTI -> error "listrPermutePrefix: Index in permutation out of range"
+-- * Ranked indices
+
-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
@@ -192,6 +232,8 @@ infixr 3 :.:
{-# COMPLETE ZIR, (:.:) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxR n = IxR n Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -213,15 +255,12 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
-ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX ZIX = ZIR
-ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
+ixrFromList = coerce (listrFromList @_ @i)
-ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
-ixxFromIxR ZIR = ZIX
-ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixxFromIxR idx)
+{-# INLINEABLE ixrToList #-}
+ixrToList :: forall n i. IxR n i -> [i]
+ixrToList = coerce (listrToList @_ @i)
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -235,18 +274,38 @@ ixrInit (IxR list) = IxR (listrInit list)
ixrLast :: IxR (n + 1) i -> i
ixrLast (IxR list) = listrLast list
+-- | Performs a runtime check that the lengths are identical.
+ixrCast :: SNat n' -> IxR n i -> IxR n' i
+ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
+
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+{-# INLINE ixrZipWith #-}
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixrToLinear #-}
+ixrToLinear :: Num i => IShR m -> IxR m i -> i
+ixrToLinear = \sh i -> go sh i 0
+ where
+ -- Additional argument: index, in the @m - m1@ dimensional array so far,
+ -- of the @m - m1 + n@ dimensional tensor pointed to by the current
+ -- @m - m1@ dimensional index prefix.
+ go :: Num i => IShR m1 -> IxR m1 i -> i -> i
+ go ZSR ZIR a = a
+ go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i)
+
+
+-- * Ranked shapes
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
@@ -278,22 +337,6 @@ instance Show i => Show (ShR n i) where
instance NFData i => NFData (ShR sh i)
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh
-
-shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
-shxFromShR ZSR = ZSX
-shxFromShR (n :$: (idx :: ShR m i)) =
- castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shxFromShR idx)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
@@ -317,6 +360,13 @@ shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh
+shrFromList :: forall n i. SNat n -> [i] -> ShR n i
+shrFromList = coerce (listrFromList @_ @i)
+
+{-# INLINEABLE shrToList #-}
+shrToList :: forall n i. ShR n i -> [i]
+shrToList = coerce (listrToList @_ @i)
+
shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list
@@ -329,30 +379,44 @@ shrInit (ShR list) = ShR (listrInit list)
shrLast :: ShR (n + 1) i -> i
shrLast (ShR list) = listrLast list
+-- | Performs a runtime check that the lengths are identical.
+shrCast :: SNat n' -> ShR n i -> ShR n' i
+shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrEnum :: IShR sh -> [IIxR sh]
+shrEnum = shrEnum'
+
+{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
+shrEnum' :: Num i => IShR sh -> [IxR sh i]
+shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shrToList sh))
+
+ fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
+ fromLin ZSR _ _ = ZIR
+ fromLin (_ :$: sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.: fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ListR n i) where
type Item (ListR n i) = i
- fromList topl = go (SNat @n) topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
- ++ show (fromSNat (SNat @n)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = listrFromList (SNat @n)
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
@@ -366,3 +430,14 @@ instance KnownNat n => IsList (ShR n i) where
type Item (ShR n i) = i
fromList = ShR . IsList.fromList
toList = Foldable.toList
+
+
+-- * Internal helper functions
+
+listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
+listrCastWithName _ SZ ZR = ZR
+listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
+
+$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
+{-# INLINEABLE ixrFromLinear #-}
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 7e38aee..acb7c89 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -29,21 +29,20 @@ import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
-import Data.Array.XArray qualified as X
-import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+import Data.Array.XArray qualified as X
-semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
+semptyArray :: forall sh a. KnownElt a => ShS sh -> Shaped (0 : sh) a
semptyArray sh = Shaped (memptyArray (shxFromShS sh))
srank :: Elt a => Shaped sh a -> SNat (Rank sh)
@@ -53,13 +52,16 @@ srank = shsRank . sshape
ssize :: Elt a => Shaped sh a -> Int
ssize = shsSize . sshape
+{-# INLINEABLE sindex #-}
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+{-# INLINEABLE shsTakeIx #-}
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh
shsTakeIx _ _ ZIS = ZSS
shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+{-# INLINEABLE sindexPartial #-}
sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial sarr@(Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
@@ -71,7 +73,16 @@ sindexPartial sarr@(Shaped arr) idx =
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+-- | See 'mgeneratePrim'.
+{-# INLINE sgeneratePrim #-}
+sgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => ShS sh -> (IxS sh i -> a) -> Shaped sh a
+sgeneratePrim sh f =
+ let g i = f (ixsFromLinear sh i)
+ in sfromVector sh $ VS.generate (shsSize sh) g
+
-- | See the documentation of 'mlift'.
+{-# INLINE slift #-}
slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -79,19 +90,23 @@ slift :: forall sh1 sh2 a. Elt a
slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
-- | See the documentation of 'mlift'.
+{-# INLINE slift2 #-}
slift2 :: forall sh1 sh2 sh3 a. Elt a
=> ShS sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+
+ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
+ssumAllPrimP (Shaped arr) = msumAllPrimP arr
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -124,39 +139,54 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromListOuterSN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'.
sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
+sfromListOuter = coerce mfromListOuterSN
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromList1SN' to be able to stream the list.
+--
+-- If the elements are scalars, 'sfromList1Prim' is faster.
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
+sfromList1 = coerce mfromList1SN
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+-- | If the elements are scalars, 'sfromListPrimLinear' is faster.
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'sfromList1PrimN' to be able to stream the list.
+sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim = coerce mfromList1PrimSN
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l)
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
stoListLinear :: Elt a => Shaped sh a -> [a]
stoListLinear (Shaped arr) = mtoListLinear arr
+stoListPrim :: PrimElt a => Shaped '[n] a -> [a]
+stoListPrim (Shaped arr) = mtoListPrim arr
+
+stoListPrimLinear :: PrimElt a => Shaped sh a -> [a]
+stoListPrimLinear (Shaped arr) = mtoListPrimLinear arr
+
sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
sfromOrthotope sh (SS.A (SG.A arr)) =
Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))
@@ -177,41 +207,41 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
| Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
= Shaped arr
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
sunzip = coerce munzip
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh))))
- (shxFromShS sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
+srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b))
+srerankPrimP sh2 f (Shaped (M_Shaped arr))
+ = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+-- | See the caveats at 'mrerankPrim'.
+srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b)
+srerankPrim sh2 f (Shaped (M_Shaped arr)) =
+ Shaped (M_Shaped (mrerankPrim (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
sreplicate sh (Shaped arr)
| Refl <- lemMapJustApp sh (Proxy @sh')
= Shaped (mreplicate (shxFromShS sh) arr)
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)
+sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x)
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x)
sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 529ac21..16c1b05 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,18 +26,19 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray)
-import Data.Array.Nested.Internal.Lemmas
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
-- | A shape-typed array: the full shape of the array (the sizes of its
@@ -88,13 +90,14 @@ instance Elt a => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a)
+ mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
mtoListOuter (M_Shaped arr)
= coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -103,6 +106,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -111,6 +115,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -134,24 +139,25 @@ instance Elt a => Elt (Shaped sh a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
marrayStrides (M_Shaped arr) = marrayStrides arr
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWriteLinear idx (Shaped arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh2 (Shaped sh a))
@(Mixed sh2 (Mixed (MapJust sh) a))
arr)
@@ -167,18 +173,30 @@ instance Elt a => Elt (Shaped sh a) where
(coerce @(MixedVecs s sh' (Shaped sh a))
@(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
@@ -201,15 +219,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
negate = liftShaped1 negate
abs = liftShaped1 abs
signum = liftShaped1 signum
- fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim"
recip = liftShaped1 recip
(/) = liftShaped2 (/)
instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim"
exp = liftShaped1 exp
log = liftShaped1 log
sqrt = liftShaped1 sqrt
@@ -242,3 +260,12 @@ satan2Array = liftShaped2 matan2Array
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0b7d1c9..0c042b7 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,13 +1,12 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -18,9 +17,11 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -37,17 +38,22 @@ import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (withDict)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
+-- * Shaped lists
+
+-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
+-- removed in a future release.
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
@@ -98,13 +104,15 @@ listsEqual (n ::$ sh) (m ::$ sh')
= Just Refl
listsEqual _ _ = Nothing
+{-# INLINE listsFmap #-}
listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFold _ ZS = mempty
-listsFold f (x ::$ xs) = f x <> listsFold f xs
+{-# INLINE listsFoldMap #-}
+listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFoldMap _ ZS = mempty
+listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs
listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
@@ -114,15 +122,40 @@ listsShow f l = showString "[" . go "" l . showString "]"
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFold (\_ -> Sum 1)
+listsLength = getSum . listsFoldMap (\_ -> Sum 1)
listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+listsFromList :: ShS sh -> [i] -> ListS sh (Const i)
+listsFromList topsh topl = go topsh topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "listsFromList: Mismatched list length (type says "
+ ++ show (shsLength topsh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+{-# INLINEABLE listsFromListS #-}
+listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i)
+listsFromListS topl0 topl = go topl0 topl
+ where
+ go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i)
+ go ZS [] = ZS
+ go (_ ::$ l0) (i : is) = Const i ::$ go l0 is
+ go _ _ = error $ "listsFromListS: Mismatched list length (the model says "
+ ++ show (listsLength topl0) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+{-# INLINEABLE listsToList #-}
listsToList :: ListS sh (Const i) -> [i]
-listsToList ZS = []
-listsToList (Const i ::$ is) = i : listsToList is
+listsToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListS sh (Const i) -> is
+ go ZS = nil
+ go (Const i ::$ is) = i `cons` go is
+ in go list)
listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i
@@ -144,14 +177,13 @@ listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) =
- Fun.Pair i j ::$ listsZip is js
+listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js
+{-# INLINE listsZipWith #-}
listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
-> ListS sh h
listsZipWith _ ZS ZS = ZS
-listsZipWith f (i ::$ is) (j ::$ js) =
- f i j ::$ listsZipWith f is js
+listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js
listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
@@ -180,11 +212,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape"
listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+-- * Shaped indices
-- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
@@ -193,6 +223,8 @@ newtype IxS sh i = IxS (ListS sh (Const i))
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
+-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be
+-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (KnownNat n, n : sh ~ sh1)
@@ -203,6 +235,8 @@ infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxS sh = IxS sh Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -213,10 +247,18 @@ instance Show i => Show (IxS sh i) where
#endif
instance Functor (IxS sh) where
+ {-# INLINE fmap #-}
fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
instance Foldable (IxS sh) where
- foldMap f (IxS l) = listsFold (f . getConst) l
+ {-# INLINE foldMap #-}
+ foldMap f (IxS l) = listsFoldMap (f . getConst) l
+ {-# INLINE foldr #-}
+ foldr _ z ZIS = z
+ foldr f z (x :.$ xs) = f x (foldr f z xs)
+ toList = ixsToList
+ null ZIS = False
+ null _ = True
instance NFData i => NFData (IxS sh i)
@@ -226,18 +268,21 @@ ixsLength (IxS l) = listsLength l
ixsRank :: IxS sh i -> SNat (Rank sh)
ixsRank (IxS l) = listsRank l
+ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
+ixsFromList = coerce (listsFromList @_ @i)
+
+{-# INLINEABLE ixsFromIxS #-}
+ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS = coerce (listsFromListS @_ @i0 @i)
+
+{-# INLINEABLE ixsToList #-}
+ixsToList :: forall sh i. IxS sh i -> [i]
+ixsToList = coerce (listsToList @_ @i)
+
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX ZSS ZIX = ZIS
-ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
-
-ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
-ixxFromIxS ZIS = ZIX
-ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
-
ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)
@@ -250,20 +295,39 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
+ixsCast ZSS ZIS = ZIS
+ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
+ixsCast _ _ = error "ixsCast: ranks don't match"
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
-ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
+ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
-ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
+{-# INLINE ixsZipWith #-}
+ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k
ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear = \sh i -> go sh i 0
+ where
+ go :: Num i => ShS sh -> IxS sh i -> i -> i
+ go ZSS ZIS a = a
+ go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i)
+
+
+-- * Shaped shapes
-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
@@ -272,7 +336,10 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
type role ShS nominal
type ShS :: [Nat] -> Type
newtype ShS sh = ShS (ListS sh SNat)
- deriving (Eq, Ord, Generic)
+ deriving (Generic)
+
+instance Eq (ShS sh) where _ == _ = True
+instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
pattern ZSS = ShS ZS
@@ -317,26 +384,28 @@ shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh
-shsToList :: ShS sh -> [Int]
-shsToList ZSS = []
-shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-
-shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
+-- | This is a partial @const@ that fails when the second argument
+-- doesn't match the first.
+shsFromList :: ShS sh -> [Int] -> ShS sh
+shsFromList topsh topl = go topsh topl `seq` topsh
where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shsFromShX (SUnknown _ :$% _) = error "impossible"
+ go :: ShS sh' -> [Int] -> ()
+ go ZSS [] = ()
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = go sh is
+ | otherwise = error $ "shsFromList: Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "shsFromList: Mismatched list length (type says "
+ ++ show (shsLength topsh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
-shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+{-# INLINEABLE shsToList #-}
+shsToList :: ShS sh -> [Int]
+shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) ->
+ let go :: ShS sh -> is
+ go ZSS = nil
+ go (sn :$$ sh) = fromSNat' sn `cons` go sh
+ in go topsh)
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list
@@ -381,7 +450,7 @@ instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
-withKnownShS k = withDict @(KnownShS sh) k
+withKnownShS = withDict @(KnownShS sh)
shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
@@ -391,18 +460,27 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shsToList sh))
+
+ fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i
+ fromLin ZSS _ _ = ZIS
+ fromLin (_ :$$ sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh'
+ in fromIntegral (I# q#) :.$ fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = listsFromList (knownShS @sh)
toList = listsToList
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
@@ -414,15 +492,8 @@ instance KnownShS sh => IsList (IxS sh i) where
-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = shsFromList (knownShS @sh)
toList = shsToList
+
+$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |])
+{-# INLINEABLE ixsFromLinear #-}
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 838e2b0..dfa5129 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -5,21 +5,28 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# OPTIONS -Wno-simplifiable-class-constraints #-}
{-|
This module is API-compatible with "Data.Array.Nested", except that inputs and
-outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods
-also have additional 'Show' constraints.
+outputs of the methods are traced to 'stderr'. Thus the methods also have
+additional 'Show' constraints.
->>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7))
->>> length (show res) `seq` ()
-oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))]
-oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))]
-oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))]
-oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))]
-oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))]
-oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))]
->>> res
-Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35]))))
+>>> rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7))
+oxtrace: (riota _ ... = rfromListLinear [6] [0,1,2,3,4,5])
+oxtrace: (rreshape [2,3] (rfromListLinear [6] [0,1,2,3,4,5]) ... = rfromListLinear [2,3] [0,1,2,3,4,5])
+oxtrace: (rtranspose [1,0] (rfromListLinear [2,3] [0,1,2,3,4,5]) ... = rfromListLinear [3,2] [0,3,1,4,2,5])
+oxtrace: (rscalar _ ... = rfromListLinear [] [7])
+oxtrace: (rreplicate [6] (rfromListLinear [] [7]) ... = rreplicate [6] 7)
+oxtrace: (rreshape [3,2] (rreplicate [6] 7) ... = rreplicate [3,2] 7)
+rfromListLinear [3,2] [0,21,7,28,14,35]
+
+The part up until and including the @...@ is printed after @seq@ing the
+arguments; the @=@ and further is printed after @seq@ing the result of the
+operation. Do note that tracing means that the functions in this module are
+potentially __stricter__ than the plain ones in "Data.Array.Nested".
+
+Arguments that this module does not know how to @show@, probably due to
+laziness on my side, are printed as @_@.
-}
module Data.Array.Nested.Trace (
-- * Traced variants
@@ -37,10 +44,12 @@ module Data.Array.Nested.Trace (
ShS(..), KnownShS(..),
Mixed,
+ ListX(ZX, (::%)),
IxX(..), IIxX,
- ShX(..), KnownShX(..),
+ ShX(..), KnownShX(..), IShX,
StaticShX(..),
SMayNat(..),
+ Conversion(..),
Elt,
PrimElt,
@@ -51,10 +60,10 @@ module Data.Array.Nested.Trace (
Storable,
SNat, pattern SNat,
pattern SZ, pattern SS,
- Perm(..),
+ Perm(..), PermR,
IsPermutation,
KnownPerm(..),
- NumElt, FloatElt,
+ NumElt, IntElt, FloatElt,
Rank, Product,
Replicate,
MapJust,
@@ -67,4 +76,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rgeneratePrim, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoListOuter, 'rtoList, 'rtoListLinear, 'rtoListPrim, 'rtoListPrimLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'sgeneratePrim, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoListOuter, 'stoList, 'stoListLinear, 'stoListPrim, 'stoListPrimLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'mgeneratePrim, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoListOuter, 'mtoList, 'mtoListLinear, 'mtoListPrim, 'mtoListPrimLinear, 'msliceN, 'msliceSN, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs
index 4b388e3..644b4bd 100644
--- a/src/Data/Array/Nested/Trace/TH.hs
+++ b/src/Data/Array/Nested/Trace/TH.hs
@@ -4,11 +4,11 @@
module Data.Array.Nested.Trace.TH where
import Control.Monad (zipWithM)
-import Data.List (foldl', intersperse)
+import Data.List (foldl')
import Data.Maybe (isJust)
import Language.Haskell.TH hiding (cxt)
-
-import Debug.Trace qualified as Debug
+import System.IO (hPutStr, stderr)
+import System.IO.Unsafe (unsafePerformIO)
import Data.Array.Nested
@@ -20,7 +20,7 @@ splitFunTy = \case
in (vars, cx, t1 : args, ret)
ForallT vs cx' t ->
let (vars, cx, args, ret) = splitFunTy t
- in (vars ++ vs, cx ++ cx', args, ret)
+ in (vs ++ vars, cx' ++ cx, args, ret)
t -> ([], [], [], t)
data Arg = RRanked Type Arg
@@ -30,17 +30,27 @@ data Arg = RRanked Type Arg
| ROther Type
deriving (Show)
--- TODO: always returns Just
recognise :: Type -> Maybe Arg
recognise (ConT name `AppT` sht `AppT` ty)
- | name == ''Ranked = RRanked sht <$> recognise ty
- | name == ''Shaped = RShaped sht <$> recognise ty
- | name == ''Mixed = RMixed sht <$> recognise ty
+ | name == ''Ranked = Just (RRanked sht (recogniseElt ty))
+ | name == ''Shaped = Just (RShaped sht (recogniseElt ty))
+ | name == ''Mixed = Just (RMixed sht (recogniseElt ty))
+ | name == ''Conversion = Just (RShowable ty)
recognise ty@(ConT name `AppT` _)
- | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] =
+ | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat, ''Perm] =
Just (RShowable ty)
+recognise ty@(ConT name)
+ | name == ''PermR = Just (RShowable ty)
+recognise (ListT `AppT` ty) = Just (ROther ty)
recognise _ = Nothing
+recogniseElt :: Type -> Arg
+recogniseElt (ConT name `AppT` sht `AppT` ty)
+ | name == ''Ranked = RRanked sht (recogniseElt ty)
+ | name == ''Shaped = RShaped sht (recogniseElt ty)
+ | name == ''Mixed = RMixed sht (recogniseElt ty)
+recogniseElt ty = ROther ty
+
realise :: Arg -> Type
realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty
realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty
@@ -62,37 +72,58 @@ mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT ''
mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty]
mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty]
-convertType :: Type -> Q (Type, [Bool], Bool)
+-- If you pass a polymorphic function to seq, GHC wants to monomorphise and
+-- doesn't know how to instantiate the type variables. Just don't, I guess.
+isSeqable :: Type -> Bool
+isSeqable ForallT{} = False
+isSeqable (AppT a b) = isSeqable a && isSeqable b
+isSeqable _ = True -- yolo, I guess
+
+convertType :: Type -> Q (Type, [Bool], [Bool], Bool)
convertType typ =
let (tybndrs, cxt, args, ret) = splitFunTy typ
- argrels = map recognise args
- retrel = recognise ret
+ argdescrs = map recognise args
+ retdescr = recognise ret
in return
(ForallT tybndrs
(cxt ++ [constr
- | Just rel <- retrel : argrels
+ | Just rel <- retdescr : argdescrs
, constr <- mkShow rel])
(foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args)
- ,map isJust argrels
- ,isJust retrel)
+ ,map isJust argdescrs
+ ,map isSeqable args
+ ,isJust retdescr)
convertFun :: Name -> Q [Dec]
convertFun funname = do
defname <- newName (nameBase funname)
- (convty, argarrs, retarr) <- reifyType funname >>= convertType
- names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..]
+ -- "ok": whether we understand this type enough to be able to show it
+ (convty, argoks, argsseqable, retok) <- reifyType funname >>= convertType
+ names <- zipWithM (\_ i -> newName ('x' : show i)) argoks [1::Int ..]
+ -- let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
resname <- newName "res"
- let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
+ let traceCall str val = VarE 'traceNoNewline `AppE` str `AppE` val
+ let msg1 = [LitE (StringL ("oxtrace: (" ++ nameBase funname ++ " "))] ++
+ [if ok
+ then VarE 'showsPrec `AppE` LitE (IntegerL 11) `AppE` VarE n `AppE` LitE (StringL " ")
+ else LitE (StringL "_ ")
+ | (n, ok) <- zip names argoks] ++
+ [LitE (StringL "...")]
+ let msg2 | retok = [LitE (StringL " = "), VarE 'show `AppE` VarE resname, LitE (StringL ")\n")]
+ | otherwise = [LitE (StringL " = _)\n")]
let ex = LetE [ValD (VarP resname)
(NormalB (foldl' AppE (VarE funname) (map VarE names)))
- []]
- (VarE 'Debug.trace
- `AppE` (VarE 'concat `AppE` ListE
- ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++
- intersperse (LitE (StringL ", "))
- (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++
- [LitE (StringL "]")]))
- `AppE` VarE resname)
+ []] $
+ flip (foldr AppE) [VarE 'seq `AppE` VarE n | (n, True) <- zip names argsseqable] $
+ traceCall (VarE 'concat `AppE` ListE msg1) $
+ VarE 'seq `AppE` VarE resname `AppE`
+ traceCall (VarE 'concat `AppE` ListE msg2) (VarE resname)
return
[SigD defname convty
,FunD defname [Clause (map VarP names) (NormalB ex) []]]
+
+{-# NOINLINE traceNoNewline #-}
+traceNoNewline :: String -> a -> a
+traceNoNewline str x = unsafePerformIO $ do
+ hPutStr stderr str
+ return x
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index 4172fa0..1dce868 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -6,7 +6,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
@@ -30,6 +30,7 @@ module Data.Array.Nested.Types (
Replicate,
lemReplicateSucc,
MapJust,
+ lemMapJustEmpty, lemMapJustCons,
Head,
Tail,
Init,
@@ -45,7 +46,6 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Unsafe.Coerce qualified
-
-- Reasoning helpers
subst1 :: forall f a b. a :~: b -> f a :~: f b
@@ -58,8 +58,9 @@ subst2 Refl = Refl
data Dict c a where
Dict :: c a => Dict c a
+{-# INLINE fromSNat' #-}
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
sameNat' n@SNat m@SNat = sameNat n m
@@ -108,13 +109,20 @@ type family Replicate n a where
Replicate 0 a = '[]
Replicate n a = a : Replicate (n - 1) a
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerceRefl
+lemReplicateSucc :: forall a n proxy.
+ proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc _ = unsafeCoerceRefl
-type family MapJust l where
+type family MapJust l = r | r -> l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
+lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[]
+lemMapJustEmpty Refl = unsafeCoerceRefl
+
+lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
+lemMapJustCons Refl = unsafeCoerceRefl
+
type family Head l where
Head (x : _) = x
diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs
index 5c38d14..e2cd17c 100644
--- a/src/Data/Array/Strided/Orthotope.hs
+++ b/src/Data/Array/Strided/Orthotope.hs
@@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve
toO :: AS.Array n a -> RS.Array n a
toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
+{-# INLINE liftO1 #-}
liftO1 :: (AS.Array n a -> AS.Array n' b)
-> RS.Array n a -> RS.Array n' b
liftO1 f = toO . f . fromO
+{-# INLINE liftO2 #-}
liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
-> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
liftO2 f x y = toO (f (fromO x) (fromO y))
+-- We don't inline this lifting function, because its code is not just
+-- a wrapper, being relatively long and expensive.
+{-# INLINEABLE liftVEltwise1 #-}
liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
-> (VS.Vector a -> VS.Vector b)
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index dde06e3..7dcfd5e 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -1,8 +1,11 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -14,27 +17,33 @@
module Data.Array.XArray where
import Control.DeepSeq (NFData)
+import Control.Monad (foldM_, foldM)
+import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
import Data.Array.Internal.RankedS qualified as ORS
-import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
+import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
+import Data.Vector.Generic.Checked qualified as VGC
import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import Unsafe.Coerce (unsafeCoerce)
+#endif
-import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
-import Data.Array.Nested.Mixed.Shape
import Data.Array.Strided.Orthotope
@@ -108,15 +117,23 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-- XArray . S.fromVector (shxShapeL sh)
-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
+{-# INLINEABLE indexPartial #-}
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
+{- Strangely, this increases allocation and there's no noticeable speedup:
+indexPartial (XArray (ORS.A (ORG.A sh t))) ix =
+ let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t))
+ len = ixxLength ix
+ in XArray (ORS.A (ORG.A (drop len sh)
+ OI.T{ OI.strides = drop len (OI.strides t)
+ , OI.offset = linear
+ , OI.values = OI.values t })) -}
+{-# INLINEABLE index #-}
index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
-index xarr i
- | Refl <- lemAppNil @sh
- = let XArray arr' = indexPartial xarr i :: XArray '[] a
- in S.unScalar arr'
+index (XArray (ORS.A (ORG.A _ t))) i =
+ OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t)))
append :: forall n m sh a. Storable a
=> StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
@@ -217,7 +234,12 @@ transpose ssh perm (XArray arr)
, Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
, Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
, Refl <- lemRankDropLen ssh perm
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
= XArray (S.transpose (permToList' perm) arr)
+#else
+ = XArray (unsafeCoerce (S.transpose (permToList' perm) arr))
+#endif
+
-- | The list argument gives indices into the original dimension list.
--
@@ -243,14 +265,10 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
-sumFull _ (XArray arr) =
- S.unScalar $
- liftO1 (numEltSum1Inner (SNat @0)) $
- S.fromVector [product (S.shapeL arr)] $
- S.toVector arr
+sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr
sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
@@ -283,33 +301,76 @@ sumOuter ssh ssh' arr
reshapePartial ssh ssh' shF $
arr
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX ssh
- = case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+-- | This creates an array from a list of arrays of one less dimension.
+-- The list is streamed, its length is checked and it's verified
+-- that all arrays on the list have the same shape.
+{-# INLINE fromListOuterSN #-}
+fromListOuterSN :: forall n sh a. Storable a
+ => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a
+fromListOuterSN m sh l
+ | Dict <- lemKnownNatRank sh
+ , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l
+ = case sh of
+ ZSX -> fromList1SN m (map S.unScalar (toList l'))
+ _ -> XArray (ravelOuterN (fromSNat' m) l')
-toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr) =
- case S.shapeL arr of
+-- | This checks that the list has the given length and that all shapes in the
+-- list are equal. The list is streamed.
+-- The first array in the list is forced early to potentially release some
+-- memory, before allocating the (large) new array.
+{-# INLINEABLE ravelOuterN #-}
+ravelOuterN :: (KnownNat k, Storable a)
+ => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a
+ravelOuterN 0 _ = error "ravelOuterN: N == 0"
+ravelOuterN k as@(!a0 :| _) = runST $ do
+ let sh0 = S.shapeL a0
+ len = product sh0
+ vecSize = k * len
+ vec <- VSM.unsafeNew vecSize
+ let f !n (ORS.A (ORG.A sht t)) =
+ if | n >= k ->
+ error $ "ravelOuterN: list too long " ++ show (n, k)
+ -- if we do this check just once at the end, we may
+ -- crash instead of producing an accurate error message
+ | sht == sh0 -> do
+ let g off el = do
+ VS.unsafeCopy (VSM.slice off (VS.length el) vec) el
+ return $! off + VS.length el
+ foldM_ g (n * len) (OI.toVectorListT sht t)
+ return $! n + 1
+ | otherwise ->
+ error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0)
+ nFinal <- foldM f 0 as
+ if nFinal == k
+ then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec
+ else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
+
+toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
+toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) =
+ case shArr of
+ [] -> error "impossible"
0 : _ -> []
- _ -> coerce (ORB.toList (S.unravel arr))
+ -- using orthotope's functions here would entail using rerank, which is slow, so we don't
+ [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr)
+ n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1]
+
+-- | Performance note: the list's spine is fully materialised to compute its
+-- length before traversing it again to construct the array.
+{-# INLINE fromList1 #-}
+fromList1 :: Storable a => [a] -> XArray '[Nothing] a
+fromList1 l =
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+-- | The list is streamed.
+{-# INLINE fromList1SN #-}
+fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a
+fromList1SN m l =
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
-toList1 :: Storable a => XArray '[n] a -> [a]
-toList1 (XArray arr) = S.toList arr
+toListLinear :: Storable a => XArray sh a -> [a]
+toListLinear (XArray arr) = S.toList arr
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
diff --git a/src/Data/Vector/Generic/Checked.hs b/src/Data/Vector/Generic/Checked.hs
new file mode 100644
index 0000000..d8aaaae
--- /dev/null
+++ b/src/Data/Vector/Generic/Checked.hs
@@ -0,0 +1,40 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+module Data.Vector.Generic.Checked (
+ fromListNChecked,
+) where
+
+import Data.Stream.Monadic qualified as Stream
+import Data.Vector.Fusion.Bundle.Monadic qualified as VBM
+import Data.Vector.Fusion.Bundle.Size qualified as VBS
+import Data.Vector.Fusion.Util qualified as VFU
+import Data.Vector.Generic qualified as VG
+
+-- for INLINE_FUSED and INLINE_INNER
+#include "vector.h"
+
+
+-- These functions are copied over and lightly edited from the vector and
+-- vector-stream packages, and thus inherit their BSD-3-Clause license with:
+-- Copyright (c) 2008-2012, Roman Leshchinskiy
+-- 2020-2022, Alexey Kuleshevich
+-- 2020-2022, Aleksey Khudyakov
+-- 2020-2022, Andrew Lelechenko
+
+fromListNChecked :: VG.Vector v a => Int -> [a] -> v a
+{-# INLINE fromListNChecked #-}
+fromListNChecked n = VG.unstream . bundleFromListNChecked n
+
+bundleFromListNChecked :: Int -> [a] -> VBM.Bundle VFU.Id v a
+{-# INLINE_FUSED bundleFromListNChecked #-}
+bundleFromListNChecked nTop xsTop
+ | nTop < 0 = error "fromListNChecked: length negative"
+ | otherwise =
+ VBM.fromStream (Stream.Stream step (xsTop, nTop)) (VBS.Max (VFU.delay_inline max nTop 0))
+ where
+ {-# INLINE_INNER step #-}
+ step (xs,n) | n == 0 = case xs of
+ [] -> return Stream.Done
+ _:_ -> error "fromListNChecked: list too long"
+ step (x:xs,n) = return (Stream.Yield x (xs,n-1))
+ step ([],_) = error "fromListNChecked: list too short"
diff --git a/src/GHC/TypeLits/Orphans.hs b/src/GHC/TypeLits/Orphans.hs
new file mode 100644
index 0000000..42f7293
--- /dev/null
+++ b/src/GHC/TypeLits/Orphans.hs
@@ -0,0 +1,13 @@
+-- | Compatibility module adding some additional instances not yet defined in
+-- base-4.18 with GHC 9.6.
+{-# OPTIONS -Wno-orphans #-}
+module GHC.TypeLits.Orphans where
+
+import GHC.TypeLits
+
+
+instance Eq (SNat n) where
+ _ == _ = True
+
+instance Ord (SNat n) where
+ compare _ _ = EQ
diff --git a/test/Gen.hs b/test/Gen.hs
index 281c620..952e8db 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -4,7 +4,6 @@
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -20,10 +19,10 @@ import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Nested.Permutation
-import Data.Array.Nested.Types
import Data.Array.Nested
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -59,9 +58,12 @@ shuffleShR = \sh -> go (length sh) (toList sh) sh
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
-genShR sn = do
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 100_000)
+ targetSize <- Gen.int (Range.linear 0 targetMax)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
@@ -76,9 +78,8 @@ genShR sn = do
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
dims <- genDims sn targetSize
- let dimsL = toList dims
- maxdim = maximum dimsL
- cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ let maxdim = maximum dims
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize)
shuffleShR (min cap <$> dims)
-- | Example: given 3 and 7, might return:
@@ -93,10 +94,14 @@ genShR sn = do
-- other dimensions might be zero.
genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
genReplicatedShR = \m n -> do
- sh1 <- genShR m
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
(sh2, sh3) <- injectOnes n sh1 sh1
return (sh1, sh2, sh3)
where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
injectOnes n@SNat shOnes sh
| m@SNat <- shrRank sh
@@ -105,7 +110,7 @@ genReplicatedShR = \m n -> do
EQI -> return (shOnes, sh)
GTI -> do
index <- Gen.int (Range.linear 0 (fromSNat' m))
- value <- Gen.int (Range.linear 1 5)
+ value <- Gen.int (uncurry Range.linear repvalrange)
Refl <- return (lem n m)
injectOnes n (inject index 1 shOnes) (inject index value sh)
@@ -115,7 +120,7 @@ genReplicatedShR = \m n -> do
inject :: Int -> Int -> IShR m -> IShR (m + 1)
inject 0 v sh = v :$: sh
inject i v (w :$: sh) = w :$: inject (i - 1) v sh
- inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+ inject _ _ ZSR = error "unreachable"
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
@@ -134,7 +139,7 @@ genStaticShX = \n k -> case n of
genStaticShX n' $ \ssh ->
k (item :!% ssh)
where
- genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m ()
+ genItem :: Monad m => (forall n. SMayNat () n -> PropertyT m ()) -> PropertyT m ()
genItem k = do
b <- forAll Gen.bool
if b
@@ -157,7 +162,7 @@ genPermR n = Gen.shuffle [0 .. n-1]
genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
genPerm n@SNat k = do
list <- forAll $ genPermR (fromSNat' n)
- permFromList list $ \perm -> do
+ permFromListCont list $ \perm -> do
case permCheckPermutation perm $
case sameNat' (permRank perm) n of
Just Refl -> Just (k perm)
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 3b78bc0..e26c3dd 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -1,9 +1,12 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
{-# LANGUAGE TypeAbstractions #-}
+#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -18,13 +21,13 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Nested.Types (fromSNat')
import Data.Array.Nested
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
import Hedgehog
import Hedgehog.Gen qualified as Gen
-import Hedgehog.Internal.Property (forAllT)
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -39,8 +42,11 @@ import Util
fineTol :: Double
fineTol = 1e-8
-prop_sum_nonempty :: Property
-prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+debugCoverage :: Bool
+debugCoverage = False
+
+gen_red_nonempty :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_nonempty f = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
@@ -49,11 +55,10 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ f inrank outrank arr
-prop_sum_empty :: Property
-prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+gen_red_empty :: (forall n. SNat (n + 1) -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_empty f = property $ genRank $ \outrankm1@(SNat @nm1) -> do
-- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
_outrank :: SNat n <- return $ SNat @(nm1 + 1)
let inrank = SNat @(n + 1)
@@ -62,14 +67,13 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (0 `elem` toList (shrTail sh))
+ guard (0 `elem` shrTail sh)
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @(n + 1) @Double (toList sh) []
- let rarr = rfromOrthotope inrank arr
- OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+ f inrank arr
-prop_sum_lasteq1 :: Property
-prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+gen_red_lasteq1 :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_lasteq1 f = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
guard (all (> 0) outsh)
@@ -77,11 +81,10 @@ prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
genStorables (Range.singleton (product insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ f inrank outrank arr
-prop_sum_replicated :: Bool -> Property
-prop_sum_replicated doTranspose = property $
+gen_red_replicated :: Bool -> (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_replicated doTranspose f = property $
genRank $ \inrank1@(SNat @m) ->
genRank $ \outrank@(SNat @nm1) -> do
inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
@@ -89,6 +92,10 @@ prop_sum_replicated doTranspose = property $
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
guard (all (> 0) sh3)
arr <- forAllT $
OR.stretch (toList sh3)
@@ -100,8 +107,50 @@ prop_sum_replicated doTranspose = property $
if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
return $ OR.transpose perm arr
else return arr
- let rarr = rfromOrthotope inrank2 arrTrans
- almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+ f inrank2 outrank arrTrans
+
+
+prop_sum_nonempty :: Property
+prop_sum_nonempty = gen_red_nonempty $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_empty :: Property
+prop_sum_empty = gen_red_empty $ \inrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = gen_red_lasteq1 $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = gen_red_replicated doTranspose $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+
+prop_sumall_nonempty :: Property
+prop_sumall_nonempty = gen_red_nonempty $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr)
+
+prop_sumall_empty :: Property
+prop_sumall_empty = gen_red_empty $ \inrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ rsumAllPrim rarr === 0.0
+
+prop_sumall_lasteq1 :: Property
+prop_sumall_lasteq1 = gen_red_lasteq1 $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr)
+
+prop_sumall_replicated :: Bool -> Property
+prop_sumall_replicated doTranspose = gen_red_replicated doTranspose $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq 1e-6 (rsumAllPrim rarr) (OR.sumA arr)
+
prop_negate_with :: forall f b. Show b
=> ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
@@ -130,6 +179,13 @@ tests = testGroup "C"
,testProperty "replicated" (prop_sum_replicated False)
,testProperty "replicated_transposed" (prop_sum_replicated True)
]
+ ,testGroup "sumAll"
+ [testProperty "nonempty" prop_sumall_nonempty
+ ,testProperty "empty" prop_sumall_empty
+ ,testProperty "last==1" prop_sumall_lasteq1
+ ,testProperty "replicated" (prop_sumall_replicated False)
+ ,testProperty "replicated_transposed" (prop_sumall_replicated True)
+ ]
,testGroup "negate"
[testProperty "normalised" $ prop_negate_with
(\k -> genRank (k (Const ())))
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 98a6da5..4e75d64 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -24,7 +24,7 @@ tests = testGroup "Permutation"
[testProperty "permCheckPermutation" $ property $ do
n <- forAll $ Gen.int (Range.linear 0 10)
list <- forAll $ genPermR n
- let r = permFromList list $ \perm ->
+ let r = permFromListCont list $ \perm ->
permCheckPermutation perm ()
case r of
Just () -> return ()
diff --git a/test/Util.hs b/test/Util.hs
index 8a5ba72..6514fbf 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -36,16 +36,20 @@ orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
-class AlmostEq f where
- type AlmostEqConstr f :: Type -> Constraint
+class AlmostEq t where
+ type EltOf t :: Type
-- | absolute tolerance, lhs, rhs
- almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
- => a -> f a -> f a -> m ()
+ almostEq :: MonadTest m => EltOf t -> t -> t -> m ()
-instance AlmostEq (OR.Array n) where
- type AlmostEqConstr (OR.Array n) = OR.Unbox
+instance (OR.Unbox a, Ord a, Show a, Fractional a) => AlmostEq (OR.Array n a) where
+ type EltOf (OR.Array n a) = a
almostEq atol lhs rhs
| OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
success
| otherwise =
failDiff lhs rhs
+
+instance AlmostEq Double where
+ type EltOf Double = Double
+ almostEq atol lhs rhs | abs (lhs - rhs) < atol = success
+ | otherwise = failDiff lhs rhs