aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.stylish-haskell.yaml452
-rw-r--r--CHANGELOG.md15
-rw-r--r--README.md197
-rw-r--r--bench/Main.hs68
-rw-r--r--cabal.project2
-rw-r--r--cbits/arith.c10
-rw-r--r--cbits/arith_lists.h62
-rwxr-xr-xgentrace.sh2
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs69
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs4
-rw-r--r--ox-arrays.cabal95
-rw-r--r--release-hints.txt3
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs137
-rw-r--r--src/Data/Array/Nested.hs77
-rw-r--r--src/Data/Array/Nested/Convert.hs344
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs86
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs59
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs559
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs495
-rw-r--r--src/Data/Array/Nested/Lemmas.hs184
-rw-r--r--src/Data/Array/Nested/Mixed.hs (renamed from src/Data/Array/Nested/Internal/Mixed.hs)577
-rw-r--r--src/Data/Array/Nested/Mixed/ListX.hs139
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs707
-rw-r--r--src/Data/Array/Nested/Permutation.hs (renamed from src/Data/Array/Mixed/Permutation.hs)180
-rw-r--r--src/Data/Array/Nested/Ranked.hs381
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs271
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs528
-rw-r--r--src/Data/Array/Nested/Shaped.hs297
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs266
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs491
-rw-r--r--src/Data/Array/Nested/Trace.hs42
-rw-r--r--src/Data/Array/Nested/Trace/TH.hs83
-rw-r--r--src/Data/Array/Nested/Types.hs (renamed from src/Data/Array/Mixed/Types.hs)34
-rw-r--r--src/Data/Array/Strided/Orthotope.hs (renamed from src/Data/Array/Mixed/Internal/Arith.hs)9
-rw-r--r--src/Data/Array/XArray.hs (renamed from src/Data/Array/Mixed/XArray.hs)170
-rw-r--r--src/Data/Vector/Generic/Checked.hs40
-rw-r--r--src/GHC/TypeLits/Orphans.hs13
-rw-r--r--test/Gen.hs34
-rw-r--r--test/Tests/C.hs131
-rw-r--r--test/Tests/Permutation.hs4
-rw-r--r--test/Util.hs18
41 files changed, 4568 insertions, 2767 deletions
diff --git a/.stylish-haskell.yaml b/.stylish-haskell.yaml
new file mode 100644
index 0000000..bfd25ea
--- /dev/null
+++ b/.stylish-haskell.yaml
@@ -0,0 +1,452 @@
+# stylish-haskell configuration file
+# ==================================
+
+# The stylish-haskell tool is mainly configured by specifying steps. These steps
+# are a list, so they have an order, and one specific step may appear more than
+# once (if needed). Each file is processed by these steps in the given order.
+steps:
+ # Convert some ASCII sequences to their Unicode equivalents. This is disabled
+ # by default.
+ # - unicode_syntax:
+ # # In order to make this work, we also need to insert the UnicodeSyntax
+ # # language pragma. If this flag is set to true, we insert it when it's
+ # # not already present. You may want to disable it if you configure
+ # # language extensions using some other method than pragmas. Default:
+ # # true.
+ # add_language_pragma: true
+
+ # Format module header
+ #
+ # Currently, this option is not configurable and will format all exports and
+ # module declarations to minimize diffs
+ #
+ # - module_header:
+ # # How many spaces use for indentation in the module header.
+ # indent: 4
+ #
+ # # Should export lists be sorted? Sorting is only performed within the
+ # # export section, as delineated by Haddock comments.
+ # sort: true
+ #
+ # # See `separate_lists` for the `imports` step.
+ # separate_lists: true
+
+ # Format record definitions. This is disabled by default.
+ #
+ # You can control the layout of record fields. The only rules that can't be configured
+ # are these:
+ #
+ # - "|" is always aligned with "="
+ # - "," in fields is always aligned with "{"
+ # - "}" is likewise always aligned with "{"
+ #
+ # - records:
+ # # How to format equals sign between type constructor and data constructor.
+ # # Possible values:
+ # # - "same_line" -- leave "=" AND data constructor on the same line as the type constructor.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of the next line.
+ # equals: "indent 2"
+ #
+ # # How to format first field of each record constructor.
+ # # Possible values:
+ # # - "same_line" -- "{" and first field goes on the same line as the data constructor.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of the data constructor
+ # first_field: "indent 2"
+ #
+ # # How many spaces to insert between the column with "," and the beginning of the comment in the next line.
+ # field_comment: 2
+ #
+ # # How many spaces to insert before "deriving" clause. Deriving clauses are always on separate lines.
+ # deriving: 2
+ #
+ # # How many spaces to insert before "via" clause counted from indentation of deriving clause
+ # # Possible values:
+ # # - "same_line" -- "via" part goes on the same line as "deriving" keyword.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of "deriving" keyword.
+ # via: "indent 2"
+ #
+ # # Sort typeclass names in the "deriving" list alphabetically.
+ # sort_deriving: true
+ #
+ # # Wheter or not to break enums onto several lines
+ # #
+ # # Default: false
+ # break_enums: false
+ #
+ # # Whether or not to break single constructor data types before `=` sign
+ # #
+ # # Default: true
+ # break_single_constructors: true
+ #
+ # # Whether or not to curry constraints on function.
+ # #
+ # # E.g: @allValues :: Enum a => Bounded a => Proxy a -> [a]@
+ # #
+ # # Instead of @allValues :: (Enum a, Bounded a) => Proxy a -> [a]@
+ # #
+ # # Default: false
+ # curried_context: false
+
+ # Align the right hand side of some elements. This is quite conservative
+ # and only applies to statements where each element occupies a single
+ # line.
+ # Possible values:
+ # - always - Always align statements.
+ # - adjacent - Align statements that are on adjacent lines in groups.
+ # - never - Never align statements.
+ # All default to always.
+ - simple_align:
+ cases: never
+ top_level_patterns: never
+ records: never
+ multi_way_if: never
+
+ # Import cleanup
+ - imports:
+ # There are different ways we can align names and lists.
+ #
+ # - global: Align the import names and import list throughout the entire
+ # file.
+ #
+ # - file: Like global, but don't add padding when there are no qualified
+ # imports in the file.
+ #
+ # - group: Only align the imports per group (a group is formed by adjacent
+ # import lines).
+ #
+ # - none: Do not perform any alignment.
+ #
+ # Default: global.
+ align: group
+
+ # The following options affect only import list alignment.
+ #
+ # List align has following options:
+ #
+ # - after_alias: Import list is aligned with end of import including
+ # 'as' and 'hiding' keywords.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # > init, last, length)
+ #
+ # - with_alias: Import list is aligned with start of alias or hiding.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # > init, last, length)
+ #
+ # - with_module_name: Import list is aligned `list_padding` spaces after
+ # the module name.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # init, last, length)
+ #
+ # This is mainly intended for use with `pad_module_names: false`.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # init, last, length, scanl, scanr, take, drop,
+ # sort, nub)
+ #
+ # - new_line: Import list starts always on new line.
+ #
+ # > import qualified Data.List as List
+ # > (concat, foldl, foldr, head, init, last, length)
+ #
+ # - repeat: Repeat the module name to align the import list.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head)
+ # > import qualified Data.List as List (init, last, length)
+ #
+ # Default: after_alias
+ list_align: after_alias
+
+ # Right-pad the module names to align imports in a group:
+ #
+ # - true: a little more readable
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr,
+ # > init, last, length)
+ # > import qualified Data.List.Extra as List (concat, foldl, foldr,
+ # > init, last, length)
+ #
+ # - false: diff-safe
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, init,
+ # > last, length)
+ # > import qualified Data.List.Extra as List (concat, foldl, foldr,
+ # > init, last, length)
+ #
+ # Default: true
+ pad_module_names: false
+
+ # Long list align style takes effect when import is too long. This is
+ # determined by 'columns' setting.
+ #
+ # - inline: This option will put as much specs on same line as possible.
+ #
+ # - new_line: Import list will start on new line.
+ #
+ # - new_line_multiline: Import list will start on new line when it's
+ # short enough to fit to single line. Otherwise it'll be multiline.
+ #
+ # - multiline: One line per import list entry.
+ # Type with constructor list acts like single import.
+ #
+ # > import qualified Data.Map as M
+ # > ( empty
+ # > , singleton
+ # > , ...
+ # > , delete
+ # > )
+ #
+ # Default: inline
+ long_list_align: new_line_multiline
+
+ # Align empty list (importing instances)
+ #
+ # Empty list align has following options
+ #
+ # - inherit: inherit list_align setting
+ #
+ # - right_after: () is right after the module name:
+ #
+ # > import Vector.Instances ()
+ #
+ # Default: inherit
+ empty_list_align: inherit
+
+ # List padding determines indentation of import list on lines after import.
+ # This option affects 'long_list_align'.
+ #
+ # - <integer>: constant value
+ #
+ # - module_name: align under start of module name.
+ # Useful for 'file' and 'group' align settings.
+ #
+ # Default: 4
+ list_padding: 2
+
+ # Separate lists option affects formatting of import list for type
+ # or class. The only difference is single space between type and list
+ # of constructors, selectors and class functions.
+ #
+ # - true: There is single space between Foldable type and list of it's
+ # functions.
+ #
+ # > import Data.Foldable (Foldable (fold, foldl, foldMap))
+ #
+ # - false: There is no space between Foldable type and list of it's
+ # functions.
+ #
+ # > import Data.Foldable (Foldable(fold, foldl, foldMap))
+ #
+ # Default: true
+ separate_lists: false
+
+ # Space surround option affects formatting of import lists on a single
+ # line. The only difference is single space after the initial
+ # parenthesis and a single space before the terminal parenthesis.
+ #
+ # - true: There is single space associated with the enclosing
+ # parenthesis.
+ #
+ # > import Data.Foo ( foo )
+ #
+ # - false: There is no space associated with the enclosing parenthesis
+ #
+ # > import Data.Foo (foo)
+ #
+ # Default: false
+ space_surround: false
+
+ # Enabling this argument will use the new GHC lib parse to format imports.
+ #
+ # This currently assumes a few things, it will assume that you want post
+ # qualified imports. It is also not as feature complete as the old
+ # imports formatting.
+ #
+ # It does not remove redundant lines or merge lines. As such, the full
+ # feature scope is still pending.
+ #
+ # It _is_ however, a fine alternative if you are using features that are
+ # not parseable by haskell src extensions and you're comfortable with the
+ # presets.
+ #
+ # Default: false
+ ghc_lib_parser: false
+
+ # Post qualify option moves any qualifies found in import declarations
+ # to the end of the declaration. This also adjust padding for any
+ # unqualified import declarations.
+ #
+ # - true: Qualified as <module name> is moved to the end of the
+ # declaration.
+ #
+ # > import Data.Bar
+ # > import Data.Foo qualified as F
+ #
+ # - false: Qualified remains in the default location and unqualified
+ # imports are padded to align with qualified imports.
+ #
+ # > import Data.Bar
+ # > import qualified Data.Foo as F
+ #
+ # Default: false
+ post_qualify: true
+
+ # A list of rules specifying how to group modules and how to
+ # order the groups.
+ #
+ # Each rule has a match field; the rule only applies to module
+ # names matched by this pattern. Patterns are POSIX extended
+ # regular expressions; see the documentation of Text.Regex.TDFA
+ # for details:
+ # https://hackage.haskell.org/package/regex-tdfa-1.3.1.2/docs/Text-Regex-TDFA.html
+ #
+ # Rules are processed in order, so only the *first* rule that
+ # matches a specific module will apply. Any module names that do
+ # not match a single rule will be put into a single group at the
+ # end of the import block.
+ #
+ # Example: group MyApp modules first, with everything else in
+ # one group at the end.
+ #
+ # group_rules:
+ # - match: "^MyApp\\>"
+ #
+ # > import MyApp
+ # > import MyApp.Foo
+ # >
+ # > import Control.Monad
+ # > import MyApps
+ # > import Test.MyApp
+ #
+ # A rule can also optionally have a sub_group pattern. Imports
+ # that match the rule will be broken up into further groups by
+ # the part of the module name matched by the sub_group pattern.
+ #
+ # Example: group MyApp modules first, then everything else
+ # sub-grouped by the first part of the module name.
+ #
+ # group_rules:
+ # - match: "^MyApp\\>"
+ # - match: "."
+ # sub_group: "^[^.]+"
+ #
+ # > import MyApp
+ # > import MyApp.Foo
+ # >
+ # > import Control.Applicative
+ # > import Control.Monad
+ # >
+ # > import Data.Map
+ #
+ # A pattern only needs to match part of the module name, which
+ # could be in the middle. You can use ^pattern to anchor to the
+ # beginning of the module name, pattern$ to anchor to the end
+ # and ^pattern$ to force a full match. Example:
+ #
+ # - "Test\\." would match "Test.Foo" and "Foo.Test.Lib"
+ # - "^Test\\." would match "Test.Foo" but not "Foo.Test.Lib"
+ # - "\\.Test$" would match "Foo.Test" but not "Foo.Test.Lib"
+ # - "^Test$" would *only* match "Test"
+ #
+ # You can use \\< and \\> to anchor against the beginning and
+ # end of words, respectively. For example:
+ #
+ # - "^Test\\." would match "Test.Foo" but not "Test" or "Tests"
+ # - "^Test\\>" would match "Test.Foo" and "Test", but not
+ # "Tests"
+ #
+ # The default is a single rule that matches everything and
+ # sub-groups based on the first component of the module name.
+ #
+ # Default: [{ "match" : ".*", "sub_group": "^[^.]+" }]
+# group_rules:
+# - match: ".*"
+# sub_group: "^[^.]+"
+# - match: "^Data.Array\\>"
+# sub_group: "^[^.]+"
+# - match: "^HordeAd\\>"
+
+ # Language pragmas
+ - language_pragmas:
+ # We can generate different styles of language pragma lists.
+ #
+ # - vertical: Vertical-spaced language pragmas, one per line.
+ #
+ # - compact: A more compact style.
+ #
+ # - compact_line: Similar to compact, but wrap each line with
+ # `{-#LANGUAGE #-}'.
+ #
+ # Default: vertical.
+# style: compact
+
+ # Align affects alignment of closing pragma brackets.
+ #
+ # - true: Brackets are aligned in same column.
+ #
+ # - false: Brackets are not aligned together. There is only one space
+ # between actual import and closing bracket.
+ #
+ # Default: true
+ align: false
+
+ # stylish-haskell can detect redundancy of some language pragmas. If this
+ # is set to true, it will remove those redundant pragmas. Default: true.
+ remove_redundant: true
+
+ # Language prefix to be used for pragma declaration, this allows you to
+ # use other options non case-sensitive like "language" or "Language".
+ # If a non correct String is provided, it will default to: LANGUAGE.
+ language_prefix: LANGUAGE
+
+ # Replace tabs by spaces. This is disabled by default.
+ # - tabs:
+ # # Number of spaces to use for each tab. Default: 8, as specified by the
+ # # Haskell report.
+ # spaces: 8
+
+ # Remove trailing whitespace
+ - trailing_whitespace: {}
+
+ # Squash multiple spaces between the left and right hand sides of some
+ # elements into single spaces. Basically, this undoes the effect of
+ # simple_align but is a bit less conservative.
+ # - squash: {}
+
+# A common setting is the number of columns (parts of) code will be wrapped
+# to. Different steps take this into account.
+#
+# Set this to null to disable all line wrapping.
+#
+# Default: 80.
+columns: 200
+
+# By default, line endings are converted according to the OS. You can override
+# preferred format here.
+#
+# - native: Native newline format. CRLF on Windows, LF on other OSes.
+#
+# - lf: Convert to LF ("\n").
+#
+# - crlf: Convert to CRLF ("\r\n").
+#
+# Default: native.
+newline: native
+
+# Sometimes, language extensions are specified in a cabal file or from the
+# command line instead of using language pragmas in the file. stylish-haskell
+# needs to be aware of these, so it can parse the file correctly.
+#
+# No language extensions are enabled by default.
+#language_extensions:
+# - NoStarIsType
+ # - TemplateHaskell
+ # - QuasiQuotes
+
+# Attempt to find the cabal file in ancestors of the current directory, and
+# parse options (currently only language extensions) from that.
+#
+# Default: true
+cabal: true
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..3c9b3c4
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,15 @@
+# Changelog for `ox-arrays`
+
+This package intends to follow the [PVP](https://pvp.haskell.org/).
+
+## 0.2.0.1
+- Fix the broken Data.Array.Nested.Trace export list (not compiled by default).
+
+## 0.2.0.0
+- Many performance improvements to the Haskell code (C code mostly left alone).
+- A few API changes and additions, mostly guided by the performance-improving work.
+- The API should be a little more stable now.
+
+## 0.1.0.0
+- Initial release
+- Various aspects of the API are still experimental, and breaking changes are expected in the future.
diff --git a/README.md b/README.md
index 9b8d543..032f7b8 100644
--- a/README.md
+++ b/README.md
@@ -1,56 +1,173 @@
-Wrapper library around `orthotope` that defines nested arrays, including
-tuples, of (eventually) unboxed values. The arrays are represented in
-struct-of-arrays form via the `Data.Vector.Unboxed` data family trick. Below
-the surface layer, there is a more low-level wrapper around `orthotope` that
-defines an array type type-indexed by `[Maybe Nat]`: some dimensions are
-shape-typed (i.e. have their size statically known), and some not.
+## ox-arrays
-An overview of the API:
+ox-arrays is an array library that defines nested arrays, including tuples, of
+(eventually) unboxed values. The arrays are represented in struct-of-arrays
+form via the `Data.Vector.Unboxed` data family trick; the component arrays are
+`orthotope` arrays
+([RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html))
+which describe elements using a _stride vector_ or
+[LMAD](https://dl.acm.org/doi/pdf/10.1145/509705.509708) so that `transpose`
+and `replicate` need only modify array metadata, not actually move around data.
+
+Because of the struct-of-arrays representation, nested arrays are not fully
+general: indeed, arrays are not actually nested under the hood, so if one has an
+array of arrays, those element arrays must all have the same shape (length,
+width, etc.). If one has an array of tuples of arrays, then all the `fst`
+components must have the same shape and all the `snd` components must have the
+same shape, but the two pair components themselves can be different.
+
+However, the nesting functionality of ox-arrays can be completely ignored if you
+only care about other parts of its API, or the vectorised arithmetic operations
+(using hand-written C code). Nesting support mostly does not get in the way, and
+has essentially no overhead (both when it's used and when it's not used).
+
+ox-arrays defines three array types: `Ranked`, `Shaped` and `Mixed`.
+- `Ranked` corresponds to `orthotope`'s
+ [RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html)
+ and has the _rank_ of the array (its number of dimensions) on the type level.
+ For example, `Ranked 2 Float` is a two-dimensional array of `Float`s, i.e. a
+ matrix.
+- `Shaped` corresponds to `orthotope`'s
+ [ShapedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html).
+ and has the full _shape_ of the array (its dimension sizes) on the type level
+ as a type-level list of `Nat`s. For example, `Shaped [2,3] Float` is a 2-by-3
+ matrix. The innermost dimension correspond to the right-most element in the
+ list.
+- `Mixed` is halfway between the two: it has a type parameter of kind
+ `[Maybe Nat]` whose length is the rank of the array; `Nothing` elements have
+ unknown size, whereas `Just` elements have the indicated size. The type
+ `Mixed [Nothing, Nothing] a` is equivalent to `Ranked 2 a`; the type
+ `Mixed [Just n, Just m] a` is equivalent to `Shaped [n, m] a`.
+
+In various places in the API of a library like ox-arrays, one can make a
+decision between 1. requiring a type class constraint providing certain
+information (e.g.
+[KnownNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:KnownNat)
+or `orthotope`'s
+[Shape](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html#t:Shape)),
+or 2. taking singleton _values_ that encode said information in a way that is
+linked to the type level (e.g.
+[SNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:SNat)).
+`orthotope` chooses the type class approach; ox-arrays chooses the singleton
+approach. Singletons are more verbose at times, but give the programmer more
+insight in what data is flowing where, and more importantly, more control: type
+class inference is very nice and implicit, but if it's not powerful enough for
+the trickery you're doing, you're out of luck. Singletons allow you to explain
+as precisely as you want to GHC what exactly you're doing.
+
+Below the surface layer, there is a more low-level wrapper (`XArray`) around
+`orthotope` that defines a non-nested `Mixed`-style array type.
+
+**Be aware**: `ox-arrays` attempts to preserve sharing as much as possible.
+That is to say: if a function is able to avoid copying array data and return an
+array that references the original underlying `Vector`, it may do so. For
+example, this means that if you convert a nested array to a list of arrays, all
+returned arrays reference part of the original array without copying. This
+makes `mtoList` fast, but also means that memory may be retained longer than
+you might expect.
+
+Here is a little taster of the API, to get a sense for the design:
```haskell
-data Ranked (n :: INat) a {- e.g. -} Ranked 3 Float
-data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
-data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
-
-Ranked I0 a = Ranked Z a ~~= Acc.Array Z a = Acc.Scalar a
-Ranked I1 a = Ranked (S Z) a ~~= Acc.Array (Z :. Int) a = Acc.Vector a
-Ranked I2 a = Ranked (S (S Z)) a ~~= Acc.Array (Z :. Int :. Int) a = Acc.Matrix a
+import GHC.TypeLits (Nat, SNat)
+data Ranked (n :: Nat) a {- e.g. -} Ranked 3 Float
+data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
+data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
-rshape :: (Elt a, KnownINat n) => Ranked n a -> IxR n
-sshape :: (Elt a, KnownShape sh) => Shaped sh a -> IxS sh
-mshape :: (Elt a, KnownShapeX xsh) => Mixed xsh a -> IxX xsh
+-- Shape types are written Sh{R,S,X}. The 'I' prefix denotes a Int-filled shape;
+-- ShR and ShX are more general containers. ShS is a singleton.
+rshape :: Elt a => Ranked n a -> IShR n
+sshape :: Elt a => Shaped sh a -> ShS sh
+mshape :: Elt a => Mixed xsh a -> IShX xsh
-rindex :: Elt a => Ranked n a -> IxR n -> a
-sindex :: Elt a => Shaped sh a -> IxS sh -> a
-mindex :: Elt a => Mixed xsh a -> IxX xsh -> a
+-- Index types are written Ix{R,S,X}.
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+mindex :: Elt a => Mixed xsh a -> IIxX xsh -> a
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
+-- The index types can be used as if they were defined as follows; pattern
+-- synonyms are provided to construct the illusion. (The actual definitions are
+-- indirect, with the benefit that they make many conversion operations free.)
+data IIxR n where
+ ZIR :: IIxR 0
+ (:.:) :: Int -> IIxR n -> IIxR (n + 1)
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
+data IIxS sh where
+ ZIS :: IIxS '[]
+ (:.$) :: Int -> IIxS sh -> IIxS (n : sh)
-data IxX sh where
- IZX :: IxX '[]
- (::@) :: Int -> IxX sh -> IxX (Just n : sh)
- (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+data IIxX xsh where
+ ZIX :: IIxX '[]
+ (:.%) :: Int -> IIxX xsh -> IIxX (mn : xsh)
+-- Similarly, the shape types can be used as if they were defined as follows.
+data IShR n where
+ ZSR :: IShR 0
+ (:$:) :: Int -> IShR n -> IShR (n + 1)
+
+data ShS sh where
+ ZSS :: ShS '[]
+ (:$$) :: SNat n -> ShS sh -> ShS (n : sh)
+
+data IShX xsh where
+ ZSX :: IShX '[]
+ (:$%) :: SMayNat Int mn -> IShX xsh -> IShX (mn : xsh)
+-- where:
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: SNat n -> SMayNat i (Just n)
+
+-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed
+-- shape -- that is to say, only the statically-known part of a mixed shape.
+-- StaticShX provides for this need. It can be used as if defined as follows:
+data StaticShX xsh where
+ ZKX :: StaticShX '[]
+ (:!%) :: SMayNat () mn -> StaticShX xsh -> StaticShX (mn : xsh)
+
+-- The Elt class describes types that can be used as elements of an array. While
+-- it is technically possible to define new instances of this class, typical
+-- usage should regard Elt as closed. The user-relevant instances are the
+-- following:
class Elt a
-instance Elt ()
-instance Elt Double
-instance Elt Int
-instance (Elt a, Elt b) => Elt (a, b)
-instance (Elt a, KnownINat n) => Elt (Ranked n a)
-instance (Elt a, KnownShape sh) => Elt (Shaped sh a)
-instance (Elt a, KnownShapeX xsh) => Elt (Mixed xsh a)
+instance Elt ()
+instance Elt Bool
+instance Elt Float
+instance Elt Double
+instance Elt Int
+instance (Elt a, Elt b) => Elt (a, b)
+instance Elt a => Elt (Ranked n a)
+instance Elt a => Elt (Shaped sh a)
+instance Elt a => Elt (Mixed xsh a)
+
+-- Essentially all functions that ox-arrays offers on arrays are first-order:
+-- add two arrays elementwise, transpose an array, append arrays, compute
+-- minima/maxima, zip/unzip, nest/unnest, etc. The first-order approach allows
+-- operations, especially arithmetic ones, to be vectorised using hand-written
+-- C code, without needing any sort of JIT compilation.
+rappend :: Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+mappend :: Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
-rgenerate :: Elt a => IxR n -> (IxR n -> a) -> Ranked n a
-sgenerate :: (Elt a, KnownShape sh) => (IxS sh -> a) -> Shaped sh a
-mgenerate :: (Elt a, KnownShapeX xsh) => IxX xsh -> (IxX xsh -> a) -> Mixed xsh a
+-- Exceptionally, also one higher-order function is provided per array type:
+-- 'generate'. These functions have the caveat that regularity of arrays must be
+-- preserved: all returned 'a's must have equal shape. See the documentation of
+-- 'mgenerate'.
+-- Warning: because the invocations of the function you pass cannot be
+-- vectorised, 'generate' is rather slow if 'a' is small.
+-- The 'KnownElt' class captures an API infelicity where constraint-based shape
+-- passing is the only practical option.
+rgenerate :: KnownElt a => IShR n -> (IxR n -> a) -> Ranked n a
+sgenerate :: KnownElt a => ShS sh -> (IxS sh -> a) -> Shaped sh a
+mgenerate :: KnownElt a => IShX xsh -> (IxX xsh -> a) -> Mixed xsh a
+-- Under the hood, Ranked and Shaped are both newtypes over Mixed. Mixed itself
+-- is a data family over XArray, which is a newtype over orthotope's RankedS.
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
```
+
+About the name: when importing `orthotope` array modules, a possible naming
+convention is to use qualified imports as `OR` for "orthotope ranked" arrays and
+`OS` for "orthotope shaped" arrays. ox-arrays was started to fill the `OX` gap,
+then grew out of proportion.
diff --git a/bench/Main.hs b/bench/Main.hs
index 5901d8b..185bef0 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -9,25 +9,18 @@ import Control.Monad (when)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
-import Data.Foldable (toList)
import Data.Vector.Storable qualified as VS
import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
import Text.Show (showListWith)
-import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed (Mixed (M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
-import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
+import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
+import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith.Internal qualified as Arith
-
-
-enableMisc :: Bool
-enableMisc = False
-
-bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
-bgroupIf True = bgroup
-bgroupIf False = \name _ -> bgroup name []
+import Data.Array.XArray (XArray(..))
main :: IO ()
@@ -47,11 +40,11 @@ main_tests = defaultMain
let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
l ""
- in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
+ in bench (name ++ " " ++ showSh (shrToList (rshape inp1)) ++
" str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
- iota n = riota @Double n
+ iota = riota @Double
in
[dotprodBench "dot 1D"
(iota 10_000_000
@@ -104,7 +97,7 @@ main_tests = defaultMain
in nf (\a -> RS.normalize a)
(RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
]
- ,bgroupIf enableMisc "misc"
+ ,bgroup "misc"
[let n = 1000
k = 1000
in bgroup ("fusion [" ++ show k ++ "]*" ++ show n)
@@ -148,6 +141,16 @@ main_tests = defaultMain
| ki <- [1::Int ..5]]
]
]
+ ,bench "ixxFromLinear 10000x" $
+ let n = 10000
+ sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> [ixxFromLinear @Int sh i | i <- [1..n]]) sh0
+ ,bench "ixxFromLinear 1x" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> ixxFromLinear @Int sh 1234) sh0
+ ,bench "shxEnum" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> shxEnum sh) sh0
]
]
@@ -156,45 +159,54 @@ tests_compare =
let n = 1_000_000 in
[bgroup "Num"
[bench "sum(+) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (+)) a b)))
(riota @Double n, riota n)
,bench "sum(*) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (*)) a b)))
(riota @Double n, riota n)
,bench "sum(/) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (/)) a b)))
(riota @Double n, riota n)
,bench "sum(**) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (**)) a b)))
(riota @Double n, riota n)
,bench "sum(sin) Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a)))
+ nf (\a -> runScalar (rsumOuter1Prim (liftRanked1 (mliftPrim sin) a)))
(riota @Double n)
,bench "sum Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 a))
+ nf (\a -> runScalar (rsumOuter1Prim a))
+ (riota @Double n)
+ ,bench "sumAll iota [1e6]" $
+ nf (\a -> rsumAllPrim a)
(riota @Double n)
+ ,bench "sumAll rev1(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rrev1 $ riota @Double n)
+ ,bench "sumAll reshape(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rreshape (1 :$: n :$: 1 :$: ZSR) $ riota @Double n)
]
,bgroup "NumElt"
[bench "sum(+) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a + b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a + b)))
(riota @Double n, riota n)
,bench "sum(*) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b)))
(riota @Double n, riota n)
,bench "sum(/) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a / b)))
(riota @Double n, riota n)
,bench "sum(**) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a ** b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a ** b)))
(riota @Double n, riota n)
,bench "sum(sin) Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 (sin a)))
+ nf (\a -> runScalar (rsumOuter1Prim (sin a)))
(riota @Double n)
,bench "sum Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 a))
+ nf (\a -> runScalar (rsumOuter1Prim a))
(riota @Double n)
,bench "sum(*) Double [1e6] stride 1; -1" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b)))
(riota @Double n, rrev1 (riota n))
,bench "dotprod Float [1e6]" $
nf (\(a, b) -> rdot a b)
diff --git a/cabal.project b/cabal.project
index d102ed6..d76d872 100644
--- a/cabal.project
+++ b/cabal.project
@@ -1,2 +1,2 @@
packages: .
-with-compiler: ghc-9.8.4
+with-compiler: ghc-9.12.2
diff --git a/cbits/arith.c b/cbits/arith.c
index f19b01e..ee248a4 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -20,6 +20,8 @@
// Shorter names, due to CPP used both in function names and in C types.
+typedef int8_t i8;
+typedef int16_t i16;
typedef int32_t i32;
typedef int64_t i64;
@@ -248,6 +250,8 @@ void oxarrays_stats_print_all(void) {
#define GEN_ABS(x) \
_Generic((x), \
+ i8: abs, \
+ i16: abs, \
int: abs, \
long: labs, \
long long: llabs, \
@@ -490,7 +494,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
if (rank == 0) return arr[0]; \
typ result = 0; \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
- REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \
+ typ dest = 0; \
+ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, dest); \
+ result = result op dest; \
}); \
return result; \
}
@@ -738,7 +744,7 @@ enum redop_tag_t {
* Generate all the functions *
*****************************************************************************/
-#define INT_TYPES_XLIST X(i32) X(i64)
+#define INT_TYPES_XLIST X(i8) X(i16) X(i32) X(i64)
#define FLOAT_TYPES_XLIST X(double) X(float)
#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index 432765c..dc9ad1a 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -2,38 +2,38 @@ LIST_BINOP(BO_ADD, 1, +)
LIST_BINOP(BO_SUB, 2, -)
LIST_BINOP(BO_MUL, 3, *)
-LIST_IBINOP(IB_QUOT, 1, quot)
-LIST_IBINOP(IB_REM, 2, rem)
+LIST_IBINOP(IB_QUOT, 11, quot)
+LIST_IBINOP(IB_REM, 12, rem)
-LIST_FBINOP(FB_DIV, 1, /)
-LIST_FBINOP(FB_POW, 2, **)
-LIST_FBINOP(FB_LOGBASE, 3, logBase)
-LIST_FBINOP(FB_ATAN2, 4, atan2)
+LIST_FBINOP(FB_DIV, 21, /)
+LIST_FBINOP(FB_POW, 22, **)
+LIST_FBINOP(FB_LOGBASE, 23, logBase)
+LIST_FBINOP(FB_ATAN2, 24, atan2)
-LIST_UNOP(UO_NEG, 1,)
-LIST_UNOP(UO_ABS, 2,)
-LIST_UNOP(UO_SIGNUM, 3,)
+LIST_UNOP(UO_NEG, 31,)
+LIST_UNOP(UO_ABS, 32,)
+LIST_UNOP(UO_SIGNUM, 33,)
-LIST_FUNOP(FU_RECIP, 1,)
-LIST_FUNOP(FU_EXP, 2,)
-LIST_FUNOP(FU_LOG, 3,)
-LIST_FUNOP(FU_SQRT, 4,)
-LIST_FUNOP(FU_SIN, 5,)
-LIST_FUNOP(FU_COS, 6,)
-LIST_FUNOP(FU_TAN, 7,)
-LIST_FUNOP(FU_ASIN, 8,)
-LIST_FUNOP(FU_ACOS, 9,)
-LIST_FUNOP(FU_ATAN, 10,)
-LIST_FUNOP(FU_SINH, 11,)
-LIST_FUNOP(FU_COSH, 12,)
-LIST_FUNOP(FU_TANH, 13,)
-LIST_FUNOP(FU_ASINH, 14,)
-LIST_FUNOP(FU_ACOSH, 15,)
-LIST_FUNOP(FU_ATANH, 16,)
-LIST_FUNOP(FU_LOG1P, 17,)
-LIST_FUNOP(FU_EXPM1, 18,)
-LIST_FUNOP(FU_LOG1PEXP, 19,)
-LIST_FUNOP(FU_LOG1MEXP, 20,)
+LIST_FUNOP(FU_RECIP, 41,)
+LIST_FUNOP(FU_EXP, 42,)
+LIST_FUNOP(FU_LOG, 43,)
+LIST_FUNOP(FU_SQRT, 44,)
+LIST_FUNOP(FU_SIN, 45,)
+LIST_FUNOP(FU_COS, 46,)
+LIST_FUNOP(FU_TAN, 47,)
+LIST_FUNOP(FU_ASIN, 48,)
+LIST_FUNOP(FU_ACOS, 49,)
+LIST_FUNOP(FU_ATAN, 50,)
+LIST_FUNOP(FU_SINH, 51,)
+LIST_FUNOP(FU_COSH, 52,)
+LIST_FUNOP(FU_TANH, 53,)
+LIST_FUNOP(FU_ASINH, 54,)
+LIST_FUNOP(FU_ACOSH, 55,)
+LIST_FUNOP(FU_ATANH, 56,)
+LIST_FUNOP(FU_LOG1P, 57,)
+LIST_FUNOP(FU_EXPM1, 58,)
+LIST_FUNOP(FU_LOG1PEXP, 59,)
+LIST_FUNOP(FU_LOG1MEXP, 60,)
-LIST_REDOP(RO_SUM, 1,)
-LIST_REDOP(RO_PRODUCT, 2,)
+LIST_REDOP(RO_SUM, 81,)
+LIST_REDOP(RO_PRODUCT, 82,)
diff --git a/gentrace.sh b/gentrace.sh
index 7be2b9c..c3f1240 100755
--- a/gentrace.sh
+++ b/gentrace.sh
@@ -8,7 +8,7 @@ module Data.Array.Nested.Trace (
-- * Re-exports from the plain "Data.Array.Nested" module
EOF
-sed -n '/^module/,/^) where/!d; /^\s*-- /d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
+sed -n '/^module/,/^) where/!d; /^\s*--\( \|$\)/d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
cat <<'EOF'
) where
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 5802573..2eb0666 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -27,7 +27,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
+import GHC.TypeNats qualified as TN
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
@@ -42,7 +42,7 @@ import Data.Array.Strided.Array
-- TODO: move this to a utilities module
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
data Dict c where
Dict :: c => Dict c
@@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) =
unrepSize = product [n | (n, True) <- zip sh replDims]
- in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
simplifyArray :: Array n a
@@ -200,7 +200,7 @@ simplifyArray :: Array n a
-> r)
-> r
simplifyArray array k
- | let revDims = map (<0) (arrStrides array)
+ | let revDims = map (< 0) (arrStrides array)
, Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
= k array'
unrepSize
@@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
, let unrepSize = product [n | (n, True) <- zip sh replDims]
- = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
k @lenshF
(Array shF strides1F offset1 vec1)
(Array shF strides2F offset2 vec2)
@@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
VS.unsafeWith vec' $ \pv ->
let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
- TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -396,6 +396,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
Nothing -> error "impossible"
-- TODO: test handling of negative strides
+-- TODO: simplify away normalised dimensions
-- | Reduce full array
{-# NOINLINE vectorRedFullOp #-}
vectorRedFullOp :: forall a b n. (Num a, Storable a)
@@ -490,7 +491,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
- TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -714,6 +715,36 @@ class NumElt a where
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+instance NumElt Int8 where
+ numEltAdd = addVectorInt8
+ numEltSub = subVectorInt8
+ numEltMul = mulVectorInt8
+ numEltNeg = negVectorInt8
+ numEltAbs = absVectorInt8
+ numEltSignum = signumVectorInt8
+ numEltSum1Inner = sum1VectorInt8
+ numEltProduct1Inner = product1VectorInt8
+ numEltSumFull = sumFullVectorInt8
+ numEltProductFull = productFullVectorInt8
+ numEltMinIndex _ = minindexVectorInt8
+ numEltMaxIndex _ = maxindexVectorInt8
+ numEltDotprodInner = dotprodinnerVectorInt8
+
+instance NumElt Int16 where
+ numEltAdd = addVectorInt16
+ numEltSub = subVectorInt16
+ numEltMul = mulVectorInt16
+ numEltNeg = negVectorInt16
+ numEltAbs = absVectorInt16
+ numEltSignum = signumVectorInt16
+ numEltSum1Inner = sum1VectorInt16
+ numEltProduct1Inner = product1VectorInt16
+ numEltSumFull = sumFullVectorInt16
+ numEltProductFull = productFullVectorInt16
+ numEltMinIndex _ = minindexVectorInt16
+ numEltMaxIndex _ = maxindexVectorInt16
+ numEltDotprodInner = dotprodinnerVectorInt16
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
@@ -830,6 +861,14 @@ class NumElt a => IntElt a where
intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+instance IntElt Int8 where
+ intEltQuot = quotVectorInt8
+ intEltRem = remVectorInt8
+
+instance IntElt Int16 where
+ intEltQuot = quotVectorInt16
+ intEltRem = remVectorInt16
+
instance IntElt Int32 where
intEltQuot = quotVectorInt32
intEltRem = remVectorInt32
@@ -840,19 +879,19 @@ instance IntElt Int64 where
instance IntElt Int where
intEltQuot = intWidBranch2 @Int quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @Int rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
instance IntElt CInt where
intEltQuot = intWidBranch2 @CInt quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @CInt rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 910a77c..27204d2 100644
--- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -16,7 +16,9 @@ data ArithType = ArithType
intTypesList :: [ArithType]
intTypesList =
- [ArithType ''Int32 "i32"
+ [ArithType ''Int8 "i8"
+ ,ArithType ''Int16 "i16"
+ ,ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
]
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 2a15cbb..b3e5a98 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -1,18 +1,23 @@
cabal-version: 3.0
name: ox-arrays
-version: 0.1.0.0
+version: 0.2.0.1
synopsis: An efficient CPU-based multidimensional array (tensor) library
description:
An efficient and richly typed CPU-based multidimensional array (tensor)
library built upon the optimized tensor representation (strides list)
- implemented in the orthotope package.
-copyright: (c) 2025 Tom Smeding, Mikolaj Konarski
+ implemented in the orthotope package. See the README.
+
+ If you use this package: let me know (e.g. via email) if you find it useful!
+ Both positive feedback (keep this!) and negative feedback (I needed this but
+ ox-arrays doesn't provide it) is welcome.
+copyright: (c) 2025-2026 Tom Smeding, Mikolaj Konarski
author: Tom Smeding, Mikolaj Konarski
maintainer: Tom Smeding <xhackage@tomsmeding.com>
license: BSD-3-Clause
category: Array, Tensors
build-type: Simple
+extra-doc-files: README.md CHANGELOG.md
extra-source-files: cbits/arith_lists.h
flag trace-wrappers
@@ -20,7 +25,7 @@ flag trace-wrappers
Compile modules that define wrappers around the array methods that trace
their arguments and results. This is conditional on a flag because these
modules make documentation generation fail.
- (https://gitlab.haskell.org/ghc/ghc/-/issues/24964 , should be fixed in
+ (@https://gitlab.haskell.org/ghc/ghc/-/issues/24964@ , should be fixed in
GHC 9.12)
default: False
manual: True
@@ -42,47 +47,74 @@ flag pedantic-c-warnings
default: False
manual: True
+flag default-show-instances
+ description:
+ Use default GHC-derived Show instances for arrays, shapes and indices. This
+ exposes the internal struct-of-arrays representation and is less readable,
+ but can be useful for ox-arrays debugging.
+ default: False
+ manual: True
+
+common basics
+ default-language: Haskell2010
+ ghc-options: -Wall -Wcompat -Widentities -Wunused-packages -Wpartial-fields -Wredundant-bang-patterns -Woperator-whitespace -Wredundant-strictness-flags
+ if impl(ghc >= 9.14)
+ ghc-options: -Wno-pattern-namespace-specifier
+
library
+ import: basics
exposed-modules:
-- put this module on top so ghci considers it the "main" module
Data.Array.Nested
- Data.Array.Mixed.Internal.Arith
- Data.Array.Mixed.Lemmas
- Data.Array.Mixed.Permutation
- Data.Array.Mixed.Types
- Data.Array.Mixed.XArray
- Data.Array.Nested.Internal.Convert
- Data.Array.Nested.Internal.Mixed
- Data.Array.Nested.Internal.Lemmas
- Data.Array.Nested.Internal.Ranked
- Data.Array.Nested.Internal.Shaped
+ Data.Array.Nested.Convert
+ Data.Array.Nested.Mixed
+ Data.Array.Nested.Mixed.ListX
Data.Array.Nested.Mixed.Shape
+ Data.Array.Nested.Lemmas
+ Data.Array.Nested.Permutation
+ Data.Array.Nested.Ranked
+ Data.Array.Nested.Ranked.Base
Data.Array.Nested.Ranked.Shape
+ Data.Array.Nested.Shaped
+ Data.Array.Nested.Shaped.Base
Data.Array.Nested.Shaped.Shape
+ Data.Array.Nested.Types
+ Data.Array.Strided.Orthotope
+ Data.Array.XArray
Data.Bag
+ Data.Vector.Generic.Checked
+
+ if impl(ghc < 9.8)
+ exposed-modules:
+ GHC.TypeLits.Orphans
if flag(trace-wrappers)
exposed-modules:
Data.Array.Nested.Trace
Data.Array.Nested.Trace.TH
+ build-depends:
+ template-haskell
+ other-extensions: TemplateHaskell
+
+ if flag(default-show-instances)
+ cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES
build-depends:
- strided-array-ops,
+ ox-arrays:strided-array-ops,
base,
deepseq < 1.7,
ghc-typelits-knownnat,
ghc-typelits-natnormalise,
- orthotope < 0.2,
- vector
+ orthotope >= 0.1.8.0 && < 0.2,
+ vector < 0.14,
+ vector-stream < 0.14
hs-source-dirs: src
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
- other-extensions: TemplateHaskell
-
library strided-array-ops
+ import: basics
+ visibility: public
exposed-modules:
Data.Array.Strided
Data.Array.Strided.Array
@@ -92,9 +124,11 @@ library strided-array-ops
Data.Array.Strided.Arith.Internal.Lists
Data.Array.Strided.Arith.Internal.Lists.TH
build-depends:
- base >=4.18 && <4.22,
- ghc-typelits-knownnat < 1,
- ghc-typelits-natnormalise < 1,
+ base >=4.18 && <4.23,
+ ghc-typelits-knownnat >= 0.8 && < 1
+ -- 0.9.0 is unsound: https://github.com/clash-lang/ghc-typelits-natnormalise/issues/105
+ && (< 0.9 || >= 0.9.1),
+ ghc-typelits-natnormalise >= 0.8.1 && < 1,
template-haskell < 3,
vector < 0.14
hs-source-dirs: ops
@@ -109,11 +143,10 @@ library strided-array-ops
-- hmatrix assumes sse2, so we can too
cc-options: -msse2
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
other-extensions: TemplateHaskell
test-suite test
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
other-modules:
@@ -134,33 +167,29 @@ test-suite test
tasty-hedgehog,
vector
hs-source-dirs: test
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
test-suite example
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
build-depends:
ox-arrays,
base
hs-source-dirs: example
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
benchmark bench
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
build-depends:
ox-arrays,
- strided-array-ops,
+ ox-arrays:strided-array-ops,
base,
hmatrix,
orthotope,
tasty-bench,
vector
hs-source-dirs: bench
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
source-repository head
type: git
diff --git a/release-hints.txt b/release-hints.txt
index 259c671..2623caa 100644
--- a/release-hints.txt
+++ b/release-hints.txt
@@ -1,2 +1,5 @@
- Temporarily enable -Wredundant-constraints
- Has too many false-positives to enable normally, but sometimes catches actual redundant constraints
+- Don't forget to rerun gentrace.sh
+- Test with GHC 9.6, it's rather picky around type-level nats
+ - Whenever we drop support for GHC 9.6, search for "9,8" and remove all the conditionals, as well as the GHC.TypeLits.Orphans module
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
deleted file mode 100644
index ca82573..0000000
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ /dev/null
@@ -1,137 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Lemmas where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-
-import Data.Array.Mixed.Permutation
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
-
-
--- * Reasoning helpers
-
-subst1 :: forall f a b. a :~: b -> f a :~: f b
-subst1 Refl = Refl
-
-subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
-subst2 Refl = Refl
-
-
--- * Lemmas
-
--- ** Nat
-
-lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
-lemLeqSuccSucc _ _ = unsafeCoerceRefl
-
-lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
-lemLeqPlus _ _ _ = Refl
-
-
--- ** Append
-
-lemAppNil :: l ++ '[] :~: l
-lemAppNil = unsafeCoerceRefl
-
-lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
-lemAppAssoc _ _ _ = unsafeCoerceRefl
-
-lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
-lemAppLeft _ Refl = Refl
-
-
--- ** Rank
-
-lemRankApp :: forall sh1 sh2.
- StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
-lemRankApp ZKX _ = Refl
-lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
- = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
- lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
- lemRankApp ssh1 ssh2
- where
- lem :: proxy a -> proxy b -> proxy c
- -> c :~: b + a
- -> b + a :~: c
- lem _ _ _ Refl = Refl
-
- lem2 :: proxy a -> proxy b -> proxy c
- -> (a + b :~: c)
- -> c + 1 :~: (a + 1 + b)
- lem2 _ _ _ Refl = Refl
-
-lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
-lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
-
-lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate SZ = Refl
-lemRankReplicate (SS (n :: SNat nm1))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
- , Refl <- lemRankReplicate n
- = Refl
-
-
--- ** Various type families
-
-lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
- -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp sn _ _ = go sn
- where
- go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
- go SZ = Refl
- go (SS (n :: SNat n'm1))
- | Refl <- lemReplicateSucc @a @n'm1
- , Refl <- go n
- = sym (lemReplicateSucc @a @(n'm1 + m))
-
-lemDropLenApp :: Rank l1 <= Rank l2
- => Proxy l1 -> Proxy l2 -> Proxy rest
- -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
-lemDropLenApp _ _ _ = unsafeCoerceRefl
-
-lemTakeLenApp :: Rank l1 <= Rank l2
- => Proxy l1 -> Proxy l2 -> Proxy rest
- -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
-lemTakeLenApp _ _ _ = unsafeCoerceRefl
-
-lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
-lemInitApp _ _ = unsafeCoerceRefl
-
-lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
-lemLastApp _ _ = unsafeCoerceRefl
-
-
--- ** KnownNat
-
-lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
-lemKnownNatSucc = Dict
-
-lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
-lemKnownNatRank ZSX = Dict
-lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
-
-lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
-lemKnownNatRankSSX ZKX = Dict
-lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
-
-
--- ** Known shapes
-
-lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
-lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
-
-lemKnownShX :: StaticShX sh -> Dict KnownShX sh
-lemKnownShX ZKX = Dict
-lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
-lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 8198a54..6d4ae78 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -3,15 +3,19 @@
module Data.Array.Nested (
-- * Ranked arrays
Ranked(Ranked),
- ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
- rrerank,
- rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rfromListLinear, rfromListPrimLinear, rtoListLinear,
+ rrerankPrim,
+ rreplicate, rreplicatePrim,
+ rfromListOuter, rfromListOuterN,
+ rfromList1, rfromList1N,
+ rfromListLinear,
+ rfromList1Prim, rfromList1PrimN,
+ rfromListPrimLinear,
+ rtoListOuter, rtoList, rtoListLinear, rtoListPrim, rtoListPrimLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest, rzip, runzip,
@@ -19,7 +23,7 @@ module Data.Array.Nested (
rlift, rlift2,
-- ** Conversions
rtoXArrayPrim, rfromXArrayPrim,
- rcastToShaped, rtoMixed, rcastToMixed,
+ rtoMixed, rcastToMixed, rcastToShaped,
rfromOrthotope, rtoOrthotope,
-- ** Additional arithmetic operations
--
@@ -28,16 +32,16 @@ module Data.Array.Nested (
-- * Shaped arrays
Shaped(Shaped),
- ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
- sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
semptyArray,
- srerank,
- sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sfromListLinear, sfromListPrimLinear, stoListLinear,
+ srerankPrim,
+ sreplicate, sreplicatePrim,
+ sfromListOuter, sfromList1, sfromListLinear, sfromList1Prim, sfromListPrimLinear,
+ stoListOuter, stoList, stoListLinear, stoListPrim, stoListPrimLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest, szip, sunzip,
@@ -45,7 +49,7 @@ module Data.Array.Nested (
slift, slift2,
-- ** Conversions
stoXArrayPrim, sfromXArrayPrim,
- stoRanked, stoMixed, scastToMixed,
+ stoMixed, scastToMixed, stoRanked,
sfromOrthotope, stoOrthotope,
-- ** Additional arithmetic operations
--
@@ -54,18 +58,22 @@ module Data.Array.Nested (
-- * Mixed arrays
Mixed,
- ListX(ZX, (::%)),
IxX(.., ZIX, (:.%)), IIxX,
- ShX(.., ZSX, (:$%)), KnownShX(..), IShX,
+ ShX(.., (:$%)), KnownShX(..), IShX,
StaticShX(.., ZKX, (:!%)),
SMayNat(..),
- mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
- mrerank,
- mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mfromListLinear, mfromListPrimLinear, mtoListLinear,
- mslice, mrev1, mreshape, mflatten, miota,
+ mrerankPrim,
+ mreplicate, mreplicatePrim,
+ mfromListOuter, mfromListOuterN, mfromListOuterSN,
+ mfromList1, mfromList1N, mfromList1SN,
+ mfromListLinear,
+ mfromList1Prim, mfromList1PrimN, mfromList1PrimSN,
+ mfromListPrimLinear,
+ mtoListOuter, mtoList, mtoListLinear, mtoListPrim, mtoListPrimLinear,
+ msliceN, msliceSN, mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
-- ** Lifting orthotope operations to 'Mixed' arrays
@@ -73,9 +81,8 @@ module Data.Array.Nested (
-- ** Conversions
mtoXArrayPrim, mfromXArrayPrim,
mcast,
- mcastSafe, SafeMCast, SafeMCastSpec(..),
- mtoRanked, mcastToShaped,
- castCastable, Castable(..),
+ mcastToShaped, mtoRanked,
+ convert, Conversion(..),
-- ** Additional arithmetic operations
--
-- $integralRealFloat
@@ -92,7 +99,7 @@ module Data.Array.Nested (
Storable,
SNat, pattern SNat,
pattern SZ, pattern SS,
- Perm(..),
+ Perm(..), PermR,
IsPermutation,
KnownPerm(..),
NumElt, IntElt, FloatElt,
@@ -103,23 +110,25 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Convert
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Internal.Shaped
+import Foreign.Storable
+import GHC.TypeLits
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped
import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
-import Foreign.Storable
-import GHC.TypeLits
-- $integralRealFloat
--
--- These functions separate top-level functions, and not exposed in instances
--- for 'RealFloat' and 'Integral', because those classes include a variety of
--- other functions that make no sense for arrays.
+-- These functions are separate top-level functions, and not exposed in
+-- instances for 'RealFloat' and 'Integral', because those classes include a
+-- variety of other functions that make no sense for arrays.
-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but
-- having 'Num', 'Fractional' and 'Floating' available is just too useful.
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
new file mode 100644
index 0000000..7619bdb
--- /dev/null
+++ b/src/Data/Array/Nested/Convert.hs
@@ -0,0 +1,344 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+{-# LANGUAGE TypeAbstractions #-}
+#endif
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+module Data.Array.Nested.Convert (
+ -- * Shape\/index\/list casting functions
+ -- ** To ranked
+ ixrFromIxS, ixrFromIxS', ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX,
+ ixrCast, shrCast,
+ -- ** To shaped
+ ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
+ ixsCast,
+ -- ** To mixed
+ ixxFromIxR, ixxFromIxS, ixxFromIxS', shxFromShR, shxFromShS,
+ ixxCast, shxCast, shxCast',
+
+ -- * Array conversions
+ convert,
+ Conversion(..),
+
+ -- * Special cases of array conversions
+ --
+ -- | These functions can all be implemented using 'convert' in some way,
+ -- but some have fewer constraints.
+ rtoMixed, rcastToMixed, rcastToShaped,
+ stoMixed, scastToMixed, stoRanked,
+ mcast, mcastToShaped, mtoRanked,
+) where
+
+import Control.Category
+import Data.Coerce (coerce)
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.ListX
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped.Base
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+-- * Shape or index or list casting functions
+
+-- * To ranked
+
+ixrFromIxS :: forall sh i. IxS sh i -> IxR (Rank sh) i
+ixrFromIxS
+ | Refl <- lemRankReplicate (Proxy @(Rank sh))
+ , Refl <- lemRankMapJust (Proxy @sh)
+ = coerce
+ Prelude.. (coerceEqualRankListX :: ListX (MapJust sh) i -> ListX (Replicate (Rank sh) Nothing) i)
+ Prelude.. coerce
+
+ixrFromIxS' :: forall sh i. SNat (Rank sh) -> IxS sh i -> IxR (Rank sh) i
+ixrFromIxS' _
+ | Refl <- lemRankReplicate (Proxy @(Rank sh))
+ , Refl <- lemRankMapJust (Proxy @sh)
+ = coerce
+ Prelude.. (coerceEqualRankListX :: ListX (MapJust sh) i -> ListX (Replicate (Rank sh) Nothing) i)
+ Prelude.. coerce
+
+-- ixrFromIxX re-exported
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
+shrFromShXAnyShape :: IShX sh -> IShR (Rank sh)
+shrFromShXAnyShape ZSX = ZSR
+shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx
+
+shrFromShX :: IShX (Replicate n Nothing) -> IShR n
+shrFromShX = coerce
+
+-- ixrCast re-exported
+-- shrCast re-exported
+
+-- * To shaped
+
+ixsFromIxR :: forall sh i. IxR (Rank sh) i -> IxS sh i
+ixsFromIxR
+ | Refl <- lemRankReplicate (Proxy @(Rank sh))
+ , Refl <- lemRankMapJust (Proxy @sh)
+ = coerce
+ Prelude.. (coerceEqualRankListX :: ListX (Replicate (Rank sh) Nothing) i -> ListX (MapJust sh) i)
+ Prelude.. coerce
+
+ixsFromIxR' :: forall sh i. ShS sh -> IxR (Rank sh) i -> IxS sh i
+ixsFromIxR' _
+ | Refl <- lemRankReplicate (Proxy @(Rank sh))
+ , Refl <- lemRankMapJust (Proxy @sh)
+ = coerce
+ Prelude.. (coerceEqualRankListX :: ListX (Replicate (Rank sh) Nothing) i -> ListX (MapJust sh) i)
+ Prelude.. coerce
+
+-- ixsFromIxX re-exported
+
+-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
+-- the following, but less verbose:
+--
+-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx)
+ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i
+ixsFromIxX' ZSS ZIX = ZIS
+ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx
+ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank"
+
+-- | Produce an existential 'ShS' from an 'IShR'.
+withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r
+withShsFromShR ZSR k = k ZSS
+withShsFromShR (n :$: sh) k =
+ withShsFromShR sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
+
+shsFromShX :: IShX (MapJust sh) -> ShS sh
+shsFromShX = coerce
+
+-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
+-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
+withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r
+withShsFromShX ZSX k = k ZSS
+withShsFromShX (SKnown sn :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ k (sn :$$ sh')
+withShsFromShX (SUnknown n :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+
+-- If it ever matters for performance, this is unsafeCoercible.
+shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
+shsFromSSX = shsFromShX Prelude.. shxFromSSX
+
+-- ixsCast re-exported
+
+-- * To mixed
+
+-- ixxFromIxR re-exported
+-- ixxFromIxS re-exported
+
+ixxFromIxS' :: StaticShX sh' -> IxS sh i -> IxX sh' i
+ixxFromIxS' sh' = ixxCast sh' Prelude.. ixxFromIxS
+
+shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
+shxFromShR = coerce
+
+shxFromShS :: ShS sh -> IShX (MapJust sh)
+shxFromShS = coerce
+
+-- ixxCast re-exported
+-- shxCast re-exported
+-- shxCast' re-exported
+
+
+-- * Array conversions
+
+-- | The constructors that perform runtime shape checking are marked with a
+-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types
+-- ensure that the shapes are already compatible. To convert between 'Ranked'
+-- and 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Conversion' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'Data.Array.XArray.XArray's. This leads to the inclusion
+-- of some operations that do not look like simple conversions (casts)
+-- at first glance, like 'ConvZip'.
+--
+-- /Note/: Haddock gleefully renames type variables in constructors so that
+-- they match the data type head as much as possible. See the source for a more
+-- readable presentation of this data type.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Conversion (Mixed sh a) (Shaped sh' a)
+
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
+ where
+ -- The 'esh' is the extension shape: the conversion happens under a whole
+ -- bunch of additional dimensions that it does not touch. These dimensions
+ -- are 'esh'.
+ -- The strategy is to unwind step-by-step to a large Mixed array, and to
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
+ x))
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go ConvX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x)
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go ConvZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go ConvUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
+
+ lemRankAppRankEq :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
+ lemRankAppRankEq _ _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
+
+
+-- * Special cases of array conversions
+
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked = convert ConvXR
+
+rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
+rtoMixed (Ranked arr) = arr
+
+-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
+-- compatibility check.
+rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
+rcastToMixed sshx rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank rarr)
+ = mcast sshx arr
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => ShS sh' -> Mixed sh a -> Shaped sh' a
+mcastToShaped targetsh = convert (ConvXS' targetsh)
+
+stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
+stoMixed (Shaped arr) = arr
+
+-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
+-- compatibility check.
+scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => StaticShX sh' -> Shaped sh a -> Mixed sh' a
+scastToMixed sshx sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mcast sshx arr
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr
+
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped targetsh arr
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
deleted file mode 100644
index 5d6cee4..0000000
--- a/src/Data/Array/Nested/Internal/Convert.hs
+++ /dev/null
@@ -1,86 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Convert where
-
-import Control.Category
-import Data.Proxy
-import Data.Type.Equality
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Lemmas
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Shaped.Shape
-import Data.Array.Nested.Internal.Shaped
-
-
-stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
-stoRanked sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- = mtoRanked arr
-
-rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
-rcastToShaped (Ranked arr) targetsh
- | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
- , Refl <- lemRankMapJust targetsh
- = mcastToShaped arr targetsh
-
--- | The only constructor that performs runtime shape checking is 'CastXS''.
--- For the other construtors, the types ensure that the shapes are already
--- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
-data Castable a b where
- CastId :: Castable a a
- CastCmp :: Castable b c -> Castable a b -> Castable a c
-
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
-
- CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
-
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
-
-instance Category Castable where
- id = CastId
- (.) = CastCmp
-
-castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
-castCastable = \c x -> munScalar (go c (mscalar x))
- where
- -- The 'esh' is the extension shape: the casting happens under a whole
- -- bunch of additional dimensions that it does not touch. These dimensions
- -- are 'esh'.
- -- The strategy is to unwind step-by-step to a large Mixed array, and to
- -- perform the required checks and castings when re-nesting back up.
- go :: Castable a b -> Mixed esh a -> Mixed esh b
- go CastId x = x
- go (CastCmp c1 c2) x = go c1 (go c2 x)
- go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
- M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
- (go c x)))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
- (go c x)))
- go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
- go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
- go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
-
- lemRankAppMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppMapJust _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
deleted file mode 100644
index b8baf96..0000000
--- a/src/Data/Array/Nested/Internal/Lemmas.hs
+++ /dev/null
@@ -1,59 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Lemmas where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Shaped.Shape
-
-
-lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
-lemRankMapJust ZSS = Refl
-lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
-
-lemMapJustApp :: ShS sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustApp ZSS _ = Refl
-lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
-
-lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
-lemTakeLenMapJust PNil _ = Refl
-lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
-lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
-
-lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
-lemDropLenMapJust PNil _ = Refl
-lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
-lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
-
-lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
-lemIndexMapJust SZ (_ :$$ _) = Refl
-lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
- | Refl <- lemIndexMapJust i sh
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = Refl
-lemIndexMapJust _ ZSS = error "Index of empty"
-
-lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
-lemPermuteMapJust PNil _ = Refl
-lemPermuteMapJust (i `PCons` is) sh
- | Refl <- lemPermuteMapJust is sh
- , Refl <- lemIndexMapJust i sh
- = Refl
-
-lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
-lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKX
- go (n :$$ sh) = SKnown n :!% go sh
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
deleted file mode 100644
index 368e337..0000000
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ /dev/null
@@ -1,559 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Ranked where
-
-import Prelude hiding (mappend, mconcat)
-
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
-import Data.Array.RankedS qualified as S
-import Data.Bifunctor (first)
-import Data.Coerce (coerce)
-import Data.Foldable (toList)
-import Data.Kind (Type)
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray(..))
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Ranked.Shape
-import Data.Array.Strided.Arith
-
-
--- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'Nat'.
---
--- Valid elements of a ranked arrays are described by the 'Elt' type class.
--- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
--- supported (and are represented as a single, flattened, struct-of-arrays
--- array internally).
---
--- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: Nat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
-deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
-
-instance (Show a, Elt a) => Show (Ranked n a) where
- showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
- in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Ranked n a) where
- rnf (Ranked arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance Elt a => Elt (Ranked n a) where
- mshape (M_Ranked arr) = mshape arr
- mindex (M_Ranked arr) i = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i =
- coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
-
- mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoListOuter (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift ssh2 f (M_Ranked arr) =
- coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
- coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
- @(NonEmpty (Mixed sh2 (Ranked n a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
-
- mconcat l = M_Ranked (mconcat (coerce l))
-
- mrnf (M_Ranked arr) = mrnf arr
-
- type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
-
- mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Ranked arr) = marrayStrides arr
-
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
- memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
- | Dict <- lemKnownReplicate (SNat @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
-
-liftRanked1 :: forall n a b.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
- -> Ranked n a -> Ranked n b
-liftRanked1 = coerce
-
-liftRanked2 :: forall n a b c.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
- -> Ranked n a -> Ranked n b -> Ranked n c
-liftRanked2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Ranked n a) where
- (+) = liftRanked2 (+)
- (-) = liftRanked2 (-)
- (*) = liftRanked2 (*)
- negate = liftRanked1 negate
- abs = liftRanked1 abs
- signum = liftRanked1 signum
- fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
- recip = liftRanked1 recip
- (/) = liftRanked2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
- exp = liftRanked1 exp
- log = liftRanked1 log
- sqrt = liftRanked1 sqrt
- (**) = liftRanked2 (**)
- logBase = liftRanked2 logBase
- sin = liftRanked1 sin
- cos = liftRanked1 cos
- tan = liftRanked1 tan
- asin = liftRanked1 asin
- acos = liftRanked1 acos
- atan = liftRanked1 atan
- sinh = liftRanked1 sinh
- cosh = liftRanked1 cosh
- tanh = liftRanked1 tanh
- asinh = liftRanked1 asinh
- acosh = liftRanked1 acosh
- atanh = liftRanked1 atanh
- log1p = liftRanked1 GHC.Float.log1p
- expm1 = liftRanked1 GHC.Float.expm1
- log1pexp = liftRanked1 GHC.Float.log1pexp
- log1mexp = liftRanked1 GHC.Float.log1mexp
-
-rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-rquotArray = liftRanked2 mquotArray
-rremArray = liftRanked2 mremArray
-
-ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-ratan2Array = liftRanked2 matan2Array
-
-
-remptyArray :: KnownElt a => Ranked 1 a
-remptyArray = mtoRanked (memptyArray ZSX)
-
-rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shCvtXR' (mshape arr)
-
-rrank :: Elt a => Ranked n a -> SNat n
-rrank = shrRank . rshape
-
--- | The total number of elements in the array.
-rsize :: Elt a => Ranked n a -> Int
-rsize = shrSize . rshape
-
-rindex :: Elt a => Ranked n a -> IIxR n -> a
-rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-
-rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
-rindexPartial (Ranked arr) idx =
- Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
- (ixCvtRX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
-rgenerate sh f
- | sn@SNat <- shrRank sh
- , Dict <- lemKnownReplicate sn
- , Refl <- lemRankReplicate sn
- = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-
--- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. Elt a
- => SNat n2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a
-rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-
--- | See the documentation of 'mlift2'.
-rlift2 :: forall n1 n2 n3 a. Elt a
- => SNat n3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
-rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (msumOuter1P arr)
-
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
-
-rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
-rsumAllPrim (Ranked arr) = msumAllPrim arr
-
-rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
-rtranspose perm arr
- | sn@SNat <- rrank arr
- , Dict <- lemKnownReplicate sn
- , length perm <= fromIntegral (natVal (Proxy @n))
- = rlift sn
- (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
- arr
- | otherwise
- = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
-
-rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
-rconcat
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce mconcat
-
-rappend :: forall n a. Elt a
- => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
-rappend arr1 arr2
- | sn@SNat <- rrank arr1
- , Dict <- lemKnownReplicate sn
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
- arr1 arr2
-
-rscalar :: Elt a => a -> Ranked 0 a
-rscalar x = Ranked (mscalar x)
-
-rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
-rfromVectorP sh v
- | Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mfromVectorP (shCvtRX sh) v)
-
-rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
-rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
-
-rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
-rtoVectorP = coerce mtoVectorP
-
-rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
-rtoVector = coerce mtoVector
-
-rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-
-rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
-
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
-
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
-
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
-
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
-
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
-
-rtoListLinear :: Elt a => Ranked n a -> [a]
-rtoListLinear (Ranked arr) = mtoListLinear arr
-
-rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
-rfromOrthotope sn arr
- | Refl <- lemRankReplicate sn
- = let xarr = XArray arr
- in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
-
-rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
-rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh)
- = arr
-
-runScalar :: Elt a => Ranked 0 a -> a
-runScalar arr = rindex arr ZIR
-
-rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
-rnest n arr
- | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
- = coerce (mnest (ssxFromSNat n) (coerce arr))
-
-runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
-runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
- | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
- = Ranked arr
-
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
-rzip = coerce mzip
-
-runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
-runzip = coerce munzip
-
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
-
--- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
--- input array, then there is no way to deduce the full shape of the output
--- array (more precisely, the @n2@ part): that could only come from calling
--- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
--- this case; we choose to fill the @n2@ part of the output shape with zeros.
---
--- For example, if:
---
--- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
--- f :: Ranked 2 Int -> Ranked 3 Float
--- @
---
--- then:
---
--- @
--- rrerank _ _ _ f arr :: Ranked 5 Float
--- @
---
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank sn sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
-
-rreplicate :: forall n m a. Elt a
- => IShR n -> Ranked m a -> Ranked (n + m) a
-rreplicate sh (Ranked arr)
- | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
- = Ranked (mreplicate (shCvtRX sh) arr)
-
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
- | Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mreplicateScalP (shCvtRX sh) x)
-
-rreplicateScal :: forall n a. PrimElt a
- => IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
-
-rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (rrank arr)
- (\_ -> X.sliceU i n)
- arr
-
-rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (rrank arr)
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
-
-rreshape :: forall n n' a. Elt a
- => IShR n' -> Ranked n a -> Ranked n' a
-rreshape sh' rarr@(Ranked arr)
- | Dict <- lemKnownReplicate (rrank rarr)
- , Dict <- lemKnownReplicate (shrRank sh')
- = Ranked (mreshape (shCvtRX sh') arr)
-
-rflatten :: Elt a => Ranked n a -> Ranked 1 a
-rflatten (Ranked arr) = mtoRanked (mflatten arr)
-
-riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
-riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
-
--- | Throws if the array is empty.
-rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
-rminIndexPrim rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (mminIndexPrim arr)
-
--- | Throws if the array is empty.
-rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
-rmaxIndexPrim rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (mmaxIndexPrim arr)
-
-rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
-rdot1Inner arr1 arr2
- | SNat <- rrank arr1
- , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
- = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
-
--- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'rdot1Inner' if applicable.
-rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
-rdot = coerce mdot
-
-rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
-
-rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr)
-
-rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-
-rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-
-rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
-rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
-
-rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
-rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
-
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked arr
- | Refl <- lemRankReplicate (shxRank (mshape arr))
- = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
- where
- convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
- convSh ZSX = ZSX
- convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
- = SUnknown (fromSMayNat' smn) :$% convSh sh
-
-rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
-rtoMixed (Ranked arr) = arr
-
--- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
--- compatibility check.
-rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
-rcastToMixed sshx rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank rarr)
- = mcast sshx arr
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
deleted file mode 100644
index 86dcee2..0000000
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ /dev/null
@@ -1,495 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Shaped where
-
-import Prelude hiding (mappend, mconcat)
-
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
-import Data.Array.Internal.RankedG qualified as RG
-import Data.Array.Internal.RankedS qualified as RS
-import Data.Array.Internal.ShapedG qualified as SG
-import Data.Array.Internal.ShapedS qualified as SS
-import Data.Bifunctor (first)
-import Data.Coerce (coerce)
-import Data.Kind (Type)
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray)
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Nested.Internal.Lemmas
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Shaped.Shape
-import Data.Array.Strided.Arith
-
-
--- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's. Note that
--- these are "GHC.TypeLits" naturals, because we do not need induction over
--- them and we want very large arrays to be possible.
---
--- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
--- and 'Shaped' itself is again an instance of 'Elt' as well.
---
--- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
-deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
-
-instance (Show a, Elt a) => Show (Shaped n a) where
- showsPrec d arr@(Shaped marr) =
- let sh = show (shsToList (sshape arr))
- in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Shaped sh a) where
- rnf (Shaped arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
-
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
-
-instance Elt a => Elt (Shaped sh a) where
- mshape (M_Shaped arr) = mshape arr
- mindex (M_Shaped arr) i = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
-
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
-
- mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoListOuter (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift ssh2 f (M_Shaped arr) =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
- coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
- @(NonEmpty (Mixed sh2 (Shaped sh a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
-
- mconcat l = M_Shaped (mconcat (coerce l))
-
- mrnf (M_Shaped arr) = mrnf arr
-
- type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-
- mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Shaped arr) = marrayStrides arr
-
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
-instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
- memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
-
-liftShaped1 :: forall sh a b.
- (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
- -> Shaped sh a -> Shaped sh b
-liftShaped1 = coerce
-
-liftShaped2 :: forall sh a b c.
- (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
- -> Shaped sh a -> Shaped sh b -> Shaped sh c
-liftShaped2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
- (+) = liftShaped2 (+)
- (-) = liftShaped2 (-)
- (*) = liftShaped2 (*)
- negate = liftShaped1 negate
- abs = liftShaped1 abs
- signum = liftShaped1 signum
- fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
- recip = liftShaped1 recip
- (/) = liftShaped2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
- exp = liftShaped1 exp
- log = liftShaped1 log
- sqrt = liftShaped1 sqrt
- (**) = liftShaped2 (**)
- logBase = liftShaped2 logBase
- sin = liftShaped1 sin
- cos = liftShaped1 cos
- tan = liftShaped1 tan
- asin = liftShaped1 asin
- acos = liftShaped1 acos
- atan = liftShaped1 atan
- sinh = liftShaped1 sinh
- cosh = liftShaped1 cosh
- tanh = liftShaped1 tanh
- asinh = liftShaped1 asinh
- acosh = liftShaped1 acosh
- atanh = liftShaped1 atanh
- log1p = liftShaped1 GHC.Float.log1p
- expm1 = liftShaped1 GHC.Float.expm1
- log1pexp = liftShaped1 GHC.Float.log1pexp
- log1mexp = liftShaped1 GHC.Float.log1mexp
-
-squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
-squotArray = liftShaped2 mquotArray
-sremArray = liftShaped2 mremArray
-
-satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
-satan2Array = liftShaped2 matan2Array
-
-
-semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
-semptyArray sh = Shaped (memptyArray (shCvtSX sh))
-
-sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shCvtXS' (mshape arr)
-
-srank :: Elt a => Shaped sh a -> SNat (Rank sh)
-srank = shsRank . sshape
-
--- | The total number of elements in the array.
-ssize :: Elt a => Shaped sh a -> Int
-ssize = shsSize . sshape
-
-sindex :: Elt a => Shaped sh a -> IIxS sh -> a
-sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
-shsTakeIx _ _ ZIS = ZSS
-shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
-
-sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
-sindexPartial sarr@(Shaped arr) idx =
- Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
- (ixCvtSX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
-
--- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. Elt a
- => ShS sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a
-slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
-
--- | See the documentation of 'mlift'.
-slift2 :: forall sh1 sh2 sh3 a. Elt a
- => ShS sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
-slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
-
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
-
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
-
-ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
-ssumAllPrim (Shaped arr) = msumAllPrim arr
-
-stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
- => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
-stranspose perm sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- , Refl <- lemTakeLenMapJust perm (sshape sarr)
- , Refl <- lemDropLenMapJust perm (sshape sarr)
- , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr))
- , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
- = Shaped (mtranspose perm arr)
-
-sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend = coerce mappend
-
-sscalar :: Elt a => a -> Shaped '[] a
-sscalar x = Shaped (mscalar x)
-
-sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
-
-sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
-sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
-
-stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
-stoVectorP = coerce mtoVectorP
-
-stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
-stoVector = coerce mtoVector
-
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-
-sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
-
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
-
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
-
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
-
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l)
-
-stoListLinear :: Elt a => Shaped sh a -> [a]
-stoListLinear (Shaped arr) = mtoListLinear arr
-
-sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
-sfromOrthotope sh (SS.A (SG.A arr)) =
- Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))
-
-stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a
-stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr)
-
-sunScalar :: Elt a => Shaped '[] a -> a
-sunScalar arr = sindex arr ZIS
-
-snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
-snest sh arr
- | Refl <- lemMapJustApp sh (Proxy @sh')
- = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr))
-
-sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
-sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
- | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
- = Shaped arr
-
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
-szip = coerce mzip
-
-sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
-sunzip = coerce munzip
-
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
- (shCvtSX sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
-
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
-
-sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
-sreplicate sh (Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh')
- = Shaped (mreplicate (shCvtSX sh) arr)
-
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
-
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
-
-sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n@SNat arr =
- let _ :$$ sh = sshape arr
- in slift (n :$$ sh) (\_ -> X.slice i n) arr
-
-srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
-
-sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a
-sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
-
-sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
-sflatten arr =
- case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
- n@SNat -> sreshape (n :$$ ZSS) arr
-
-siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
-siota sn = Shaped (miota sn)
-
--- | Throws if the array is empty.
-sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
-
--- | Throws if the array is empty.
-smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
-
-sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
- => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
-sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
- | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
- , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
- = case sshape sarr1 of
- _ :$$ _
- | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
- -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
- _ -> error "unreachable"
-
--- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'sdot1Inner' if applicable.
-sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
-sdot = coerce mdot
-
-stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
-stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)
-
-stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
-stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr)
-
-sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
-
-sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
-
-sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
-sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
-
-stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
-stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
-
-mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => Mixed sh a -> ShS sh' -> Shaped sh' a
-mcastToShaped arr targetsh
- | Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
-
-stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
-stoMixed (Shaped arr) = arr
-
--- | A more weakly-typed version of 'stoMixed' that does a runtime shape
--- compatibility check.
-scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => StaticShX sh' -> Shaped sh a -> Mixed sh' a
-scastToMixed sshx sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- = mcast sshx arr
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
new file mode 100644
index 0000000..850f8ea
--- /dev/null
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -0,0 +1,184 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Lemmas (
+ module Data.Array.Nested.Lemmas,
+ lemReplicateSucc, lemMapJustEmpty, lemMapJustCons, lemMapJustHead,
+ lemRankMapJust
+) where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+
+-- * Lemmas about numbers and lists
+
+-- ** Nat
+
+lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
+lemLeqSuccSucc _ _ = unsafeCoerceRefl
+
+lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
+lemLeqPlus _ _ _ = Refl
+
+-- ** Append
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerceRefl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerceRefl
+
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+-- ** Simple type families
+
+lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+{- for now, the plugins can't derive a type for this code, see
+ https://github.com/clash-lang/ghc-typelits-natnormalise/pull/98#issuecomment-3332842214
+lemReplicatePlusApp sn _ _ = go sn
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- lemReplicateSucc @a n
+ , Refl <- go n
+ = sym (lemReplicateSucc @a (SNat @(n'm1 + m)))
+-}
+lemReplicatePlusApp _ _ _ = unsafeCoerceRefl
+
+lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0
+lemReplicateEmpty _ Refl = unsafeCoerceRefl
+
+-- TODO: make less ad-hoc and rename the following few:
+lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1
+lemReplicateCons _ _ Refl = unsafeCoerceRefl
+
+lemReplicateCons2 :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> sh :~: Replicate (Rank sh) Nothing
+lemReplicateCons2 _ _ Refl = unsafeCoerceRefl
+
+lemReplicateSucc2 :: forall n1 n proxy.
+ proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing
+lemReplicateSucc2 _ _ = unsafeCoerceRefl
+
+-- TODO: simplify, but GHC doesn't consistently use congruence nor transitivity
+lemReplicateHead :: proxy x -> proxy' sh -> proxy'' t -> proxy''' n -> x : sh :~: Replicate n t -> x :~: t
+lemReplicateHead _ _ _ _ Refl = unsafeCoerceRefl
+
+lemDropLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
+lemDropLenApp _ _ _ = unsafeCoerceRefl
+
+lemTakeLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
+lemTakeLenApp _ _ _ = unsafeCoerceRefl
+
+lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
+lemInitApp _ _ = unsafeCoerceRefl
+
+lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
+lemLastApp _ _ = unsafeCoerceRefl
+
+
+-- ** KnownNat
+
+lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
+lemKnownNatSucc = Dict
+
+lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKX = Dict
+lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+
+
+-- * Lemmas about shapes
+
+-- ** Known shapes
+
+lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
+lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
+
+lemKnownShX :: StaticShX sh -> Dict KnownShX sh
+lemKnownShX ZKX = Dict
+lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
+lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
+
+-- ** Rank
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem (Proxy @(Rank sh1T)) Proxy Proxy $
+ sym (lemRankApp ssh1 ssh2)
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem _ _ _ Refl = Refl
+
+lemRankAppComm :: proxy sh1 -> proxy sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl
+
+lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = unsafeCoerceRefl
+
+-- ** Related to MapJust and/or Permutation
+
+lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemTakeLenMapJust PNil _ = Refl
+lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
+lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemDropLenMapJust PNil _ = Refl
+lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
+lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemIndexMapJust SZ (_ :$$ _) = Refl
+lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemIndexMapJust i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemIndexMapJust _ ZSS = error "Index of empty"
+
+lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemPermuteMapJust PNil _ = Refl
+lemPermuteMapJust (i `PCons` is) sh
+ | Refl <- lemPermuteMapJust is sh
+ , Refl <- lemIndexMapJust i sh
+ = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index b76aa50..7371c4b 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
@@ -6,6 +7,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -16,19 +18,21 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
-module Data.Array.Nested.Internal.Mixed where
+module Data.Array.Nested.Mixed where
import Prelude hiding (mconcat)
import Control.DeepSeq (NFData(..))
-import Control.Monad (forM_, when)
+import Control.Monad (foldM_, forM_, when)
import Control.Monad.ST
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
import Data.Array.RankedS qualified as S
import Data.Bifunctor (bimap)
import Data.Coerce
-import Data.Foldable (toList)
import Data.Int
-import Data.Kind (Constraint, Type)
+import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty qualified as NE
import Data.Proxy
@@ -37,23 +41,22 @@ import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types (CInt)
import Foreign.Storable (Storable)
+import Foreign.Storable qualified as Storable
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray(..))
-import Data.Array.Mixed.XArray qualified as X
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
import Data.Bag
-- TODO:
--- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int
-- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute)
-- After benchmarking: matmul and matvec
@@ -91,6 +94,9 @@ import Data.Bag
-- Unfortunately, the setup of the library requires us to list these primitive
-- element types multiple times; to aid in extending the list, all these lists
-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+--
+-- NOTE: if you add PRIMITIVE types, be sure to also add NumElt and IntElt
+-- instances for them!
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
@@ -118,6 +124,8 @@ instance PrimElt Bool
instance PrimElt Int
instance PrimElt Int64
instance PrimElt Int32
+instance PrimElt Int16
+instance PrimElt Int8
instance PrimElt CInt
instance PrimElt Float
instance PrimElt Double
@@ -140,27 +148,41 @@ data family Mixed sh a
-- sizes of the elements of an empty array, which is information that should
-- ostensibly not exist; the full array is still empty.
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+#define ANDSHOW , Show
+#else
+#define ANDSHOW
+#endif
+
data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
- deriving (Eq, Ord, Generic)
+ deriving (Eq, Ord, Generic ANDSHOW)
-- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector)
+newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int16 = M_Int16 (Mixed sh (Primitive Int16)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int8 = M_Int8 (Mixed sh (Primitive Int8)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic ANDSHOW) -- no content, orthotope optimises this (via Vector)
-- etc.
data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
+#endif
-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b))
deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b))
data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (Show (Mixed (sh1 ++ sh2) a)) => Show (Mixed sh1 (Mixed sh2 a))
+#endif
deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a))
deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a))
@@ -178,6 +200,8 @@ newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool)
newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh Int16 = MV_Int16 (VS.MVector s Int16)
+newtype instance MixedVecs s sh Int8 = MV_Int8 (VS.MVector s Int8)
newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
@@ -204,20 +228,24 @@ showsMixedArray fromlistPrefix replicatePrefix d arr =
_ ->
showString fromlistPrefix . showString " " . shows (mtoListLinear arr)
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Mixed sh a) where
showsPrec d arr =
let sh = show (shxToList (mshape arr))
in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr
+#endif
instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
+{-# INLINE mliftNumElt1 #-}
mliftNumElt1 :: (PrimElt a, PrimElt b)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
-> Mixed sh a -> Mixed sh b
mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+{-# INLINE mliftNumElt2 #-}
mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
@@ -233,15 +261,15 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
abs = mliftNumElt1 (liftO1 . numEltAbs)
signum = mliftNumElt1 (liftO1 . numEltSignum)
-- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
- fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
+ fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicatePrim"
recip = mliftNumElt1 (liftO1 . floatEltRecip)
(/) = mliftNumElt2 (liftO2 . floatEltDiv)
instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicatePrim"
exp = mliftNumElt1 (liftO1 . floatEltExp)
log = mliftNumElt1 (liftO1 . floatEltLog)
sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
@@ -273,9 +301,10 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem)
matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
--- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
--- a@; see the documentation for 'Primitive' for more details.
+-- | Allowable element types in a mixed array, and by extension
+-- in a 'Data.Array.Nested.Ranked.Ranked' or 'Data.Array.Nested.Shaped.Shaped'
+-- array. Note the polymorphic instance for 'Elt' of @'Primitive' a@;
+-- see the documentation for 'Primitive' for more details.
class Elt a where
-- ====== PUBLIC METHODS ====== --
@@ -284,15 +313,9 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- | See 'mfromListOuter'. If the list does not have the given length, a
+ -- runtime error is thrown.
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -326,8 +349,8 @@ class Elt a where
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
- -- | All arrays in the input must have equal shapes, including subarrays
- -- inside their elements.
+ -- | All arrays in the input must have equal shapes (except possibly
+ -- for the outermost dimension) including subarrays inside their elements.
mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a
mrnf :: Mixed sh a -> ()
@@ -337,11 +360,14 @@ class Elt a where
-- | Tree giving the shape of every array component.
type ShapeTree a
+ -- | Produces an internal representation of a tree of shapes of (potentially)
+ -- nested arrays. If the argument is an array, it requires that the array
+ -- is not empty (otherwise, its' guaranteed to crash early, if non-trivial).
mshapeTree :: a -> ShapeTree a
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+ mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool
mshowShapeTree :: Proxy a -> ShapeTree a -> String
@@ -349,26 +375,28 @@ class Elt a where
-- this mixed array.
marrayStrides :: Mixed sh a -> Bag [Int]
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ -- | 'mvecsFreeze' but without copying the mutable vectors before freezing
+ -- them. If the mutable vectors are mutated after this function, referential
+ -- transparency is broken.
+ mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
-- of this class with those of 'Elt': some instances have an additional
-- "known-shape" constraint.
--
--- This class is (currently) only required for 'mgenerate',
--- 'Data.Array.Nested.Ranked.rgenerate' and
--- 'Data.Array.Nested.Shaped.sgenerate'.
+-- This class is (currently) only required for `memptyArray` and 'mgenerate'.
class Elt a => KnownElt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArrayUnsafe :: IShX sh -> Mixed sh a
@@ -377,20 +405,30 @@ class Elt a => KnownElt a where
-- this vector and an example for the contents.
mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+ -- | Create initialised vectors for this array type, given the shape of
+ -- this vector and the chosen element.
+ mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ -- Somehow, INLINE here can increase allocation with GHC 9.14.1.
+ -- Maybe that happens in void instances such as @Primitive ()@.
+ {-# INLINEABLE mshape #-}
mshape (M_Primitive sh _) = sh
+ {-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ {-# INLINEABLE mindexPartial #-}
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mfromListOuterSN sn l@(arr1 :| _) =
+ let sh = mshape arr1
+ in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -401,6 +439,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
@@ -412,6 +451,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -426,44 +466,50 @@ instance Storable a => Elt (Primitive a) where
=> StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- sh2 = shxCast' sh1 ssh2
- in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+ sh2 = shxCast' ssh2 sh1
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
mtranspose perm (M_Primitive sh arr) =
M_Primitive (shxPermutePrefix perm sh)
- (X.transpose (ssxFromShape sh) perm arr)
+ (X.transpose (ssxFromShX sh) perm arr)
mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
- let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l)
- in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result
+ let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l)
+ in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result
mrnf (M_Primitive sh a) = rnf sh `seq` rnf a
type ShapeTree (Primitive a) = ()
mshapeTree _ = ()
mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
+ mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+ mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
+ -- TODO: this use of toVectorListT is suboptimal
+ mvecsWritePartialLinear
:: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShape sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+ Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShX sh') arr
+ offset = i * shxSize arrsh
+ f off el = do
+ VS.copy (VSM.slice off (VS.length el) v) el
+ return $! off + VS.length el
+ foldM_ f offset (OI.toVectorListT sht t)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+ mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
deriving via Primitive Bool instance Elt Bool
deriving via Primitive Int instance Elt Int
deriving via Primitive Int64 instance Elt Int64
deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive Int16 instance Elt Int16
+deriving via Primitive Int8 instance Elt Int8
deriving via Primitive CInt instance Elt CInt
deriving via Primitive Double instance Elt Double
deriving via Primitive Float instance Elt Float
@@ -472,6 +518,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -479,6 +526,8 @@ deriving via Primitive Bool instance KnownElt Bool
deriving via Primitive Int instance KnownElt Int
deriving via Primitive Int64 instance KnownElt Int64
deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive Int16 instance KnownElt Int16
+deriving via Primitive Int8 instance KnownElt Int8
deriving via Primitive CInt instance KnownElt CInt
deriving via Primitive Double instance KnownElt Double
deriving via Primitive Float instance KnownElt Float
@@ -486,16 +535,22 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ {-# INLINEABLE mshape #-}
mshape (M_Tup2 a _) = mshape a
+ {-# INLINEABLE mindex #-}
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ {-# INLINEABLE mindexPartial #-}
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromListOuterSN sn l =
+ M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ {-# INLINE mlift #-}
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ {-# INLINE mlift2 #-}
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ {-# INLINE mliftL #-}
mliftL ssh2 f =
let unzipT2l [] = ([], [])
unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
@@ -517,20 +572,22 @@ instance (Elt a, Elt b) => Elt (a, b) where
type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
+ mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
+ mvecsWriteLinear i x a
+ mvecsWriteLinear i y b
+ mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartialLinear proxy i x a
+ mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+ mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y
mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-- Arrays of arrays are just arrays, but with more dimensions.
@@ -538,38 +595,41 @@ instance Elt a => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
+ {-# INLINEABLE mshape #-}
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+ = shxTakeSh (Proxy @sh') sh (mshape arr)
+ {-# INLINEABLE mindex #-}
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
+ mindex (M_Nest _ arr) = mindexPartial arr
+ {-# INLINEABLE mindexPartial #-}
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+ mfromListOuterSN sn l@(arr :| _) =
+ M_Nest (SKnown sn :$% mshape arr)
+ (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l))
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift ssh2 f (M_Nest sh1 arr) =
let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -577,16 +637,17 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -595,16 +656,17 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
-> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result))
in fmap (M_Nest sh2) result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
f' sshT
@@ -618,15 +680,15 @@ instance Elt a => Elt (Mixed sh' a) where
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
= let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
-> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
- , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ | let sh' = shxDropSh @sh @sh' sh (mshape arr)
+ , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
, Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
, Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
, Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
@@ -637,48 +699,73 @@ instance Elt a => Elt (Mixed sh' a) where
mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
mconcat l@(M_Nest sh1 _ :| _) =
let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result
+ in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result
mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+ -- This requires that @arr@ is not empty.
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ -- the array is empty if either there are no subarrays, or the subarrays themselves are empty
+ mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
marrayStrides (M_Nest _ arr) = marrayStrides arr
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+ mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
+ mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+ = mvecsWritePartialLinear proxy idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+ mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
mvecsUnsafeNew sh example
| shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh')))
where
sh' = mshape example
+ mvecsReplicate sh example = do
+ vecs <- mvecsUnsafeNew sh example
+ forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs
+ -- this is a slow case, but the alternative, mvecsUnsafeNew with manual
+ -- writing in a loop, leads to every case being as slow
+ return vecs
+
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWrite #-}
+mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+mvecsWrite sh idx = mvecsWriteLinear (ixxToLinear sh idx)
+
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWritePartial #-}
+mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+mvecsWritePartial sh idx = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx)
+
+-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
+memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
@@ -705,38 +792,56 @@ msize = shxSize . mshape
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
+--
+-- If your element type @a@ is a scalar, use the faster 'mgeneratePrim'.
mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f = case shxEnum sh of
[] -> memptyArrayUnsafe sh
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
+ in if mshapeTreeIsEmpty (Proxy @a) shapetree
then memptyArrayUnsafe sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
forM_ restidxs $ \idx -> do
let val = f idx
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
+ mvecsUnsafeFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
+-- | An optimized special case of 'mgenerate', where the function results
+-- are of a primitive type and so there's not need to check that all shapes
+-- are equal. This is also generalized to an arbitrary @Num@ index type
+-- compared to @mgenerate@.
+{-# INLINE mgeneratePrim #-}
+mgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => IShX sh -> (IxX sh i -> a) -> Mixed sh a
+mgeneratePrim sh f =
+ let g i = f (ixxFromLinear sh i)
+ in mfromVector sh $ VS.generate (shxSize sh) g
+
+{-# INLINEABLE msumOuter1PrimP #-}
+msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
+
+{-# INLINEABLE msumOuter1Prim #-}
+msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+{-# INLINEABLE msumAllPrimP #-}
+msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
+msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+{-# INLINEABLE msumAllPrim #-}
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr
+msumAllPrim arr = msumAllPrimP (toPrimitive arr)
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -744,8 +849,8 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
where
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
- ssh = ssxFromShape sh
- snm :: SMayNat () SNat (AddMaybe n m)
+ ssh = ssxFromShX sh
+ snm :: SMayNat () (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
@@ -755,115 +860,217 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
=> StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
f ssh' = X.append (ssxAppend ssh ssh')
+{-# INLINEABLE mfromVectorP #-}
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+{-# INLINEABLE mfromVector #-}
mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+{-# INLINEABLE mtoVectorP #-}
mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
mtoVectorP (M_Primitive _ v) = X.toVector v
+{-# INLINEABLE mtoVector #-}
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to
+-- stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'.
+mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuter l = mfromListOuterN (length l) l
+
+-- | See 'mfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable.
+mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuterN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l)
+ Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")"
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the
+-- list.
+--
+-- If the elements are scalars, 'mfromList1Prim' is faster.
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+mfromList1 = mfromListOuter . fmap mscalar
+-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a
+mfromList1N n = mfromListOuterN n . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a
+mfromList1SN sn = mfromListOuterSN sn . fmap mscalar
+
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+-- | If the elements are scalars, 'mfromListPrimLinear' is faster.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
+
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list.
mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
+ xarr = X.fromList1 l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
+mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
+mfromList1PrimN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
+ Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN sn l =
+ let sh = SKnown sn :$% ZSX
+ in fromPrimitive $ M_Primitive sh
+ $ if Storable.sizeOf (undefined :: a) > 0
+ then X.fromList1SN sn l
+ else case l of -- don't force the list if all elements are the same
+ a0 : _ -> X.replicateScal sh a0
+ [] -> X.fromList1SN sn l
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
mtoListLinear :: Elt a => Mixed sh a -> [a]
-mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
+mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr))
+
+mtoListPrim :: PrimElt a => Mixed '[n] a -> [a]
+mtoListPrim (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr
+
+mtoListPrimLinear :: PrimElt a => Mixed sh a -> [a]
+mtoListPrimLinear (toPrimitive -> M_Primitive _ arr) = X.toListLinear arr
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
-mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
+mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b))
+mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) =
+ let sh1 = shxDropSh sh shsh1
+ in M_Nest sh $
+ M_Primitive (shxAppend sh sh2)
+ (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
--- | See the caveats at @X.rerank@.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero),
+-- then there is no way to deduce the full shape of the output array (more
+-- precisely, the @sh2@ part): that could only come from calling @f@, and there
+-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we
+-- choose to fill the shape with zeros wherever we cannot deduce what it should
+-- be.
+--
+-- For example, if:
+--
+-- @
+-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21]
+-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int)
+-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float)
+-- @
+--
+-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the
+-- @0@ in this shape: we don't know if @f@ intended to return an array with
+-- shape 0 here (it probably didn't), but there is no better number to put here
+-- absent a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
+mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b)
+mrerankPrim sh2 f (M_Nest sh arr) =
+ let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr))
+ in M_Nest sh' (fromPrimitive arr')
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
mreplicate sh arr =
- let ssh' = ssxFromShape (mshape arr)
- in mlift (ssxAppend (ssxFromShape sh) ssh')
+ let ssh' = ssxFromShX (mshape arr)
+ in mlift (ssxAppend (ssxFromShX sh) ssh')
(\(sshT :: StaticShX shT) ->
case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x)
-mreplicateScal :: forall sh a. PrimElt a
+mreplicatePrim :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x)
+
+msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
+msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+msliceSN i n arr =
let _ :$% sh = mshape arr
- in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+ in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+mslice :: forall i n k sh a. Elt a
+ => SMayNat Int i -> SMayNat Int n -> SMayNat Int k -> Mixed (AddMaybe (AddMaybe i n) k : sh) a -> Mixed (n : sh) a
+mslice i n k arr = case mshape arr of
+ _ :$% sh ->
+ let uarr = mcastPartial (ssxFromShX $ smnAddMaybe (smnAddMaybe i n) k :$% ZSX) (SUnknown () :!% ZKX) Proxy arr
+ in mcastPartial (SUnknown () :!% ZKX) (ssxFromShX $ n :$% ZSX) Proxy
+ $ mlift (SUnknown () :!% ssxFromShX sh) (\_ -> X.sliceU (fromSMayNat' i) (fromSMayNat' n)) uarr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr
mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
mreshape sh' arr =
- mlift (ssxFromShape sh')
- (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ mlift (ssxFromShX sh')
+ (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh')
arr
mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
@@ -875,13 +1082,14 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-- | Throws if the array is empty.
mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
+ ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr))
-- | Throws if the array is empty.
mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+ ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+{-# INLINEABLE mdot1Inner #-}
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
@@ -890,13 +1098,14 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
= case sh1 of
_ :$% _
| sh1 == sh2
- , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) ->
fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
| otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"
ZSX -> error "unreachable"
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'mdot1Inner' if applicable.
+{-# INLINEABLE mdot #-}
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
mdot a b =
munScalar $
@@ -915,41 +1124,15 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+{-# INLINE mliftPrim #-}
mliftPrim :: (PrimElt a, PrimElt b)
=> (a -> b)
-> Mixed sh a -> Mixed sh b
mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+{-# INLINE mliftPrim2 #-}
mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (a -> b -> c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
-
-mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
- => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
-mcast ssh2 arr
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
-
--- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors
-data SafeMCastSpec
- = MCastId
- | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec
- | MCastForget
-
-type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint
-type family SafeMCast spec sh1 sh2 where
- SafeMCast MCastId sh sh = ()
- SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B)
- SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing
-
--- | This is an O(1) operation: the 'SafeMCast' constraint ensures that
--- type-level shape information can only be forgotten, not introduced, and thus
--- that no runtime shape checks are required. The @spec@ describes to
--- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@.
---
--- To see how to construct the spec, read the equations of 'SafeMCast' closely.
-mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a
-mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a)
diff --git a/src/Data/Array/Nested/Mixed/ListX.hs b/src/Data/Array/Nested/Mixed/ListX.hs
new file mode 100644
index 0000000..2c8a9cc
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/ListX.hs
@@ -0,0 +1,139 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Mixed.ListX (ListX, pattern ZX, pattern (::%), listxShow, lazily, lazilyConcat, lazilyForce, Rank, coerceEqualRankListX) where
+
+import Control.DeepSeq (NFData(..))
+import Data.Foldable qualified as Foldable
+import Data.Kind (Type)
+import Data.Type.Equality
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import GHC.TypeLits.Orphans ()
+#endif
+
+import Data.Array.Nested.Types
+
+
+-- | The length of a type-level list. If the argument is a shape, then the
+-- result is the rank of that shape.
+type family Rank sh where
+ Rank '[] = 0
+ Rank (_ : sh) = Rank sh + 1
+
+
+-- * Mixed lists implementation
+
+-- | Data invariant: each element on the list is in WHNF (the spine may be lazy)
+-- and the length of the list is the same as of the type-level shape.
+type role ListX nominal representational
+type ListX :: [Maybe Nat] -> Type -> Type
+newtype ListX sh i = ListX [i]
+ deriving (Eq, Ord, NFData, Foldable)
+
+{-# INLINE ZX #-}
+pattern ZX :: forall sh i. () => sh ~ '[] => ListX sh i
+pattern ZX <- (listxNull -> Just Refl)
+ where ZX = ListX []
+
+{-# INLINE listxNull #-}
+listxNull :: ListX sh i -> Maybe (sh :~: '[])
+listxNull (ListX []) = Just unsafeCoerceRefl
+listxNull (ListX (_ : _)) = Nothing
+
+{-# INLINE (::%) #-}
+pattern (::%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> ListX sh i -> ListX sh1 i
+pattern i ::% l <- (listxUncons -> Just (UnconsListXRes i l))
+ where !i ::% ListX !l = ListX (i : l)
+infixr 3 ::%
+
+data UnconsListXRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes i (ListX sh i)
+{-# INLINE listxUncons #-}
+listxUncons :: forall sh1 i. ListX sh1 i -> Maybe (UnconsListXRes i sh1)
+listxUncons (ListX (i : l)) = gcastWith (unsafeCoerceRefl :: Head sh1 ': Tail sh1 :~: sh1) $
+ Just (UnconsListXRes i (ListX @(Tail sh1) l))
+listxUncons (ListX []) = Nothing
+
+{-# COMPLETE ZX, (::%) #-}
+
+-- | This function makes no attempt to stop you from breaking the data invariant for 'ListX'. If you do so, you must later ensure that the invariant is reinstated, for example using @'lazilyForce' 'id'@.
+{-# INLINE lazily #-}
+lazily :: ([a] -> [b]) -> ListX sh a -> ListX sh b
+lazily f (ListX l) = ListX $ f l
+
+-- | This function makes no attempt to stop you from breaking the data invariant for 'ListX'. If you do so, you must later ensure that the invariant is reinstated, for example using @'lazilyForce' 'id'@.
+{-# INLINE lazilyConcat #-}
+lazilyConcat :: ([a] -> [b] -> [c]) -> ListX sh a -> ListX sh' b -> ListX (sh ++ sh') c
+lazilyConcat f (ListX l) (ListX k) = ListX $ f l k
+
+-- | This operation forces all elements of the @[b]@ list to restore the strictness part of the data invariant for 'ListX'. Note that ensuring the list has the right length is still the user's responsibility.
+{-# INLINE lazilyForce #-}
+lazilyForce :: ([a] -> [b]) -> ListX sh a -> ListX sh b
+lazilyForce f (ListX l) = let res = f l
+ in foldr seq () res `seq` ListX res
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ListX sh i)
+#else
+instance Show i => Show (ListX sh i) where
+ showsPrec _ = listxShow shows
+#endif
+
+{-# INLINE listxShow #-}
+listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS
+listxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListX sh' i -> ShowS
+ go _ ZX = id
+ go prefix (x ::% xs) = showString prefix . f x . go "," xs
+
+-- This can't be derived, becauses the list needs to be fully evaluated,
+-- per data invariant. This version is faster than versions defined using
+-- (::%) or lazilyForce.
+instance Functor (ListX l) where
+ {-# INLINE fmap #-}
+ fmap f (ListX l) =
+ let fmap' [] = []
+ fmap' (x : xs) = let y = f x
+ rest = fmap' xs
+ in y `seq` rest `seq` (y : rest)
+ in ListX $ fmap' l
+
+-- | Very untyped: not even length is checked (at runtime).
+instance IsList (ListX sh i) where
+ type Item (ListX sh i) = i
+ {-# INLINE fromList #-}
+ fromList l = foldr seq () l `seq` ListX l
+ {-# INLINE toList #-}
+ toList = Foldable.toList
+
+coerceEqualRankListX :: Rank sh ~ Rank sh' => ListX sh i -> ListX sh' i
+coerceEqualRankListX (ListX l) = ListX l
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 5f4775c..a01e0f3 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -1,12 +1,15 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -16,162 +19,44 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Mixed.Shape where
+module Data.Array.Nested.Mixed.Shape (
+ module Data.Array.Nested.Mixed.Shape,
+ Rank,
+) where
import Control.DeepSeq (NFData(..))
+import Control.Exception (assert)
import Data.Bifunctor (first)
import Data.Coerce
import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (withDict)
-import GHC.Generics (Generic)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import GHC.TypeLits.Orphans ()
+#endif
-import Data.Array.Mixed.Types
-
-
--- | The length of a type-level list. If the argument is a shape, then the
--- result is the rank of that shape.
-type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = Rank sh + 1
-
-
--- * Mixed lists
-
-type role ListX nominal representational
-type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
-data ListX sh f where
- ZX :: ListX '[] f
- (::%) :: f n -> ListX sh f -> ListX (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
-infixr 3 ::%
-
-instance (forall n. Show (f n)) => Show (ListX sh f) where
- showsPrec _ = listxShow shows
-
-instance (forall n. NFData (f n)) => NFData (ListX sh f) where
- rnf ZX = ()
- rnf (x ::% l) = rnf x `seq` rnf l
-
-data UnconsListXRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
-listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
-listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
-listxUncons ZX = Nothing
-
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEqType ZX ZX = Just Refl
-listxEqType (n ::% sh) (m ::% sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listxEqType sh sh'
- = Just Refl
-listxEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEqual ZX ZX = Just Refl
-listxEqual (n ::% sh) (m ::% sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listxEqual sh sh'
- = Just Refl
-listxEqual _ _ = Nothing
-
-listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
-listxFmap _ ZX = ZX
-listxFmap f (x ::% xs) = f x ::% listxFmap f xs
-
-listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
-listxFold _ ZX = mempty
-listxFold f (x ::% xs) = f x <> listxFold f xs
-
-listxLength :: ListX sh f -> Int
-listxLength = getSum . listxFold (\_ -> Sum 1)
-
-listxRank :: ListX sh f -> SNat (Rank sh)
-listxRank ZX = SNat
-listxRank (_ ::% l) | SNat <- listxRank l = SNat
-
-listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
-listxShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListX sh' f -> ShowS
- go _ ZX = id
- go prefix (x ::% xs) = showString prefix . f x . go "," xs
-
-listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i)
-listxFromList topssh topl = go topssh topl
- where
- go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
- go ZKX [] = ZX
- go (_ :!% sh) (i : is) = Const i ::% go sh is
- go _ _ = error $ "listxFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
-
-listxToList :: ListX sh' (Const i) -> [i]
-listxToList ZX = []
-listxToList (Const i ::% is) = i : listxToList is
-
-listxHead :: ListX (mn ': sh) f -> f mn
-listxHead (i ::% _) = i
-
-listxTail :: ListX (n : sh) i -> ListX sh i
-listxTail (_ ::% sh) = sh
-
-listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
-listxAppend ZX idx' = idx'
-listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
-
-listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
-listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
-listxInit (_ ::% ZX) = ZX
-
-listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
-listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
-listxLast (x ::% ZX) = x
-
-listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g)
-listxZip ZX ZX = ZX
-listxZip (i ::% irest) (j ::% jrest) =
- Pair i j ::% listxZip irest jrest
-
-listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g
- -> ListX sh h
-listxZipWith _ ZX ZX = ZX
-listxZipWith f (i ::% is) (j ::% js) =
- f i j ::% listxZipWith f is js
+import Data.Array.Nested.Mixed.ListX
+import Data.Array.Nested.Types
-- * Mixed indices
--- | This is a newtype over 'ListX'.
+-- | An index into a mixed-typed array.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
-newtype IxX sh i = IxX (ListX sh (Const i))
- deriving (Eq, Ord, Generic)
+newtype IxX sh i = IxX (ListX sh i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
pattern ZIX = IxX ZX
@@ -180,30 +65,30 @@ pattern (:.%)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
=> i -> IxX sh i -> IxX sh1 i
-pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
- where i :.% IxX shl = IxX (Const i ::% shl)
+pattern i :.% l <- IxX (i ::% (IxX -> l))
+ where i :.% IxX l = IxX (i ::% l)
infixr 3 :.%
{-# COMPLETE ZIX, (:.%) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxX sh = IxX sh Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxX sh i)
+#else
instance Show i => Show (IxX sh i) where
- showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l
-
-instance Functor (IxX sh) where
- fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
-
-instance Foldable (IxX sh) where
- foldMap f (IxX l) = listxFold (f . getConst) l
+ showsPrec _ (IxX l) = listxShow shows l
+#endif
-instance NFData i => NFData (IxX sh i)
-
-ixxLength :: IxX sh i -> Int
-ixxLength (IxX l) = listxLength l
+{-# INLINE ixxFromList #-}
+ixxFromList :: StaticShX sh -> [i] -> IxX sh i
+ixxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l
ixxRank :: IxX sh i -> SNat (Rank sh)
-ixxRank (IxX l) = listxRank l
+ixxRank ZIX = SNat
+ixxRank (_ :.% l) | SNat <- ixxRank l = SNat
ixxZero :: StaticShX sh -> IIxX sh
ixxZero ZKX = ZIX
@@ -213,85 +98,131 @@ ixxZero' :: IShX sh -> IIxX sh
ixxZero' ZSX = ZIX
ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
-ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
-ixxFromList = coerce (listxFromList @_ @i)
-
-ixxHead :: IxX (n : sh) i -> i
-ixxHead (IxX list) = getConst (listxHead list)
+ixxHead :: IxX (mn ': sh) i -> i
+ixxHead (i :.% _) = i
ixxTail :: IxX (n : sh) i -> IxX sh i
-ixxTail (IxX list) = IxX (listxTail list)
+ixxTail (_ :.% sh) = sh
ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
-ixxAppend = coerce (listxAppend @_ @(Const i))
+ixxAppend (IxX l1) (IxX l2) = IxX $ lazilyConcat (++) l1 l2
-ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
-ixxDrop = coerce (listxDrop @(Const i) @(Const i))
+ixxDrop :: forall i j sh sh'. IxX sh j -> IxX (sh ++ sh') i -> IxX sh' i
+ixxDrop ZIX long = long
+ixxDrop (_ :.% short) long = case long of _ :.% long' -> ixxDrop short long'
-ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
-ixxInit = coerce (listxInit @(Const i))
+ixxInit :: forall i n sh. IxX (n : sh) i -> IxX (Init (n : sh)) i
+ixxInit (i :.% sh@(_ :.% _)) = i :.% ixxInit sh
+ixxInit (_ :.% ZIX) = ZIX
-ixxLast :: forall n sh i. IxX (n : sh) i -> i
-ixxLast = coerce (listxLast @(Const i))
+ixxLast :: forall i n sh. IxX (n : sh) i -> i
+ixxLast (_ :.% sh@(_ :.% _)) = ixxLast sh
+ixxLast (x :.% ZIX) = x
ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
ixxZip ZIX ZIX = ZIX
ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
+{-# INLINE ixxZipWith #-}
ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
-ixxFromLinear :: IShX sh -> Int -> IIxX sh
-ixxFromLinear = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
+ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
+ixxCast ZKX ZIX = ZIX
+ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx
+ixxCast _ _ = error "ixxCast: ranks don't match"
+
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixxToLinear #-}
+ixxToLinear :: Num i => IShX sh -> IxX sh i -> i
+ixxToLinear = \sh i -> go sh i 0
where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$% sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSMayNat' n
- in (locali :.% idx, upi)
+ go :: Num i => IShX sh -> IxX sh i -> i -> i
+ go ZSX ZIX !a = a
+ go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i)
-ixxToLinear :: IShX sh -> IIxX sh -> Int
-ixxToLinear = \sh i -> fst (go sh i)
+{-# INLINEABLE ixxFromLinear #-}
+ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i
+ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times
+ let suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+ in fromLin0 sh suffixes
where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$% sh) (i :.% ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
+ -- Unfold first iteration of fromLin to do the range check.
+ -- Don't inline this function at first to allow GHC to inline the outer
+ -- function and realise that 'suffixes' is shared. But then later inline it
+ -- anyway, to avoid the function call. Removing the pragma makes GHC
+ -- somehow unable to recognise that 'suffixes' can be shared in a loop.
+ {-# NOINLINE [0] fromLin0 #-}
+ fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
+ fromLin0 sh suffixes i =
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of
+ (ZSX, _) | i > 0 -> outrange sh i
+ | otherwise -> ZIX
+ ((fromSMayNat' -> n) :$% sh', suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= n then outrange sh i else
+ fromIntegral q :.% fromLin sh' suffs r
+ _ -> error "impossible"
+
+ fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
+ fromLin ZSX _ !_ = ZIX
+ fromLin (_ :$% sh') (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in fromIntegral q :.% fromLin sh' suffs r
+ fromLin _ _ _ = error "impossible"
+
+ {-# NOINLINE outrange #-}
+ outrange :: IShX sh -> Int -> a
+ outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = shxEnum'
+
+{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
+shxEnum' :: Num i => IShX sh -> [IxX sh i]
+shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+
+ fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
+ fromLin ZSX _ _ = ZIX
+ fromLin (_ :$% sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.% fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
-- * Mixed shapes
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+instance NFData i => NFData (SMayNat i n) where
rnf (SUnknown i) = rnf i
- rnf (SKnown x) = rnf x
+ rnf (SKnown SNat) = ()
-instance TestEquality f => TestEquality (SMayNat i f) where
+instance TestEquality (SMayNat i) where
testEquality SUnknown{} SUnknown{} = Just Refl
testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
testEquality _ _ = Nothing
+{-# INLINE fromSMayNat #-}
fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => f m -> r)
- -> SMayNat i f n -> r
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s
-fromSMayNat' :: SMayNat Int SNat n -> Int
+{-# INLINE fromSMayNat' #-}
+fromSMayNat' :: SMayNat Int n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
type family AddMaybe n m where
@@ -299,134 +230,226 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
--- | This is a newtype over 'ListX'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
- deriving (Eq, Ord, Generic)
-
-pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
-pattern ZSX = ShX ZX
-
-pattern (:$%)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
-pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
- where i :$% ShX shl = ShX (i ::% shl)
-infixr 3 :$%
-
-{-# COMPLETE ZSX, (:$%) #-}
+data ShX sh i where
+ ZSX :: ShX '[] i
+ ConsUnknown :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i
+-- TODO: bring this UNPACK back when GHC no longer crashes:
+-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ShX sh i -> ShX (Just n : sh) i
+ ConsKnown :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i
+deriving instance Ord i => Ord (ShX sh i)
-type IShX sh = ShX sh Int
+-- A manually defined instance and this INLINEABLE is needed to specialize
+-- mdot1Inner (otherwise GHC warns specialization breaks down here).
+instance Eq i => Eq (ShX sh i) where
+ {-# INLINEABLE (==) #-}
+ ZSX == ZSX = True
+ ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2
+ ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ShX sh i)
+#else
instance Show i => Show (ShX sh i) where
- showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+ showsPrec _ l = shxShow (fromSMayNat shows (shows . fromSNat)) l
+#endif
+
+instance NFData i => NFData (ShX sh i) where
+ rnf ZSX = ()
+ rnf (x `ConsUnknown` l) = rnf x `seq` rnf l
+ rnf (SNat `ConsKnown` l) = rnf l
instance Functor (ShX sh) where
- fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
+ {-# INLINE fmap #-}
+ fmap f l = shxFmap (fromSMayNat (SUnknown . f) SKnown) l
-instance NFData i => NFData (ShX sh i) where
- rnf (ShX ZX) = ()
- rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+data UnconsShXRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShXRes (SMayNat i n) (ShX sh i)
+shxUncons :: ShX sh1 i -> Maybe (UnconsShXRes i sh1)
+shxUncons (i `ConsUnknown` shl') = Just (UnconsShXRes (SUnknown i) shl')
+shxUncons (i `ConsKnown` shl') = Just (UnconsShXRes (SKnown i) shl')
+shxUncons ZSX = Nothing
--- | This checks only whether the types are equal; unknown dimensions might
--- still differ. This corresponds to 'testEquality', except on the penultimate
--- type parameter.
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqType ZSX ZSX = Just Refl
-shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- shxEqType sh sh'
- = Just Refl
-shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
+shxEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh')
| Just Refl <- shxEqType sh sh'
= Just Refl
+shxEqType (n `ConsKnown` sh) (m `ConsKnown` sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- shxEqType sh sh'
+ = Just Refl
shxEqType _ _ = Nothing
--- | This checks whether all dimensions have the same value. This is more than
--- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
--- @some@ package (except on the penultimate type parameter).
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
-shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
- | Just Refl <- sameNat n m
+shxEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh')
+ | n == m
, Just Refl <- shxEqual sh sh'
= Just Refl
-shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
- | i == j
+shxEqual (n `ConsKnown` sh) (m `ConsKnown` sh')
+ | Just Refl <- testEquality n m
, Just Refl <- shxEqual sh sh'
= Just Refl
shxEqual _ _ = Nothing
+{-# INLINE shxFmap #-}
+shxFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ShX sh i -> ShX sh j
+shxFmap _ ZSX = ZSX
+shxFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of
+ SUnknown y -> y `ConsUnknown` shxFmap f xs
+shxFmap f (x `ConsKnown` xs) = case f (SKnown x) of
+ SKnown y -> y `ConsKnown` shxFmap f xs
+
+{-# INLINE shxFoldMap #-}
+shxFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ShX sh i -> m
+shxFoldMap _ ZSX = mempty
+shxFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> shxFoldMap f xs
+shxFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> shxFoldMap f xs
+
shxLength :: ShX sh i -> Int
-shxLength (ShX l) = listxLength l
+shxLength = getSum . shxFoldMap (\_ -> Sum 1)
shxRank :: ShX sh i -> SNat (Rank sh)
-shxRank (ShX l) = listxRank l
+shxRank ZSX = SNat
+shxRank (_ `ConsUnknown` l) | SNat <- shxRank l = SNat
+shxRank (_ `ConsKnown` l) | SNat <- shxRank l = SNat
+
+{-# INLINE shxShow #-}
+shxShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ShX sh i -> ShowS
+shxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ShX sh' i -> ShowS
+ go _ ZSX = id
+ go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs
+ go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs
+
+shxHead :: ShX (mn ': sh) i -> SMayNat i mn
+shxHead (i `ConsUnknown` _) = SUnknown i
+shxHead (i `ConsKnown` _) = SKnown i
+
+shxTail :: ShX (n : sh) i -> ShX sh i
+shxTail (_ `ConsUnknown` sh) = sh
+shxTail (_ `ConsKnown` sh) = sh
+
+shxAppend :: ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shxAppend ZSX idx' = idx'
+shxAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` shxAppend idx idx'
+shxAppend (i `ConsKnown` idx) idx' = i `ConsKnown` shxAppend idx idx'
+
+shxDropSh :: forall sh sh' i j. ShX sh j -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSh ZSX long = long
+shxDropSh (_ `ConsUnknown` short) long = case long of
+ _ `ConsUnknown` long' -> shxDropSh short long'
+shxDropSh (_ `ConsKnown` short) long = case long of
+ _ `ConsKnown` long' -> shxDropSh short long'
+
+shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSSX = coerce (shxDropSh @_ @_ @i @())
+
+shxInit :: forall i n sh. ShX (n : sh) i -> ShX (Init (n : sh)) i
+shxInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` shxInit sh
+shxInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` shxInit sh
+shxInit (_ `ConsUnknown` ZSX) = ZSX
+shxInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` shxInit sh
+shxInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` shxInit sh
+shxInit (_ `ConsKnown` ZSX) = ZSX
+
+shxLast :: forall i n sh. ShX (n : sh) i -> SMayNat i (Last (n : sh))
+shxLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = shxLast sh
+shxLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = shxLast sh
+shxLast (x `ConsUnknown` ZSX) = SUnknown x
+shxLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = shxLast sh
+shxLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = shxLast sh
+shxLast (x `ConsKnown` ZSX) = SKnown x
+
+pattern (:$%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat i n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- (shxUncons -> Just (UnconsShXRes i shl))
+ where i :$% shl = case i of; SUnknown x -> x `ConsUnknown` shl; SKnown x -> x `ConsKnown` shl
+infixr 3 :$%
+
+{-# COMPLETE ZSX, (:$%) #-}
+
+type IShX sh = ShX sh Int
-- | The number of elements in an array described by this shape.
shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
-shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
-shxFromList topssh topl = go topssh topl
+-- We don't report the size of the list in case of errors in order not to retain the list.
+{-# INLINEABLE shxFromList #-}
+shxFromList :: StaticShX sh -> [Int] -> IShX sh
+shxFromList (StaticShX topssh) topl = go topssh topl
where
- go :: StaticShX sh' -> [Int] -> ShX sh' Int
- go ZKX [] = ZSX
- go (SKnown sn :!% sh) (i : is)
- | i == fromSNat' sn = SKnown sn :$% go sh is
- | otherwise = error $ "shxFromList: Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is
- go _ _ = error $ "shxFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ go :: ShX sh' () -> [Int] -> ShX sh' Int
+ go ZSX [] = ZSX
+ go ZSX _ = error $ "shxFromList: List too long (type says " ++ show (shxLength topssh) ++ ")"
+ go (ConsKnown sn sh) (i : is)
+ | i == fromSNat' sn = ConsKnown sn (go sh is)
+ | otherwise = error "shxFromList: Value does not match typing"
+ go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is)
+ go _ _ = error $ "shxFromList: List too short (type says " ++ show (shxLength topssh) ++ ")"
+{-# INLINEABLE shxToList #-}
shxToList :: IShX sh -> [Int]
-shxToList ZSX = []
-shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
-
-shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
-
-shxHead :: ShX (n : sh) i -> SMayNat i SNat n
-shxHead (ShX list) = listxHead list
-
-shxTail :: ShX (n : sh) i -> ShX sh i
-shxTail (ShX list) = ShX (listxTail list)
+shxToList l = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ShX sh Int -> is
+ go ZSX = nil
+ go (ConsUnknown i rest) = i `cons` go rest
+ go (ConsKnown sn rest) = fromSNat' sn `cons` go rest
+ in go l)
-shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
-shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+-- If it ever matters for performance, this is unsafeCoercible.
+shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
+shxFromSSX ZKX = ZSX
+shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
+ | Refl <- lemMapJustCons @sh Refl
+ = SKnown n :$% shxFromSSX sh
+shxFromSSX (SUnknown _ :!% _) = error "unreachable"
-shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
-shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+-- | This may fail if @sh@ has @Nothing@s in it.
+shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i)
+shxFromSSX2 ZKX = Just ZSX
+shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
+shxFromSSX2 (SUnknown _ :!% _) = Nothing
-shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
-shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
-shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listxInit @(SMayNat i SNat))
+shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSh _ ZSX _ = ZSX
+shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh
-shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
-shxLast = coerce (listxLast @(SMayNat i SNat))
+{-# INLINEABLE shxTakeIx #-}
+shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeIx _ (IxX ZX) _ = ZSX
+shxTakeIx proxy (IxX (_ ::% long)) short = case short of i :$% short' -> i :$% shxTakeIx proxy (IxX long) short'
-shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shxTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+{-# INLINEABLE shxDropIx #-}
+shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropIx ZIX long = long
+shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long'
-shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+{-# INLINE shxZipWith #-}
+shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n)
-> ShX sh i -> ShX sh j -> ShX sh k
shxZipWith _ ZSX ZSX = ZSX
shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
@@ -437,115 +460,115 @@ shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
-shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
-shxEnum :: IShX sh -> [IIxX sh]
-shxEnum = \sh -> go sh id []
- where
- go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZSX f = (f ZIX :)
- go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
-
-shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
-shxCast ZSX ZKX = Just ZSX
-shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
-shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
+shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
+shxCast ZKX ZSX = Just ZSX
+shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
+shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh
shxCast _ _ = Nothing
-- | Partial version of 'shxCast'.
-shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
-shxCast' sh ssh = case shxCast sh ssh of
+shxCast' :: StaticShX sh' -> IShX sh -> IShX sh'
+shxCast' ssh sh = case shxCast ssh sh of
Just sh' -> sh'
Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
-- * Static mixed shapes
--- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+-- | The part of a shape that is statically known. (A newtype over 'ShX'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
- deriving (Eq, Ord)
+newtype StaticShX sh = StaticShX (ShX sh ())
+ deriving (NFData)
+
+instance Eq (StaticShX sh) where _ == _ = True
+instance Ord (StaticShX sh) where compare _ _ = EQ
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
-pattern ZKX = StaticShX ZX
+pattern ZKX = StaticShX ZSX
pattern (:!%)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
-pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
- where i :!% StaticShX shl = StaticShX (i ::% shl)
+ => SMayNat () n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (shxUncons -> Just (UnconsShXRes i (StaticShX -> shl)))
+ where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl)
+
infixr 3 :!%
{-# COMPLETE ZKX, (:!%) #-}
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (StaticShX sh)
+#else
instance Show (StaticShX sh) where
- showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
-
-instance NFData (StaticShX sh) where
- rnf (StaticShX ZX) = ()
- rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
- rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
+ showsPrec _ (StaticShX l) = shxShow (fromSMayNat shows (shows . fromSNat)) l
+#endif
instance TestEquality StaticShX where
- testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
+ testEquality (StaticShX l1) (StaticShX l2) = shxEqType l1 l2
ssxLength :: StaticShX sh -> Int
-ssxLength (StaticShX l) = listxLength l
+ssxLength (StaticShX l) = shxLength l
ssxRank :: StaticShX sh -> SNat (Rank sh)
-ssxRank (StaticShX l) = listxRank l
+ssxRank (StaticShX l) = shxRank l
-- | @ssxEqType = 'testEquality'@. Provided for consistency.
ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
ssxEqType = testEquality
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKX sh' = sh'
-ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+ssxAppend = coerce (shxAppend @_ @())
-ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
-ssxHead (StaticShX list) = listxHead list
+ssxHead :: StaticShX (n : sh) -> SMayNat () n
+ssxHead (StaticShX list) = shxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
-ssxTail (_ :!% ssh) = ssh
+ssxTail (StaticShX list) = StaticShX (shxTail list)
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
+ssxTakeIx _ (IxX ZX) _ = ZKX
+ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short'
-ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listxInit @(SMayNat () SNat))
+ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropIx (IxX ZX) long = long
+ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long'
-ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
-ssxLast = coerce (listxLast @(SMayNat () SNat))
+ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSh = coerce (shxDropSh @_ @_ @() @i)
--- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
-ssxToShX' ZKX = Just ZSX
-ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
-ssxToShX' (SUnknown _ :!% _) = Nothing
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (shxDropSh @_ @_ @() @())
+
+ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
+ssxInit = coerce (shxInit @())
+
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
+ssxLast = coerce (shxLast @())
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ | Refl <- lemReplicateSucc @(Nothing @Nat) n
= SUnknown () :!% ssxReplicate n
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom :: StaticShX sh -> Int -> [Int]
+ssxIotaFrom ZKX _ = []
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1)
-ssxFromShape :: IShX sh -> StaticShX sh
-ssxFromShape ZSX = ZKX
-ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+ssxFromShX :: ShX sh i -> StaticShX sh
+ssxFromShX ZSX = ZKX
+ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh
ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n
-- | Evidence for the static part of a shape. This pops up only when you are
@@ -557,7 +580,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
-withKnownShX k = withDict @(KnownShX sh) k
+withKnownShX = withDict @(KnownShX sh)
-- * Flattening
@@ -570,18 +593,18 @@ type family Flatten' acc sh where
Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-- This function is currently unused
-ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh)
ssxFlatten = go (SNat @1)
where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh)
go acc ZKX = SKnown acc
go _ (SUnknown () :!% _) = SUnknown ()
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
-shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten :: IShX sh -> SMayNat Int (Flatten sh)
shxFlatten = go (SNat @1)
where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh)
go acc ZSX = SKnown acc
go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
@@ -592,20 +615,14 @@ shxFlatten = go (SNat @1)
goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
--- | Very untyped: only length is checked (at runtime).
-instance KnownShX sh => IsList (ListX sh (Const i)) where
- type Item (ListX sh (Const i)) = i
- fromList = listxFromList (knownShX @sh)
- toList = listxToList
-
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
-instance KnownShX sh => IsList (IxX sh i) where
+instance IsList (IxX sh i) where
type Item (IxX sh i) = i
fromList = IxX . IsList.fromList
toList = Foldable.toList
-- | Untyped: length and known dimensions are checked (at runtime).
-instance KnownShX sh => IsList (ShX sh Int) where
- type Item (ShX sh Int) = Int
+instance KnownShX sh => IsList (IShX sh) where
+ type Item (IShX sh) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 22672cb..85fbd89 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -1,10 +1,10 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -15,29 +15,29 @@
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Permutation where
+module Data.Array.Nested.Permutation where
import Data.Coerce (coerce)
-import Data.Functor.Const
import Data.List (sort)
import Data.Maybe (fromMaybe)
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
+import GHC.Exts (withDict)
import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
+import Data.Array.Nested.Types
-- * Permutations
-- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
+-- dimension list is most similar to @backpermute@ in the @vector@ package; see
+-- 'Permute' for code that implements this.
data Perm list where
PNil :: Perm '[]
PCons :: SNat a -> Perm l -> Perm (a : l)
@@ -45,15 +45,22 @@ infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
+instance TestEquality Perm where
+ testEquality PNil PNil = Just Refl
+ testEquality (x `PCons` xs) (y `PCons` ys)
+ | Just Refl <- testEquality x y
+ , Just Refl <- testEquality xs ys = Just Refl
+ testEquality _ _ = Nothing
+
permRank :: Perm list -> SNat (Rank list)
permRank PNil = SNat
permRank (_ `PCons` l) | SNat <- permRank l = SNat
-permFromList :: [Int] -> (forall list. Perm list -> r) -> r
-permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r
+permFromListCont [] k = k PNil
+permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Nested.Permutation.permFromListCont: negative number in list: " ++ show x
permToList :: Perm list -> [Natural]
permToList PNil = mempty
@@ -119,6 +126,9 @@ class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r
+withKnownPerm = withDict @(KnownPerm l)
+
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -161,51 +171,78 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
-listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
-listxTakeLen PNil _ = ZX
-listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
-listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+shxTakeLenPerm :: forall i is sh. Perm is -> ShX sh i -> ShX (TakeLen is sh) i
+shxTakeLenPerm PNil _ = ZSX
+shxTakeLenPerm (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` shxTakeLenPerm is sh
+shxTakeLenPerm (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` shxTakeLenPerm is sh
+shxTakeLenPerm (_ `PCons` _) ZSX = error "Permutation longer than shape"
-listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
-listxDropLen PNil sh = sh
-listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
-listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+shxDropLenPerm :: forall i is sh. Perm is -> ShX sh i -> ShX (DropLen is sh) i
+shxDropLenPerm PNil sh = sh
+shxDropLenPerm (_ `PCons` is) (_ `ConsUnknown` sh) = shxDropLenPerm is sh
+shxDropLenPerm (_ `PCons` is) (_ `ConsKnown` sh) = shxDropLenPerm is sh
+shxDropLenPerm (_ `PCons` _) ZSX = error "Permutation longer than shape"
-listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
-listxPermute PNil _ = ZX
-listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
+shxPermute :: forall i is sh. Perm is -> ShX sh i -> ShX (Permute is sh) i
+shxPermute PNil _ = ZSX
+shxPermute (i `PCons` (is :: Perm is')) (sh :: ShX sh i) =
+ case shxIndex i sh of
+ SUnknown x -> x `ConsUnknown` shxPermute is sh
+ SKnown x -> x `ConsKnown` shxPermute is sh
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
-listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
-listxIndex _ _ _ ZX = error "Index into empty shape"
+shxIndex :: forall i k sh. SNat k -> ShX sh i -> SMayNat i (Index k sh)
+shxIndex SZ (n `ConsUnknown` _) = SUnknown n
+shxIndex SZ (n `ConsKnown` _) = SKnown n
+shxIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ShX sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh')
+ = shxIndex i sh
+shxIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ShX sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh')
+ = shxIndex i sh
+shxIndex _ ZSX = error "Index into empty shape"
-listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
-listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+shxPermutePrefix :: forall i is sh. Perm is -> ShX sh i -> ShX (PermutePrefix is sh) i
+shxPermutePrefix perm sh = shxAppend (shxPermute perm (shxTakeLenPerm perm sh)) (shxDropLenPerm perm sh)
-ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
-ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
-ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+ssxTakeLenPerm :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLenPerm = coerce (shxTakeLenPerm @())
-ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+ssxDropLenPerm :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLenPerm = coerce (shxDropLenPerm @())
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+ssxPermute = coerce (shxPermute @())
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh)
+ssxIndex k = coerce (shxIndex @() k)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+ssxPermutePrefix = coerce (shxPermutePrefix @())
+
+
+ixxTakeLenPerm :: forall i is sh. Perm is -> IxX sh i -> IxX (TakeLen is sh) i
+ixxTakeLenPerm PNil _ = ZIX
+ixxTakeLenPerm (_ `PCons` is) (n :.% sh) = n :.% ixxTakeLenPerm is sh
+ixxTakeLenPerm (_ `PCons` _) ZIX = error "Permutation longer than shape"
+
+ixxDropLenPerm :: forall i is sh. Perm is -> IxX sh i -> IxX (DropLen is sh) i
+ixxDropLenPerm PNil sh = sh
+ixxDropLenPerm (_ `PCons` is) (_ :.% sh) = ixxDropLenPerm is sh
+ixxDropLenPerm (_ `PCons` _) ZIX = error "Permutation longer than shape"
+
+ixxPermute :: forall i is sh. Perm is -> IxX sh i -> IxX (Permute is sh) i
+ixxPermute PNil _ = ZIX
+ixxPermute (i `PCons` (is :: Perm is')) (sh :: IxX sh f) =
+ ixxIndex i sh :.% ixxPermute is sh
+
+ixxIndex :: forall j i sh. SNat i -> IxX sh j -> j
+ixxIndex SZ (n :.% _) = n
+ixxIndex (SS i) (_ :.% sh) = ixxIndex i sh
+ixxIndex _ ZIX = error "Index into empty shape"
-shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix perm sh = ixxAppend (ixxPermute perm (ixxTakeLenPerm perm sh)) (ixxDropLenPerm perm sh)
-- * Operations on permutations
@@ -224,7 +261,7 @@ permInverse = \perm k ->
++ " ; invperm = " ++ show invperm)
(permCheckPermutation invperm
(k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
+ (\ssh -> case permCheckInverse perm invperm ssh of
Just eq -> eq
Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
++ " ; invperm = " ++ show invperm)))
@@ -238,9 +275,9 @@ permInverse = \perm k ->
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ permCheckInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
+ permCheckInverse perm perminv ssh =
ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
@@ -248,11 +285,50 @@ type family MapSucc is where
MapSucc (i : is) = i + 1 : MapSucc is
permShift1 :: Perm l -> Perm (0 : MapSucc l)
-permShift1 = (SNat @0 `PCons`) . permMapSucc
+permShift1 = (SZ `PCons`) . permMapSucc
where
permMapSucc :: Perm l -> Perm (MapSucc l)
permMapSucc PNil = PNil
- permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+ permMapSucc (sn `PCons` ns) = snatSucc sn `PCons` permMapSucc ns
+
+-- | @PermId n@ is the type of the identity permutation of length @n@.
+type family PermId n where
+ PermId 0 = '[]
+ PermId 1 = '[0]
+ PermId n = PermId (n - 1) ++ '[n - 1]
+
+{- Doesn't type-check:
+permId :: SNat n -> Perm (PermId n)
+permId SZ = PNil
+permId (SS SZ) = PCons SZ PNil
+permId (SS k) = permId k `permAppend` PCons k PNil
+-}
+permId :: forall n. SNat n -> Perm (PermId n)
+permId n = go SZ
+ where
+ go :: forall k l. SNat k -> Perm l
+ go k = if fromSNat' k >= fromSNat' n
+ then gcastWith (unsafeCoerceRefl :: (l :~: '[])) $
+ PNil
+ else gcastWith (unsafeCoerceRefl :: (l :~: k : anything)) $
+ k `PCons` go (SS k)
+
+-- | Note that the second argument is not a valid permutation.
+permAppend :: Perm l -> Perm l2 -> Perm (l ++ l2)
+permAppend PNil l2 = l2
+permAppend (n `PCons` rest) l2 = n `PCons` permAppend rest l2
+
+type family MapPlusN n is where
+ MapPlusN n '[] = '[]
+ MapPlusN n (i : is) = i + n : MapPlusN n is
+
+-- TODO: instead of permAppend and permShiftN define permComp :: Perm l1 -> Perm l2 -> Perm (l1 ++ MapPlusN (Rank l1) l2), where all three are valid permutations
+permShiftN :: forall n l. SNat n -> Perm l -> Perm (PermId n ++ MapPlusN n l)
+permShiftN n = (permId n `permAppend`) . permMapPlusN
+ where
+ permMapPlusN :: Perm l1 -> Perm (MapPlusN n l1)
+ permMapPlusN PNil = PNil
+ permMapPlusN (sn `PCons` ns) = snatPlus sn n `PCons` permMapPlusN ns
-- * Lemmas
@@ -264,7 +340,13 @@ lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
=> StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
lemRankDropLen ZKX PNil = Refl
-lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is)
+ | Refl <- lemRankDropLen sh is
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+ = Refl
+#else
+ = unsafeCoerceRefl
+#endif
lemRankDropLen (_ :!% _) PNil = Refl
lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
new file mode 100644
index 0000000..2d8b624
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -0,0 +1,381 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked (
+ Ranked(Ranked),
+ rquotArray, rremArray, ratan2Array,
+ rshape, rrank,
+ module Data.Array.Nested.Ranked,
+ liftRanked1, liftRanked2,
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
+
+
+remptyArray :: KnownElt a => Ranked 1 a
+remptyArray = mtoRanked (memptyArray ZSX)
+
+-- | The total number of elements in the array.
+rsize :: Elt a => Ranked n a -> Int
+rsize = shrSize . rshape
+
+{-# INLINEABLE rindex #-}
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)
+
+{-# INLINEABLE rindexPartial #-}
+rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (ixxFromIxR idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details; see also
+-- 'rgeneratePrim'.
+rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgenerate sh f
+ | sn <- shrRank sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
+ = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
+
+-- | See 'mgeneratePrim'.
+{-# INLINE rgeneratePrim #-}
+rgeneratePrim :: forall n a i. (PrimElt a, Num i)
+ => IShR n -> (IxR n i -> a) -> Ranked n a
+rgeneratePrim sh f =
+ let g i = f (ixrFromLinear sh i)
+ in rfromVector sh $ VS.generate (shrSize sh) g
+
+-- | See the documentation of 'mlift'.
+{-# INLINE rlift #-}
+rlift :: forall n1 n2 a. Elt a
+ => SNat n2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
+
+-- | See the documentation of 'mlift2'.
+{-# INLINE rlift2 #-}
+rlift2 :: forall n1 n2 n3 a. Elt a
+ => SNat n3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
+rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+
+{-# INLINE rsumOuter1PrimP #-}
+rsumOuter1PrimP :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1PrimP (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (msumOuter1PrimP arr)
+
+{-# INLINEABLE rsumOuter1Prim #-}
+rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+
+{-# INLINE rsumAllPrimP #-}
+rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
+rsumAllPrimP (Ranked arr) = msumAllPrimP arr
+
+{-# INLINE rsumAllPrim #-}
+rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
+rsumAllPrim (Ranked arr) = msumAllPrim arr
+
+rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
+rtranspose perm arr
+ | sn <- rrank arr
+ , Dict <- lemKnownReplicate sn
+ , length perm <= fromSNat' sn
+ = rlift sn
+ (\ssh' -> X.transposeUntyped sn ssh' perm)
+ arr
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+
+rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
+rconcat
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce mconcat
+
+rappend :: forall n a. Elt a
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend arr1 arr2
+ | sn@SNat <- rrank arr1
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n)
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
+ arr1 arr2
+
+rscalar :: Elt a => a -> Ranked 0 a
+rscalar x = Ranked (mscalar x)
+
+{-# INLINEABLE rfromVectorP #-}
+rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
+ | Dict <- lemKnownReplicate (shrRank sh)
+ = Ranked (mfromVectorP (shxFromShR sh) v)
+
+{-# INLINEABLE rfromVector #-}
+rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+
+{-# INLINEABLE rtoVectorP #-}
+rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
+rtoVectorP = coerce mtoVectorP
+
+{-# INLINEABLE rtoVector #-}
+rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
+rtoVector = coerce mtoVector
+
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromListOuterN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'.
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+-- | See 'rfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable.
+rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuterN n l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromList1N' to be able to stream the list.
+--
+-- If the elements are scalars, 'rfromList1Prim' is faster.
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 = coerce mfromList1
+
+-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a
+rfromList1N = coerce mfromList1N
+
+-- | If the elements are scalars, 'rfromListPrimLinear' is faster.
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l)
+
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'rfromList1PrimN' to be able to stream the list.
+rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
+rfromList1Prim = coerce mfromList1Prim
+
+rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a
+rfromList1PrimN = coerce mfromList1PrimN
+
+rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l)
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListLinear :: Elt a => Ranked n a -> [a]
+rtoListLinear (Ranked arr) = mtoListLinear arr
+
+rtoListPrim :: PrimElt a => Ranked 1 a -> [a]
+rtoListPrim (Ranked arr) = mtoListPrim arr
+
+rtoListPrimLinear :: PrimElt a => Ranked n a -> [a]
+rtoListPrimLinear (Ranked arr) = mtoListPrimLinear arr
+
+rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
+rfromOrthotope sn arr
+ | Refl <- lemRankReplicate sn
+ = let xarr = XArray arr
+ in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
+
+rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX @n sh)
+ = arr
+
+runScalar :: Elt a => Ranked 0 a -> a
+runScalar arr = rindex arr ZIR
+
+rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
+rnest n arr
+ | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
+ = coerce (mnest (ssxFromSNat n) (coerce arr))
+
+runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
+runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
+ | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked arr
+
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip = coerce mzip
+
+runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
+runzip = coerce munzip
+
+rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b))
+rrerankPrimP sh2 f (Ranked (M_Ranked arr))
+ = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
+
+-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
+-- input array, then there is no way to deduce the full shape of the output
+-- array (more precisely, the @n2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the @n2@ part of the output shape with zeros.
+--
+-- For example, if:
+--
+-- @
+-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21]
+-- f :: Ranked 2 Int -> Ranked 3 Float
+-- @
+--
+-- then:
+--
+-- @
+-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float)
+-- @
+--
+-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't
+-- know if @f@ intended to return an array with all-zero shape here (it
+-- probably didn't), but there is no better number to put here absent a
+-- subarray of the input to pass to @f@.
+rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b)
+rrerankPrim sh2 f (Ranked (M_Ranked arr)) =
+ Ranked (M_Ranked (mrerankPrim (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
+
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shxFromShR sh) arr)
+
+rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicatePrimP sh x
+ | Dict <- lemKnownReplicate (shrRank sh)
+ = Ranked (mreplicatePrimP (shxFromShR sh) x)
+
+rreplicatePrim :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x)
+
+rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (msliceN i n arr)
+
+rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mrev1 arr)
+
+rreshape :: forall n n' a. Elt a
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' rarr@(Ranked arr)
+ | Dict <- lemKnownReplicate (rrank rarr)
+ , Dict <- lemKnownReplicate (shrRank sh')
+ = Ranked (mreshape (shxFromShR sh') arr)
+
+rflatten :: Elt a => Ranked n a -> Ranked 1 a
+rflatten (Ranked arr) = mtoRanked (mflatten arr)
+
+riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
+riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+
+-- | Throws if the array is empty.
+rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rminIndexPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixrFromIxX (mminIndexPrim arr)
+
+-- | Throws if the array is empty.
+rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rmaxIndexPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixrFromIxX (mmaxIndexPrim arr)
+
+{-# INLINEABLE rdot1Inner #-}
+rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
+rdot1Inner arr1 arr2
+ | SNat <- rrank arr1
+ , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
+ = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'rdot1Inner' if applicable.
+{-# INLINE rdot #-}
+rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
+rdot = coerce mdot
+
+rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
+rtoXArrayPrimP (Ranked arr) = first shrFromShX (mtoXArrayPrimP arr)
+
+rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
+rtoXArrayPrim (Ranked arr) = first shrFromShX (mtoXArrayPrim arr)
+
+rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
new file mode 100644
index 0000000..beedbcf
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -0,0 +1,271 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Ranked.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
+#endif
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
+deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Ranked n a) where
+ showsPrec d arr@(Ranked marr) =
+ let sh = show (shrToList (rshape arr))
+ in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Ranked n a) where
+ rnf (Ranked arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance Elt a => Elt (Ranked n a) where
+ {-# INLINE mshape #-}
+ mshape (M_Ranked arr) = mshape arr
+ {-# INLINE mindex #-}
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
+
+ {-# INLINE mindexPartial #-}
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
+
+ mfromListOuterSN :: SNat m -> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Just m : sh) (Ranked n a)
+ mfromListOuterSN sn l = M_Ranked (mfromListOuterSN sn (coerce l))
+
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+
+ {-# INLINE mlift #-}
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
+
+ {-# INLINE mlift2 #-}
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ {-# INLINE mliftL #-}
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
+ @(NonEmpty (Mixed sh2 (Ranked n a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
+ mconcat l = M_Ranked (mconcat (coerce l))
+
+ mrnf (M_Ranked arr) = mrnf arr
+
+ type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+
+ mshapeTree (Ranked arr) = first coerce (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Ranked arr) = marrayStrides arr
+
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+ mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArrayUnsafe sh
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArrayUnsafe sh
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsReplicate idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsReplicate idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
+
+liftRanked1 :: forall n a b.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
+ -> Ranked n a -> Ranked n b
+liftRanked1 = coerce
+
+liftRanked2 :: forall n a b c.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
+ -> Ranked n a -> Ranked n b -> Ranked n c
+liftRanked2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+ (+) = liftRanked2 (+)
+ (-) = liftRanked2 (-)
+ (*) = liftRanked2 (*)
+ negate = liftRanked1 negate
+ abs = liftRanked1 abs
+ signum = liftRanked1 signum
+ fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicatePrim"
+
+instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicatePrim"
+ recip = liftRanked1 recip
+ (/) = liftRanked2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicatePrim"
+ exp = liftRanked1 exp
+ log = liftRanked1 log
+ sqrt = liftRanked1 sqrt
+ (**) = liftRanked2 (**)
+ logBase = liftRanked2 logBase
+ sin = liftRanked1 sin
+ cos = liftRanked1 cos
+ tan = liftRanked1 tan
+ asin = liftRanked1 asin
+ acos = liftRanked1 acos
+ atan = liftRanked1 atan
+ sinh = liftRanked1 sinh
+ cosh = liftRanked1 cosh
+ tanh = liftRanked1 tanh
+ asinh = liftRanked1 asinh
+ acosh = liftRanked1 acosh
+ atanh = liftRanked1 atanh
+ log1p = liftRanked1 GHC.Float.log1p
+ expm1 = liftRanked1 GHC.Float.expm1
+ log1pexp = liftRanked1 GHC.Float.log1pexp
+ log1mexp = liftRanked1 GHC.Float.log1mexp
+
+rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+rquotArray = liftRanked2 mquotArray
+rremArray = liftRanked2 mremArray
+
+ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+ratan2Array = liftRanked2 matan2Array
+
+
+{-# INLINE rshape #-}
+rshape :: Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = coerce (mshape arr)
+
+rrank :: Elt a => Ranked n a -> SNat n
+rrank = shrRank . rshape
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 1c0b9eb..c04f39e 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,8 +1,5 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -10,16 +7,19 @@
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+{-# LANGUAGE TypeAbstractions #-}
+#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -27,337 +27,401 @@
module Data.Array.Nested.Ranked.Shape where
import Control.DeepSeq (NFData(..))
-import Data.Array.Mixed.Types
+import Control.Exception (assert)
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
-import GHC.Generics (Generic)
+import GHC.Exts (build)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
-infixr 3 :::
-
-instance Show i => Show (ListR n i) where
- showsPrec _ = listrShow shows
-
-instance NFData i => NFData (ListR n i) where
- rnf ZR = ()
- rnf (x ::: l) = rnf x `seq` rnf l
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqRank ZR ZR = Just Refl
-listrEqRank (_ ::: sh) (_ ::: sh')
- | Just Refl <- listrEqRank sh sh'
- = Just Refl
-listrEqRank _ _ = Nothing
-
--- | This compares the lists for value equality.
-listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqual ZR ZR = Just Refl
-listrEqual (i ::: sh) (j ::: sh')
- | Just Refl <- listrEqual sh sh'
- , i == j
- = Just Refl
-listrEqual _ _ = Nothing
-
-listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
-listrShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR n' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrLength :: ListR n i -> Int
-listrLength = length
-
-listrRank :: ListR n i -> SNat n
-listrRank ZR = SNat
-listrRank (_ ::: sh) = snatSucc (listrRank sh)
-
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend ZR sh = sh
-listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
-
-listrHead :: ListR (n + 1) i -> i
-listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
-
-listrTail :: ListR (n + 1) i -> ListR n i
-listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
-
-listrInit :: ListR (n + 1) i -> ListR n i
-listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
-listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
-
-listrLast :: ListR (n + 1) i -> i
-listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
-listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
-listrZip ZR ZR = ZR
-listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
-listrZip _ _ = error "listrZip: impossible pattern needlessly required"
-
-listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
-listrZipWith _ ZR ZR = ZR
-listrZipWith f (i ::: irest) (j ::: jrest) =
- f i j ::: listrZipWith f irest jrest
-listrZipWith _ _ _ =
- error "listrZipWith: impossible pattern needlessly required"
-
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "listrPermutePrefix: Index in permutation out of range"
-
+-- * Ranked indices
-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+newtype IxR n i = IxR (IxX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
+pattern ZIR <- IxR (matchZIX @n -> Just Refl)
+ where ZIR = IxR ZIX
+
+matchZIX :: forall n i. IxX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZIX ZIX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZIX _ = Nothing
pattern (:.:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
- where i :.: IxR sh = IxR (i ::: sh)
+pattern i :.: l <- (ixrUncons -> Just (UnconsIxRRes i l))
+ where i :.: IxR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = IxR (i :.% l)
infixr 3 :.:
+data UnconsIxRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsIxRRes i (IxR n i)
+ixrUncons :: forall n1 i. IxR n1 i -> Maybe (UnconsIxRRes i n1)
+ixrUncons (IxR ((:.%) @n @sh i l))
+ | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl
+ , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl =
+ Just (UnconsIxRRes i (IxR @(Rank sh) l))
+ixrUncons (IxR _) = Nothing
+
{-# COMPLETE ZIR, (:.:) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxR n = IxR n Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxR n i)
+#else
instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = listrShow shows l
+ showsPrec _ = ixrShow shows
+#endif
-instance NFData i => NFData (IxR sh i)
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+ixrEqRank :: IxR n i -> IxR n' i -> Maybe (n :~: n')
+ixrEqRank ZIR ZIR = Just Refl
+ixrEqRank (_ :.: sh) (_ :.: sh')
+ | Just Refl <- ixrEqRank sh sh'
+ = Just Refl
+ixrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+ixrEqual :: Eq i => IxR n i -> IxR n' i -> Maybe (n :~: n')
+ixrEqual ZIR ZIR = Just Refl
+ixrEqual (i :.: sh) (j :.: sh')
+ | Just Refl <- ixrEqual sh sh'
+ , i == j
+ = Just Refl
+ixrEqual _ _ = Nothing
-ixrLength :: IxR sh i -> Int
-ixrLength (IxR l) = listrLength l
+{-# INLINE ixrShow #-}
+ixrShow :: forall n i. (i -> ShowS) -> IxR n i -> ShowS
+ixrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> IxR n' i -> ShowS
+ go _ ZIR = id
+ go prefix (x :.: xs) = showString prefix . f x . go "," xs
ixrRank :: IxR n i -> SNat n
-ixrRank (IxR sh) = listrRank sh
+ixrRank ZIR = SNat
+ixrRank (_ :.: sh) = snatSucc (ixrRank sh)
ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
-ixCvtXR :: IIxX sh -> IIxR (Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixCvtRX idx)
+{-# INLINE ixrFromList #-}
+ixrFromList :: SNat n -> [i] -> IxR n i
+ixrFromList topsn topl = assert (fromSNat' topsn == length topl)
+ $ IxR $ IsList.fromList topl
ixrHead :: IxR (n + 1) i -> i
-ixrHead (IxR list) = listrHead list
+ixrHead (i :.: _) = i
ixrTail :: IxR (n + 1) i -> IxR n i
-ixrTail (IxR list) = IxR (listrTail list)
+ixrTail (_ :.: sh) = sh
ixrInit :: IxR (n + 1) i -> IxR n i
-ixrInit (IxR list) = IxR (listrInit list)
+ixrInit (n :.: sh@(_ :.: _)) = n :.: ixrInit sh
+ixrInit (_ :.: ZIR) = ZIR
ixrLast :: IxR (n + 1) i -> i
-ixrLast (IxR list) = listrLast list
+ixrLast (_ :.: sh@(_ :.: _)) = ixrLast sh
+ixrLast (n :.: ZIR) = n
+-- | Performs a runtime check that the lengths are identical.
+ixrCast :: SNat n' -> IxR n i -> IxR n' i
+ixrCast SZ ZIR = ZIR
+ixrCast (SS n) (i :.: l) = i :.: ixrCast n l
+ixrCast _ _ = error "ixrCast: ranks don't match"
+
+-- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked)
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @_ @i)
+ixrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $
+ coerce (ixxAppend @_ @_ @i)
+
+ixrIndex :: forall k n i. (k + 1 <= n) => SNat k -> IxR n i -> i
+ixrIndex SZ (x :.: _) = x
+ixrIndex (SS i) (_ :.: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = ixrIndex i xs
+ixrIndex _ ZIR = error "k + 1 <= 0"
ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
-ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+ixrZip ZIR ZIR = ZIR
+ixrZip (i :.: irest) (j :.: jrest) = (i, j) :.: ixrZip irest jrest
+ixrZip _ _ = error "ixrZip: impossible pattern needlessly required"
+{-# INLINE ixrZipWith #-}
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
-ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+ixrZipWith _ ZIR ZIR = ZIR
+ixrZipWith f (i :.: irest) (j :.: jrest) =
+ f i j :.: ixrZipWith f irest jrest
+ixrZipWith _ _ _ =
+ error "ixrZipWith: impossible pattern needlessly required"
+
+ixrSplitAt :: m <= n' => SNat m -> IxR n' i -> (IxR m i, IxR (n' - m) i)
+ixrSplitAt SZ sh = (ZIR, sh)
+ixrSplitAt (SS m) (n :.: sh) = (\(pre, post) -> (n :.: pre, post)) (ixrSplitAt m sh)
+ixrSplitAt SS{} ZIR = error "m' + 1 <= 0"
+
+ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
+ixrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case ixrRank sh of { shlen@SNat ->
+ let sperm = ixrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
+ where
+ applyPermRFull :: SNat m -> IxR k Int -> IxR m i -> IxR k i
+ applyPermRFull _ ZIR _ = ZIR
+ applyPermRFull sm@SNat (i :.: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> ixrIndex si l :.: applyPermRFull sm perm l
+ EQI -> ixrIndex si l :.: applyPermRFull sm perm l
+ GTI -> error "ixrPermutePrefix: Index in permutation out of range"
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixrToLinear #-}
+ixrToLinear :: Num i => IShR m -> IxR m i -> i
+ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix)
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR = coerce
+
+{-# INLINEABLE ixrFromLinear #-}
+ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i
+ixrFromLinear (ShR sh) i
+ | Refl <- lemRankReplicate (Proxy @m)
+ = ixrFromIxX $ ixxFromLinear sh i
+
+ixrFromIxX :: IxX (Replicate n Nothing) i -> IxR n i
+ixrFromIxX = coerce
+
+shrEnum :: IShR n -> [IIxR n]
+shrEnum = shrEnum'
+
+{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
+shrEnum' :: forall i n. Num i => IShR n -> [IxR n i]
+shrEnum' (ShR sh)
+ | Refl <- lemRankReplicate (Proxy @n)
+ = (coerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh
+
+-- * Ranked shapes
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+newtype ShR n i = ShR (ShX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor)
pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
+pattern ZSR <- ShR (matchZSR @n -> Just Refl)
+ where ZSR = ShR ZSX
+
+matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZSR _ = Nothing
pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
+pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes i sh))
+ where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh)
infixr 3 :$:
+data UnconsShRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsShRRes i (ShR n i)
+shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1)
+shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
+ | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl
+ = Just (UnconsShRRes x (ShR sh'))
+shrUncons (ShR _) = Nothing
+
{-# COMPLETE ZSR, (:$:) #-}
type IShR n = ShR n Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ShR n i)
+#else
instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
-
-instance NFData i => NFData (ShR sh i)
-
-shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
-shCvtXR' ZSX =
- castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
- ZSR
-shCvtXR' (n :$% (idx :: IShX sh))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
- castWith (subst2 (lem1 @sh Refl))
- (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
- where
- lem1 :: forall sh' n' k.
- k : sh' :~: Replicate n' Nothing
- -> Rank sh' + 1 :~: n'
- lem1 Refl = unsafeCoerceRefl
-
- lem2 :: k : sh :~: Replicate n Nothing
- -> sh :~: Replicate (Rank sh) Nothing
- lem2 Refl = unsafeCoerceRefl
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shCvtRX idx)
+ showsPrec d (ShR l) = showsPrec d l
+#endif
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+shrEqRank ZSR ZSR = Just Refl
+shrEqRank (_ :$: sh) (_ :$: sh')
+ | Just Refl <- shrEqRank sh sh'
+ = Just Refl
+shrEqRank _ _ = Nothing
-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+shrEqual ZSR ZSR = Just Refl
+shrEqual (i :$: sh) (i' :$: sh')
+ | Just Refl <- shrEqual sh sh'
+ , i == i'
+ = Just Refl
+shrEqual _ _ = Nothing
shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
+shrLength (ShR l) = shxLength l
-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
+shrRank :: forall n i. ShR n i -> SNat n
+shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
+shrSize (ShR sh) = shxSize sh
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
+-- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@.
+-- We don't report the size of the list in case of errors in order not to retain the list.
+{-# INLINEABLE shrFromList #-}
+shrFromList :: SNat n -> [Int] -> IShR n
+shrFromList snat topl = ShR $ go snat topl
+ where
+ go :: SNat n -> [Int] -> ShX (Replicate n Nothing) Int
+ go SZ [] = ZSX
+ go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")"
+ go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is)
+ go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")"
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
+-- This is equivalent to but faster than @coerce shxToList@.
+{-# INLINEABLE shrToList #-}
+shrToList :: IShR n -> [Int]
+shrToList (ShR l) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ShX sh Int -> is
+ go ZSX = nil
+ go (ConsUnknown i rest) = i `cons` go rest
+ go ConsKnown{} = error "shrToList: impossible case"
+ in go l)
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
+shrHead :: forall n i. ShR (n + 1) i -> i
+shrHead (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxHead @Nothing @(Replicate n Nothing) sh of
+ SUnknown i -> i
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
+shrTail :: forall n i. ShR (n + 1) i -> ShR n i
+shrTail
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (shxTail @_ @_ @i)
-shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
+{-# INLINEABLE shrTakeIx #-}
+shrTakeIx :: forall n n' i j. Proxy n' -> IxR n j -> ShR (n + n') i -> ShR n i
+shrTakeIx _ ZIR _ = ZSR
+shrTakeIx p (_ :.: idx) sh = case sh of n :$: sh' -> n :$: shrTakeIx p idx sh'
+
+{-# INLINEABLE shrDropIx #-}
+shrDropIx :: forall n n' i j. IxR n j -> ShR (n + n') i -> ShR n' i
+shrDropIx ZIR long = long
+shrDropIx (_ :.: short) long = case long of _ :$: long' -> shrDropIx short long'
+
+shrInit :: forall n i. ShR (n + 1) i -> ShR n i
+shrInit
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = -- TODO: change this and all other unsafeCoerceRefl to lemmas:
+ gcastWith (unsafeCoerceRefl
+ :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $
+ coerce (shxInit @i)
+
+shrLast :: forall n i. ShR (n + 1) i -> i
+shrLast (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxLast sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrLast: impossible SKnown"
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+-- | Performs a runtime check that the lengths are identical.
+shrCast :: SNat n' -> ShR n i -> ShR n' i
+shrCast SZ ZSR = ZSR
+shrCast (SS n) (i :$: sh) = i :$: shrCast n sh
+shrCast _ _ = error "shrCast: ranks don't match"
+shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
+shrAppend =
+ -- lemReplicatePlusApp requires an SNat
+ gcastWith (unsafeCoerceRefl
+ :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $
+ coerce (shxAppend @_ @i)
+
+{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+shrZipWith _ ZSR ZSR = ZSR
+shrZipWith f (i :$: irest) (j :$: jrest) =
+ f i j :$: shrZipWith f irest jrest
+shrZipWith _ _ _ =
+ error "shrZipWith: impossible pattern needlessly required"
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i)
+shrSplitAt SZ sh = (ZSR, sh)
+shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh)
+shrSplitAt SS{} ZSR = error "m' + 1 <= 0"
+shrIndex :: forall k sh i. SNat k -> ShR sh i -> i
+shrIndex k (ShR sh) = case shxIndex @i k sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrIndex: impossible SKnown"
+
+-- Copy-pasted from ixrPermutePrefix, probably unavoidably.
+shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i
+shrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case shrRank sh of { shlen@SNat ->
+ let sperm = shrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
+ where
+ applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i
+ applyPermRFull _ ZSR _ = ZSR
+ applyPermRFull sm@SNat (i :$: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> shrIndex si l :$: applyPermRFull sm perm l
+ EQI -> shrIndex si l :$: applyPermRFull sm perm l
+ GTI -> error "shrPermutePrefix: Index in permutation out of range"
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList topl = go (SNat @n) topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
- ++ show (fromSNat (SNat @n)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = Foldable.toList
-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (IxR n i) where
type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
+ fromList = ixrFromList (SNat @n)
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+instance KnownNat n => IsList (IShR n) where
+ type Item (IShR n) = Int
+ fromList = shrFromList (SNat @n)
+ toList = shrToList
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
new file mode 100644
index 0000000..be1bfc5
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -0,0 +1,297 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Shaped (
+ Shaped(Shaped),
+ squotArray, sremArray, satan2Array,
+ sshape,
+ module Data.Array.Nested.Shaped,
+ liftShaped1, liftShaped2,
+) where
+
+import Prelude hiding (mappend)
+
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+import Data.Array.Internal.ShapedG qualified as SG
+import Data.Array.Internal.ShapedS qualified as SS
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Base
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+import Data.Array.XArray qualified as X
+
+
+semptyArray :: forall sh a. KnownElt a => ShS sh -> Shaped (0 : sh) a
+semptyArray sh = Shaped (memptyArray (shxFromShS sh))
+
+srank :: Elt a => Shaped sh a -> SNat (Rank sh)
+srank = shsRank . sshape
+
+-- | The total number of elements in the array.
+ssize :: Elt a => Shaped sh a -> Int
+ssize = shsSize . sshape
+
+{-# INLINEABLE sindex #-}
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
+
+{-# INLINEABLE sindexPartial #-}
+sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
+sindexPartial sarr@(Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) idx (sshape sarr)) (Proxy @sh2))) arr)
+ (ixxFromIxS idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX))
+
+-- | See 'mgeneratePrim'.
+{-# INLINE sgeneratePrim #-}
+sgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => ShS sh -> (IxS sh i -> a) -> Shaped sh a
+sgeneratePrim sh f =
+ let g i = f (ixsFromLinear sh i)
+ in sfromVector sh $ VS.generate (shsSize sh) g
+
+-- | See the documentation of 'mlift'.
+{-# INLINE slift #-}
+slift :: forall sh1 sh2 a. Elt a
+ => ShS sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
+
+-- | See the documentation of 'mlift'.
+{-# INLINE slift2 #-}
+slift2 :: forall sh1 sh2 sh3 a. Elt a
+ => ShS sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+
+{-# INLINE ssumOuter1PrimP #-}
+ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
+
+{-# INLINEABLE ssumOuter1Prim #-}
+ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+
+{-# INLINE ssumAllPrimP #-}
+ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
+ssumAllPrimP (Shaped arr) = msumAllPrimP arr
+
+{-# INLINE ssumAllPrim #-}
+ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
+ssumAllPrim (Shaped arr) = msumAllPrim arr
+
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ , Refl <- lemTakeLenMapJust perm (sshape sarr)
+ , Refl <- lemDropLenMapJust perm (sshape sarr)
+ , Refl <- lemPermuteMapJust perm (shsTakeLenPerm perm (sshape sarr))
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm (sshape sarr))) (Proxy @(DropLen is sh))
+ = Shaped (mtranspose perm arr)
+
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
+
+sscalar :: Elt a => a -> Shaped '[] a
+sscalar x = Shaped (mscalar x)
+
+{-# INLINEABLE sfromVectorP #-}
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+
+{-# INLINEABLE sfromVector #-}
+sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+
+{-# INLINEABLE stoVectorP #-}
+stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
+stoVectorP = coerce mtoVectorP
+
+{-# INLINEABLE stoVector #-}
+stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
+stoVector = coerce mtoVector
+
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'.
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter = coerce mfromListOuterSN
+
+-- | If the elements are scalars, 'sfromList1Prim' is faster.
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 = coerce mfromList1SN
+
+-- | If the elements are scalars, 'sfromListPrimLinear' is faster.
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+
+sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim = coerce mfromList1PrimSN
+
+sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l)
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
+
+stoListLinear :: Elt a => Shaped sh a -> [a]
+stoListLinear (Shaped arr) = mtoListLinear arr
+
+stoListPrim :: PrimElt a => Shaped '[n] a -> [a]
+stoListPrim (Shaped arr) = mtoListPrim arr
+
+stoListPrimLinear :: PrimElt a => Shaped sh a -> [a]
+stoListPrimLinear (Shaped arr) = mtoListPrimLinear arr
+
+sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
+sfromOrthotope sh (SS.A (SG.A arr)) =
+ Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))
+
+stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a
+stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr ZIS
+
+snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
+snest sh arr
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr))
+
+sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
+sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
+ | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
+ = Shaped arr
+
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip = coerce mzip
+
+sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
+sunzip = coerce munzip
+
+srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b))
+srerankPrimP sh2 f (Shaped (M_Shaped arr))
+ = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
+
+-- | See the caveats at 'mrerankPrim'.
+srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b)
+srerankPrim sh2 f (Shaped (M_Shaped arr)) =
+ Shaped (M_Shaped (mrerankPrim (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
+
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shxFromShS sh) arr)
+
+sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x)
+
+sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x)
+
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n arr =
+ let _ :$$ sh = sshape arr
+ in slift (n :$$ sh) (\_ -> X.slice i n) arr
+
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
+
+sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
+
+sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
+sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr
+
+siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
+siota sn = Shaped (miota sn)
+
+-- | Throws if the array is empty.
+sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr)
+
+-- | Throws if the array is empty.
+smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr)
+
+{-# INLINEABLE sdot1Inner #-}
+sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
+sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sshape sarr1 of
+ _ :$$ _
+ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
+ -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
+ _ -> error "unreachable"
+
+{-# INLINE sdot #-}
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'sdot1Inner' if applicable.
+sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
+sdot = coerce mdot
+
+stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr)
+
+stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr)
+
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr)
+
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr)
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
new file mode 100644
index 0000000..a5e6247
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -0,0 +1,266 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Shaped.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Data.Array.Nested.Ranked.Base.Ranked',
+-- the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+#endif
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Shaped n a) where
+ showsPrec d arr@(Shaped marr) =
+ let sh = show (shsToList (sshape arr))
+ in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Shaped sh a) where
+ rnf (Shaped arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
+
+instance Elt a => Elt (Shaped sh a) where
+ {-# INLINE mshape #-}
+ mshape (M_Shaped arr) = mshape arr
+ {-# INLINE mindex #-}
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ {-# INLINE mindexPartial #-}
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a)
+ mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ {-# INLINE mlift #-}
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ {-# INLINE mlift2 #-}
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ {-# INLINE mliftL #-}
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
+ @(NonEmpty (Mixed sh2 (Shaped sh a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ mconcat l = M_Shaped (mconcat (coerce l))
+
+ mrnf (M_Shaped arr) = mrnf arr
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first coerce (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Shaped arr) = marrayStrides arr
+
+ mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWriteLinear idx (Shaped arr) vecs =
+ mvecsWriteLinear idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+ mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArrayUnsafe sh
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArrayUnsafe sh
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsReplicate idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsReplicate idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+liftShaped1 :: forall sh a b.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
+ -> Shaped sh a -> Shaped sh b
+liftShaped1 = coerce
+
+liftShaped2 :: forall sh a b c.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
+ -> Shaped sh a -> Shaped sh b -> Shaped sh c
+liftShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = liftShaped2 (+)
+ (-) = liftShaped2 (-)
+ (*) = liftShaped2 (*)
+ negate = liftShaped1 negate
+ abs = liftShaped1 abs
+ signum = liftShaped1 signum
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim"
+
+instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim"
+ recip = liftShaped1 recip
+ (/) = liftShaped2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim"
+ exp = liftShaped1 exp
+ log = liftShaped1 log
+ sqrt = liftShaped1 sqrt
+ (**) = liftShaped2 (**)
+ logBase = liftShaped2 logBase
+ sin = liftShaped1 sin
+ cos = liftShaped1 cos
+ tan = liftShaped1 tan
+ asin = liftShaped1 asin
+ acos = liftShaped1 acos
+ atan = liftShaped1 atan
+ sinh = liftShaped1 sinh
+ cosh = liftShaped1 cosh
+ tanh = liftShaped1 tanh
+ asinh = liftShaped1 asinh
+ acosh = liftShaped1 acosh
+ atanh = liftShaped1 atanh
+ log1p = liftShaped1 GHC.Float.log1p
+ expm1 = liftShaped1 GHC.Float.expm1
+ log1pexp = liftShaped1 GHC.Float.log1pexp
+ log1mexp = liftShaped1 GHC.Float.log1mexp
+
+squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+squotArray = liftShaped2 mquotArray
+sremArray = liftShaped2 mremArray
+
+satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+satan2Array = liftShaped2 matan2Array
+
+
+{-# INLINE sshape #-}
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = coerce (mshape arr)
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 6c43fa7..60e0252 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,8 +1,5 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -10,7 +7,6 @@
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -20,6 +16,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -27,235 +24,173 @@
module Data.Array.Nested.Shaped.Shape where
import Control.DeepSeq (NFData(..))
-import Data.Array.Mixed.Types
+import Control.Exception (assert)
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product qualified as Fun
import Data.Kind (Constraint, Type)
-import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (withDict)
-import GHC.Generics (Generic)
+import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
-type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
-infixr 3 ::$
-
-instance (forall n. Show (f n)) => Show (ListS sh f) where
- showsPrec _ = listsShow shows
-
-instance (forall m. NFData (f m)) => NFData (ListS n f) where
- rnf ZS = ()
- rnf (x ::$ l) = rnf x `seq` rnf l
-
-data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
-listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
-listsUncons ZS = Nothing
-
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqType ZS ZS = Just Refl
-listsEqType (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listsEqType sh sh'
- = Just Refl
-listsEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqual ZS ZS = Just Refl
-listsEqual (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listsEqual sh sh'
- = Just Refl
-listsEqual _ _ = Nothing
-
-listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-listsFmap _ ZS = ZS
-listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-
-listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFold _ ZS = mempty
-listsFold f (x ::$ xs) = f x <> listsFold f xs
-
-listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
-listsShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListS sh' f -> ShowS
- go _ ZS = id
- go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-
-listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFold (\_ -> Sum 1)
-
-listsRank :: ListS sh f -> SNat (Rank sh)
-listsRank ZS = SNat
-listsRank (_ ::$ sh) = snatSucc (listsRank sh)
-
-listsToList :: ListS sh (Const i) -> [i]
-listsToList ZS = []
-listsToList (Const i ::$ is) = i : listsToList is
-
-listsHead :: ListS (n : sh) f -> f n
-listsHead (i ::$ _) = i
-
-listsTail :: ListS (n : sh) f -> ListS sh f
-listsTail (_ ::$ sh) = sh
-
-listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
-listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
-listsInit (_ ::$ ZS) = ZS
-
-listsLast :: ListS (n : sh) f -> f (Last (n : sh))
-listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
-listsLast (n ::$ ZS) = n
-
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
-listsAppend ZS idx' = idx'
-listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-
-listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
-listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) =
- Fun.Pair i j ::$ listsZip is js
-
-listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
- -> ListS sh h
-listsZipWith _ ZS ZS = ZS
-listsZipWith f (i ::$ is) (j ::$ js) =
- f i j ::$ listsZipWith f is js
-
-listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
-listsTakeLenPerm PNil _ = ZS
-listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
-listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
-listsDropLenPerm PNil sh = sh
-listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
-listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
-listsPermute PNil _ = ZS
-listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
-
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh
-listsIndex _ _ _ ZS = error "Index into empty shape"
-
-listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
-listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-
+-- * Shaped indices
-- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord, Generic)
+newtype IxS sh i = IxS (IxX (MapJust sh) i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
+pattern ZIS <- IxS (matchZIX -> Just Refl)
+ where ZIS = IxS ZIX
+
+matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZIX _ = Nothing
pattern (:.$)
:: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
+pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l))
+ where i :.$ IxS l = IxS (i :.% l)
infixr 3 :.$
+data UnconsIxSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i)
+ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1)
+ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1)
+ , Refl <- lemMapJustCons @sh1 Refl =
+ Just (UnconsIxSRes i (IxS l))
+ixsUncons (IxS _) = Nothing
+
{-# COMPLETE ZIS, (:.$) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxS sh = IxS sh Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxS sh i)
+#else
instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
-
-instance Functor (IxS sh) where
- fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
+ showsPrec _ l = ixsShow shows l
+#endif
-instance Foldable (IxS sh) where
- foldMap f (IxS l) = listsFold (f . getConst) l
+ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS
+ixsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> IxS sh' i -> ShowS
+ go _ ZIS = id
+ go prefix (x :.$ xs) = showString prefix . f x . go "," xs
-instance NFData i => NFData (IxS sh i)
+ixsRank :: IxS sh i -> SNat (Rank sh)
+ixsRank ZIS = SNat
+ixsRank (_ :.$ sh) = snatSucc (ixsRank sh)
-ixsLength :: IxS sh i -> Int
-ixsLength (IxS l) = listsLength l
+{-# INLINE ixsFromList #-}
+ixsFromList :: ShS sh -> [i] -> IxS sh i
+ixsFromList sh l = assert (shsLength sh == length l)
+ $ IxS $ IsList.fromList l
-ixsRank :: IxS sh i -> SNat (Rank sh)
-ixsRank (IxS l) = listsRank l
+{-# INLINE ixsFromIxS #-}
+ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS sh l = assert (length sh == length l)
+ $ IxS $ IsList.fromList l
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
-
-ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
-ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
-
ixsHead :: IxS (n : sh) i -> i
-ixsHead (IxS list) = getConst (listsHead list)
+ixsHead (i :.$ _) = i
ixsTail :: IxS (n : sh) i -> IxS sh i
-ixsTail (IxS list) = IxS (listsTail list)
+ixsTail (_ :.$ sh) = sh
ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
-ixsInit (IxS list) = IxS (listsInit list)
+ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh
+ixsInit (_ :.$ ZIS) = ZIS
ixsLast :: IxS (n : sh) i -> i
-ixsLast (IxS list) = getConst (listsLast list)
+ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh
+ixsLast (n :.$ ZIS) = n
+
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
-ixsAppend = coerce (listsAppend @_ @(Const i))
+ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $
+ coerce (ixxAppend @_ @_ @i)
-ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
+ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
-ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
+{-# INLINE ixsZipWith #-}
+ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k
ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
+ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i
+ixsTakeLenPerm PNil _ = ZIS
+ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh
+ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape"
+
+ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i
+ixsDropLenPerm PNil sh = sh
+ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh
+ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape"
+
+ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i
+ixsPermute PNil _ = ZIS
+ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) =
+ case ixsIndex i sh of
+ item -> item :.$ ixsPermute is sh
+
+ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j
+ixsIndex SZ (n :.$ _) = n
+ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh
+ixsIndex _ ZIS = error "Index into empty shape"
+
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh)
+
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS = coerce
+
+{-# INLINEABLE ixsFromLinear #-}
+ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i
+ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i
+
+ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX = coerce
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' (ShS sh) = (coerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh
+
+-- * Shaped shapes
-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
@@ -263,32 +198,48 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Eq, Ord, Generic)
+newtype ShS sh = ShS (ShX (MapJust sh) Int)
+ deriving (NFData)
+
+instance Eq (ShS sh) where _ == _ = True
+instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
+pattern ZSS <- ShS (matchZSX -> Just Refl)
+ where ZSS = ShS ZSX
+
+matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZSX _ = Nothing
pattern (:$$)
:: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
-
+pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh))
+ where i :$$ ShS sh = ShS (SKnown i :$% sh)
infixr 3 :$$
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
+shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1)
+shsUncons (ShS (SKnown x :$% sh')) | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsShSRes x (ShS sh'))
+shsUncons (ShS _) = Nothing
+
{-# COMPLETE ZSS, (:$$) #-}
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (ShS sh)
+#else
instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
-
-instance NFData (ShS sh) where
- rnf (ShS ZS) = ()
- rnf (ShS (SNat ::$ l)) = rnf (ShS l)
+ showsPrec d (ShS shx) = showsPrec d shx
+#endif
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+ testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of
+ Nothing -> Nothing
+ Just Refl -> Just unsafeCoerceRefl
-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
@@ -296,62 +247,117 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality
shsLength :: ShS sh -> Int
-shsLength (ShS l) = listsLength l
+shsLength (ShS shx) = shxLength shx
-shsRank :: ShS sh -> SNat (Rank sh)
-shsRank (ShS l) = listsRank l
+shsRank :: forall sh. ShS sh -> SNat (Rank sh)
+shsRank (ShS shx) | Refl <- lemRankMapJust (Proxy @sh) =
+ shxRank shx
-shsSize :: ShS sh -> Int
-shsSize ZSS = 1
-shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+lemRankMapJust :: proxy sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust _ = unsafeCoerceRefl
-shsToList :: ShS sh -> [Int]
-shsToList ZSS = []
-shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+shsSize :: ShS sh -> Int
+shsSize (ShS sh) = shxSize sh
-shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
-shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
+-- | This is a partial @const@ that fails when the second argument
+-- doesn't match the first. We don't report the size of the list
+-- in case of errors in order not to retain the list.
+{-# INLINEABLE shsFromList #-}
+shsFromList :: ShS sh -> [Int] -> ShS sh
+shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0
where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shCvtXS' (SUnknown _ :$% _) = error "impossible"
+ go :: ShX sh' Int -> [Int] -> ()
+ go ZSX [] = ()
+ go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")"
+ go (ConsKnown sn sh) (i : is)
+ | i == fromSNat' sn = go sh is
+ | otherwise = error "shsFromList: Value does not match typing"
+ go ConsUnknown{} _ = error "shsFromList: impossible case"
+ go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")"
-shCvtSX :: ShS sh -> IShX (MapJust sh)
-shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
+-- This is equivalent to but faster than @coerce shxToList@.
+{-# INLINEABLE shsToList #-}
+shsToList :: ShS sh -> [Int]
+shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ShX sh Int -> is
+ go ZSX = nil
+ go ConsUnknown{} = error "shsToList: impossible case"
+ go (ConsKnown snat rest) = fromSNat' snat `cons` go rest
+ in go l)
shsHead :: ShS (n : sh) -> SNat n
-shsHead (ShS list) = listsHead list
+shsHead (ShS shx) = case shxHead shx of
+ SKnown SNat -> SNat
+
+shsTail :: forall n sh. ShS (n : sh) -> ShS sh
+shsTail = coerce (shxTail @_ @_ @Int)
-shsTail :: ShS (n : sh) -> ShS sh
-shsTail (ShS list) = ShS (listsTail list)
+{-# INLINEABLE shsTakeIx #-}
+shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh
+shsTakeIx _ ZIS _ = ZSS
+shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh'
-shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
-shsInit (ShS list) = ShS (listsInit list)
+{-# INLINEABLE shsDropIx #-}
+shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh'
+shsDropIx ZIS long = long
+shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long'
-shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
-shsLast (ShS list) = listsLast list
+shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
+shsInit =
+ gcastWith (unsafeCoerceRefl
+ :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $
+ coerce (shxInit @Int)
+
+shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $
+ case shxLast shx of
+ SKnown SNat -> SNat
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
-shsAppend = coerce (listsAppend @_ @SNat)
+shsAppend =
+ gcastWith (unsafeCoerceRefl
+ :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $
+ coerce (shxAppend @_ @Int)
+
+shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLenPerm =
+ gcastWith (unsafeCoerceRefl
+ :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $
+ coerce (shxTakeLenPerm @Int)
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
+shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh)
+shsDropLenPerm =
+ gcastWith (unsafeCoerceRefl
+ :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $
+ coerce (shxDropLenPerm @Int)
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
+shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute =
+ gcastWith (unsafeCoerceRefl
+ :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $
+ coerce (shxPermute @Int)
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex i (ShS sh) =
+ gcastWith (unsafeCoerceRefl
+ :: Index i (MapJust sh) :~: Just (Index i sh)) $
+ case shxIndex @Int i sh of
+ SKnown SNat -> SNat
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+shsPermutePrefix perm (ShS shx)
+ {- TODO: here and elsewhere, solve the module dependency cycle and add this:
+ | Refl <- lemTakeLenMapJust perm sh
+ , Refl <- lemDropLenMapJust perm sh
+ , Refl <- lemPermuteMapJust perm sh
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm sh)) (shsDropLenPerm perm sh) -}
+ = gcastWith (unsafeCoerceRefl
+ :: Permute is (TakeLen is (MapJust sh))
+ ++ DropLen is (MapJust sh)
+ :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $
+ ShS (shxPermutePrefix perm shx)
type family Product sh where
Product '[] = 1
@@ -369,7 +375,7 @@ instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
-withKnownShS k = withDict @(KnownShS sh) k
+withKnownShS = withDict @(KnownShS sh)
shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
@@ -380,37 +386,14 @@ shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
--- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = listsToList
-
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShS sh => IsList (IxS sh i) where
type Item (IxS sh i) = i
- fromList = IxS . IsList.fromList
+ fromList = ixsFromList (knownShS @sh)
toList = Foldable.toList
-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = shsFromList (knownShS @sh)
toList = shsToList
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 838e2b0..8e98ff2 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -5,21 +5,28 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# OPTIONS -Wno-simplifiable-class-constraints #-}
{-|
This module is API-compatible with "Data.Array.Nested", except that inputs and
-outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods
-also have additional 'Show' constraints.
+outputs of the methods are traced to 'stderr'. Thus the methods also have
+additional 'Show' constraints.
->>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7))
->>> length (show res) `seq` ()
-oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))]
-oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))]
-oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))]
-oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))]
-oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))]
-oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))]
->>> res
-Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35]))))
+>>> rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7))
+oxtrace: (riota _ ... = rfromListLinear [6] [0,1,2,3,4,5])
+oxtrace: (rreshape [2,3] (rfromListLinear [6] [0,1,2,3,4,5]) ... = rfromListLinear [2,3] [0,1,2,3,4,5])
+oxtrace: (rtranspose [1,0] (rfromListLinear [2,3] [0,1,2,3,4,5]) ... = rfromListLinear [3,2] [0,3,1,4,2,5])
+oxtrace: (rscalar _ ... = rfromListLinear [] [7])
+oxtrace: (rreplicate [6] (rfromListLinear [] [7]) ... = rreplicate [6] 7)
+oxtrace: (rreshape [3,2] (rreplicate [6] 7) ... = rreplicate [3,2] 7)
+rfromListLinear [3,2] [0,21,7,28,14,35]
+
+The part up until and including the @...@ is printed after @seq@ing the
+arguments; the @=@ and further is printed after @seq@ing the result of the
+operation. Do note that tracing means that the functions in this module are
+potentially __stricter__ than the plain ones in "Data.Array.Nested".
+
+Arguments that this module does not know how to @show@, probably due to
+laziness on my side, are printed as @_@.
-}
module Data.Array.Nested.Trace (
-- * Traced variants
@@ -27,20 +34,19 @@ module Data.Array.Nested.Trace (
-- * Re-exports from the plain "Data.Array.Nested" module
Ranked(Ranked),
- ListR(ZR, (:::)),
IxR(..), IIxR,
ShR(..), IShR,
Shaped(Shaped),
- ListS(ZS, (::$)),
IxS(..), IIxS,
ShS(..), KnownShS(..),
Mixed,
IxX(..), IIxX,
- ShX(..), KnownShX(..),
+ ShX(..), KnownShX(..), IShX,
StaticShX(..),
SMayNat(..),
+ Conversion(..),
Elt,
PrimElt,
@@ -51,10 +57,10 @@ module Data.Array.Nested.Trace (
Storable,
SNat, pattern SNat,
pattern SZ, pattern SS,
- Perm(..),
+ Perm(..), PermR,
IsPermutation,
KnownPerm(..),
- NumElt, FloatElt,
+ NumElt, IntElt, FloatElt,
Rank, Product,
Replicate,
MapJust,
@@ -67,4 +73,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rgeneratePrim, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoListOuter, 'rtoList, 'rtoListLinear, 'rtoListPrim, 'rtoListPrimLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'sgeneratePrim, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoListOuter, 'stoList, 'stoListLinear, 'stoListPrim, 'stoListPrimLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'mgeneratePrim, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoListOuter, 'mtoList, 'mtoListLinear, 'mtoListPrim, 'mtoListPrimLinear, 'msliceN, 'msliceSN, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs
index 4b388e3..644b4bd 100644
--- a/src/Data/Array/Nested/Trace/TH.hs
+++ b/src/Data/Array/Nested/Trace/TH.hs
@@ -4,11 +4,11 @@
module Data.Array.Nested.Trace.TH where
import Control.Monad (zipWithM)
-import Data.List (foldl', intersperse)
+import Data.List (foldl')
import Data.Maybe (isJust)
import Language.Haskell.TH hiding (cxt)
-
-import Debug.Trace qualified as Debug
+import System.IO (hPutStr, stderr)
+import System.IO.Unsafe (unsafePerformIO)
import Data.Array.Nested
@@ -20,7 +20,7 @@ splitFunTy = \case
in (vars, cx, t1 : args, ret)
ForallT vs cx' t ->
let (vars, cx, args, ret) = splitFunTy t
- in (vars ++ vs, cx ++ cx', args, ret)
+ in (vs ++ vars, cx' ++ cx, args, ret)
t -> ([], [], [], t)
data Arg = RRanked Type Arg
@@ -30,17 +30,27 @@ data Arg = RRanked Type Arg
| ROther Type
deriving (Show)
--- TODO: always returns Just
recognise :: Type -> Maybe Arg
recognise (ConT name `AppT` sht `AppT` ty)
- | name == ''Ranked = RRanked sht <$> recognise ty
- | name == ''Shaped = RShaped sht <$> recognise ty
- | name == ''Mixed = RMixed sht <$> recognise ty
+ | name == ''Ranked = Just (RRanked sht (recogniseElt ty))
+ | name == ''Shaped = Just (RShaped sht (recogniseElt ty))
+ | name == ''Mixed = Just (RMixed sht (recogniseElt ty))
+ | name == ''Conversion = Just (RShowable ty)
recognise ty@(ConT name `AppT` _)
- | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] =
+ | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat, ''Perm] =
Just (RShowable ty)
+recognise ty@(ConT name)
+ | name == ''PermR = Just (RShowable ty)
+recognise (ListT `AppT` ty) = Just (ROther ty)
recognise _ = Nothing
+recogniseElt :: Type -> Arg
+recogniseElt (ConT name `AppT` sht `AppT` ty)
+ | name == ''Ranked = RRanked sht (recogniseElt ty)
+ | name == ''Shaped = RShaped sht (recogniseElt ty)
+ | name == ''Mixed = RMixed sht (recogniseElt ty)
+recogniseElt ty = ROther ty
+
realise :: Arg -> Type
realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty
realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty
@@ -62,37 +72,58 @@ mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT ''
mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty]
mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty]
-convertType :: Type -> Q (Type, [Bool], Bool)
+-- If you pass a polymorphic function to seq, GHC wants to monomorphise and
+-- doesn't know how to instantiate the type variables. Just don't, I guess.
+isSeqable :: Type -> Bool
+isSeqable ForallT{} = False
+isSeqable (AppT a b) = isSeqable a && isSeqable b
+isSeqable _ = True -- yolo, I guess
+
+convertType :: Type -> Q (Type, [Bool], [Bool], Bool)
convertType typ =
let (tybndrs, cxt, args, ret) = splitFunTy typ
- argrels = map recognise args
- retrel = recognise ret
+ argdescrs = map recognise args
+ retdescr = recognise ret
in return
(ForallT tybndrs
(cxt ++ [constr
- | Just rel <- retrel : argrels
+ | Just rel <- retdescr : argdescrs
, constr <- mkShow rel])
(foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args)
- ,map isJust argrels
- ,isJust retrel)
+ ,map isJust argdescrs
+ ,map isSeqable args
+ ,isJust retdescr)
convertFun :: Name -> Q [Dec]
convertFun funname = do
defname <- newName (nameBase funname)
- (convty, argarrs, retarr) <- reifyType funname >>= convertType
- names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..]
+ -- "ok": whether we understand this type enough to be able to show it
+ (convty, argoks, argsseqable, retok) <- reifyType funname >>= convertType
+ names <- zipWithM (\_ i -> newName ('x' : show i)) argoks [1::Int ..]
+ -- let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
resname <- newName "res"
- let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
+ let traceCall str val = VarE 'traceNoNewline `AppE` str `AppE` val
+ let msg1 = [LitE (StringL ("oxtrace: (" ++ nameBase funname ++ " "))] ++
+ [if ok
+ then VarE 'showsPrec `AppE` LitE (IntegerL 11) `AppE` VarE n `AppE` LitE (StringL " ")
+ else LitE (StringL "_ ")
+ | (n, ok) <- zip names argoks] ++
+ [LitE (StringL "...")]
+ let msg2 | retok = [LitE (StringL " = "), VarE 'show `AppE` VarE resname, LitE (StringL ")\n")]
+ | otherwise = [LitE (StringL " = _)\n")]
let ex = LetE [ValD (VarP resname)
(NormalB (foldl' AppE (VarE funname) (map VarE names)))
- []]
- (VarE 'Debug.trace
- `AppE` (VarE 'concat `AppE` ListE
- ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++
- intersperse (LitE (StringL ", "))
- (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++
- [LitE (StringL "]")]))
- `AppE` VarE resname)
+ []] $
+ flip (foldr AppE) [VarE 'seq `AppE` VarE n | (n, True) <- zip names argsseqable] $
+ traceCall (VarE 'concat `AppE` ListE msg1) $
+ VarE 'seq `AppE` VarE resname `AppE`
+ traceCall (VarE 'concat `AppE` ListE msg2) (VarE resname)
return
[SigD defname convty
,FunD defname [Clause (map VarP names) (NormalB ex) []]]
+
+{-# NOINLINE traceNoNewline #-}
+traceNoNewline :: String -> a -> a
+traceNoNewline str x = unsafePerformIO $ do
+ hPutStr stderr str
+ return x
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Nested/Types.hs
index 3f5b1e7..ec1b3dc 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -6,13 +6,16 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Types (
+module Data.Array.Nested.Types (
+ -- * Reasoning helpers
+ subst1, subst2,
+
-- * Reified evidence of a type class
Dict(..),
@@ -27,6 +30,7 @@ module Data.Array.Mixed.Types (
Replicate,
lemReplicateSucc,
MapJust,
+ lemMapJustEmpty, lemMapJustCons, lemMapJustHead,
Head,
Tail,
Init,
@@ -42,13 +46,21 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Unsafe.Coerce qualified
+-- Reasoning helpers
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
-- | Evidence for the constraint @c a@.
data Dict c a where
Dict :: c a => Dict c a
+{-# INLINE fromSNat' #-}
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
sameNat' n@SNat m@SNat = sameNat n m
@@ -97,13 +109,23 @@ type family Replicate n a where
Replicate 0 a = '[]
Replicate n a = a : Replicate (n - 1) a
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerceRefl
+lemReplicateSucc :: forall a n proxy.
+ proxy n -> a : Replicate n a :~: Replicate (n + 1) a
+lemReplicateSucc _ = unsafeCoerceRefl
-type family MapJust l where
+type family MapJust l = r | r -> l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
+lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[]
+lemMapJustEmpty Refl = unsafeCoerceRefl
+
+lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
+lemMapJustCons Refl = unsafeCoerceRefl
+
+lemMapJustHead :: proxy sh1 -> Head (MapJust sh1) :~: Just (Head sh1)
+lemMapJustHead _ = unsafeCoerceRefl
+
type family Head l where
Head (x : _) = x
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Strided/Orthotope.hs
index ebda388..e2cd17c 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Strided/Orthotope.hs
@@ -1,6 +1,6 @@
{-# LANGUAGE ImportQualifiedPost #-}
-module Data.Array.Mixed.Internal.Arith (
- module Data.Array.Mixed.Internal.Arith,
+module Data.Array.Strided.Orthotope (
+ module Data.Array.Strided.Orthotope,
module Data.Array.Strided.Arith,
) where
@@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve
toO :: AS.Array n a -> RS.Array n a
toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
+{-# INLINE liftO1 #-}
liftO1 :: (AS.Array n a -> AS.Array n' b)
-> RS.Array n a -> RS.Array n' b
liftO1 f = toO . f . fromO
+{-# INLINE liftO2 #-}
liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
-> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
liftO2 f x y = toO (f (fromO x) (fromO y))
+-- We don't inline this lifting function, because its code is not just
+-- a wrapper, being relatively long and expensive.
+{-# INLINEABLE liftVEltwise1 #-}
liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
-> (VS.Vector a -> VS.Vector b)
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/XArray.hs
index 3e7a498..9f3ee34 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -1,8 +1,11 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -11,31 +14,37 @@
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.XArray where
+module Data.Array.XArray where
import Control.DeepSeq (NFData)
+import Control.Monad (foldM_, foldM)
+import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
import Data.Array.Internal.RankedS qualified as ORS
-import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
+import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
+import Data.Vector.Generic.Checked qualified as VGC
import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import Unsafe.Coerce (unsafeCoerce)
+#endif
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Mixed.Types
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
type XArray :: [Maybe Nat] -> Type -> Type
@@ -53,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
go _ _ = error "Invalid shapeL"
+{-# INLINEABLE fromVector #-}
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
@@ -76,9 +86,9 @@ cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
-> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
cast ssh1 sh2 ssh' (XArray arr)
| Refl <- lemRankApp ssh1 ssh'
- , Refl <- lemRankApp (ssxFromShape sh2) ssh'
+ , Refl <- lemRankApp (ssxFromShX sh2) ssh'
= let arrsh :: IShX sh1
- (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
in if shxToList arrsh == shxToList sh2
then XArray arr
else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
@@ -89,8 +99,8 @@ unScalar (XArray a) = S.unScalar a
replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
replicate sh ssh' (XArray arr)
| Dict <- lemKnownNatRankSSX ssh'
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
- , Refl <- lemRankApp (ssxFromShape sh) ssh'
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh')
+ , Refl <- lemRankApp (ssxFromShX sh) ssh'
= XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr)
arr)
@@ -108,15 +118,23 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-- XArray . S.fromVector (shxShapeL sh)
-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
+{-# INLINEABLE indexPartial #-}
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
+{- Strangely, this increases allocation and there's no noticeable speedup:
+indexPartial (XArray (ORS.A (ORG.A sh t))) ix =
+ let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t))
+ len = ixxLength ix
+ in XArray (ORS.A (ORG.A (drop len sh)
+ OI.T{ OI.strides = drop len (OI.strides t)
+ , OI.offset = linear
+ , OI.values = OI.values t })) -}
+{-# INLINEABLE index #-}
index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
-index xarr i
- | Refl <- lemAppNil @sh
- = let XArray arr' = indexPartial xarr i :: XArray '[] a
- in S.unScalar arr'
+index (XArray (ORS.A (ORG.A _ t))) i =
+ OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t)))
append :: forall n m sh a. Storable a
=> StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
@@ -167,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f xarr@(XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
+ = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
in if 0 `elem` shxToList sh
then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
@@ -194,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
+ = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
in if 0 `elem` shxToList sh
then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
@@ -214,10 +232,15 @@ transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh)
-> XArray (PermutePrefix is sh) a
transpose ssh perm (XArray arr)
| Dict <- lemKnownNatRankSSX ssh
- , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLenPerm perm ssh)) (ssxDropLenPerm perm ssh)
, Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
, Refl <- lemRankDropLen ssh perm
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
= XArray (S.transpose (permToList' perm) arr)
+#else
+ = XArray (unsafeCoerce (S.transpose (permToList' perm) arr))
+#endif
+
-- | The list argument gives indices into the original dimension list.
--
@@ -243,27 +266,23 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
-sumFull _ (XArray arr) =
- S.unScalar $
- liftO1 (numEltSum1Inner (SNat @0)) $
- S.fromVector [product (S.shapeL arr)] $
- S.toVector arr
+sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr
sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr)
sh'F = shxFlatten sh' :$% ZSX
- ssh'F = ssxFromShape sh'F
+ ssh'F = ssxFromShX sh'F
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
- , let sn = listxRank (let StaticShX l = ssh in l)
+ , let sn = ssxRank ssh
= XArray (liftO1 (numEltSum1Inner sn) arr')
in go $
@@ -276,40 +295,83 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
shF = shxFlatten sh :$% ZSX
- in sumInner ssh' (ssxFromShape shF) $
- transpose2 (ssxFromShape shF) ssh' $
+ in sumInner ssh' (ssxFromShX shF) $
+ transpose2 (ssxFromShX shF) ssh' $
reshapePartial ssh ssh' shF $
arr
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX ssh
- = case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+-- | This creates an array from a list of arrays of one less dimension.
+-- The list is streamed, its length is checked and it's verified
+-- that all arrays on the list have the same shape.
+{-# INLINE fromListOuterSN #-}
+fromListOuterSN :: forall n sh a. Storable a
+ => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a
+fromListOuterSN m sh l
+ | Dict <- lemKnownNatRank sh
+ , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l
+ = case sh of
+ ZSX -> fromList1SN m (map S.unScalar (toList l'))
+ _ -> XArray (ravelOuterN (fromSNat' m) l')
+
+-- | This checks that the list has the given length and that all shapes in the
+-- list are equal. The list is streamed.
+-- The first array in the list is forced early to potentially release some
+-- memory, before allocating the (large) new array.
+{-# INLINEABLE ravelOuterN #-}
+ravelOuterN :: (KnownNat k, Storable a)
+ => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a
+ravelOuterN 0 _ = error "ravelOuterN: N == 0"
+ravelOuterN k as@(!a0 :| _) = runST $ do
+ let sh0 = S.shapeL a0
+ len = product sh0
+ vecSize = k * len
+ vec <- VSM.unsafeNew vecSize
+ let f !n (ORS.A (ORG.A sht t)) =
+ if | n >= k ->
+ error $ "ravelOuterN: list too long " ++ show (n, k)
+ -- if we do this check just once at the end, we may
+ -- crash instead of producing an accurate error message
+ | sht == sh0 -> do
+ let g off el = do
+ VS.unsafeCopy (VSM.slice off (VS.length el) vec) el
+ return $! off + VS.length el
+ foldM_ g (n * len) (OI.toVectorListT sht t)
+ return $! n + 1
+ | otherwise ->
+ error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0)
+ nFinal <- foldM f 0 as
+ if nFinal == k
+ then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec
+ else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
-toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr) =
- case S.shapeL arr of
+toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
+toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) =
+ case shArr of
+ [] -> error "impossible"
0 : _ -> []
- _ -> coerce (ORB.toList (S.unravel arr))
+ -- using orthotope's functions here would entail using rerank, which is slow, so we don't
+ [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr)
+ n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1]
+
+-- | Performance note: the list's spine is fully materialised to compute its
+-- length before traversing it again to construct the array.
+{-# INLINE fromList1 #-}
+fromList1 :: Storable a => [a] -> XArray '[Nothing] a
+fromList1 l =
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+-- | The list is streamed.
+{-# INLINE fromList1SN #-}
+fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a
+fromList1SN m l =
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
-toList1 :: Storable a => XArray '[n] a -> [a]
-toList1 (XArray arr) = S.toList arr
+toListLinear :: Storable a => XArray sh a -> [a]
+toListLinear (XArray arr) = S.toList arr
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
@@ -340,7 +402,7 @@ reshape ssh1 sh2 (XArray arr)
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh')
= XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
diff --git a/src/Data/Vector/Generic/Checked.hs b/src/Data/Vector/Generic/Checked.hs
new file mode 100644
index 0000000..d8aaaae
--- /dev/null
+++ b/src/Data/Vector/Generic/Checked.hs
@@ -0,0 +1,40 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+module Data.Vector.Generic.Checked (
+ fromListNChecked,
+) where
+
+import Data.Stream.Monadic qualified as Stream
+import Data.Vector.Fusion.Bundle.Monadic qualified as VBM
+import Data.Vector.Fusion.Bundle.Size qualified as VBS
+import Data.Vector.Fusion.Util qualified as VFU
+import Data.Vector.Generic qualified as VG
+
+-- for INLINE_FUSED and INLINE_INNER
+#include "vector.h"
+
+
+-- These functions are copied over and lightly edited from the vector and
+-- vector-stream packages, and thus inherit their BSD-3-Clause license with:
+-- Copyright (c) 2008-2012, Roman Leshchinskiy
+-- 2020-2022, Alexey Kuleshevich
+-- 2020-2022, Aleksey Khudyakov
+-- 2020-2022, Andrew Lelechenko
+
+fromListNChecked :: VG.Vector v a => Int -> [a] -> v a
+{-# INLINE fromListNChecked #-}
+fromListNChecked n = VG.unstream . bundleFromListNChecked n
+
+bundleFromListNChecked :: Int -> [a] -> VBM.Bundle VFU.Id v a
+{-# INLINE_FUSED bundleFromListNChecked #-}
+bundleFromListNChecked nTop xsTop
+ | nTop < 0 = error "fromListNChecked: length negative"
+ | otherwise =
+ VBM.fromStream (Stream.Stream step (xsTop, nTop)) (VBS.Max (VFU.delay_inline max nTop 0))
+ where
+ {-# INLINE_INNER step #-}
+ step (xs,n) | n == 0 = case xs of
+ [] -> return Stream.Done
+ _:_ -> error "fromListNChecked: list too long"
+ step (x:xs,n) = return (Stream.Yield x (xs,n-1))
+ step ([],_) = error "fromListNChecked: list too short"
diff --git a/src/GHC/TypeLits/Orphans.hs b/src/GHC/TypeLits/Orphans.hs
new file mode 100644
index 0000000..42f7293
--- /dev/null
+++ b/src/GHC/TypeLits/Orphans.hs
@@ -0,0 +1,13 @@
+-- | Compatibility module adding some additional instances not yet defined in
+-- base-4.18 with GHC 9.6.
+{-# OPTIONS -Wno-orphans #-}
+module GHC.TypeLits.Orphans where
+
+import GHC.TypeLits
+
+
+instance Eq (SNat n) where
+ _ == _ = True
+
+instance Ord (SNat n) where
+ compare _ _ = EQ
diff --git a/test/Gen.hs b/test/Gen.hs
index 50c671f..789a59c 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -4,7 +4,6 @@
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -12,7 +11,6 @@
module Gen where
import Data.ByteString qualified as BS
-import Data.Foldable (toList)
import Data.Type.Equality
import Data.Type.Ord
import Data.Vector.Storable qualified as VS
@@ -20,10 +18,10 @@ import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Types
import Data.Array.Nested
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -47,7 +45,7 @@ genLowBiased (lo, hi) = do
return (lo + x * x * x * (hi - lo))
shuffleShR :: IShR n -> Gen (IShR n)
-shuffleShR = \sh -> go (length sh) (toList sh) sh
+shuffleShR = \sh -> go (shrLength sh) (shrToList sh) sh
where
go :: Int -> [Int] -> IShR n -> Gen (IShR n)
go _ _ ZSR = return ZSR
@@ -59,9 +57,12 @@ shuffleShR = \sh -> go (length sh) (toList sh) sh
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
-genShR sn = do
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 100_000)
+ targetSize <- Gen.int (Range.linear 0 targetMax)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
@@ -76,9 +77,8 @@ genShR sn = do
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
dims <- genDims sn targetSize
- let dimsL = toList dims
- maxdim = maximum dimsL
- cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ let maxdim = maximum $ shrToList dims
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize)
shuffleShR (min cap <$> dims)
-- | Example: given 3 and 7, might return:
@@ -93,10 +93,14 @@ genShR sn = do
-- other dimensions might be zero.
genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
genReplicatedShR = \m n -> do
- sh1 <- genShR m
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
(sh2, sh3) <- injectOnes n sh1 sh1
return (sh1, sh2, sh3)
where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
injectOnes n@SNat shOnes sh
| m@SNat <- shrRank sh
@@ -105,7 +109,7 @@ genReplicatedShR = \m n -> do
EQI -> return (shOnes, sh)
GTI -> do
index <- Gen.int (Range.linear 0 (fromSNat' m))
- value <- Gen.int (Range.linear 1 5)
+ value <- Gen.int (uncurry Range.linear repvalrange)
Refl <- return (lem n m)
injectOnes n (inject index 1 shOnes) (inject index value sh)
@@ -115,7 +119,7 @@ genReplicatedShR = \m n -> do
inject :: Int -> Int -> IShR m -> IShR (m + 1)
inject 0 v sh = v :$: sh
inject i v (w :$: sh) = w :$: inject (i - 1) v sh
- inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+ inject _ _ ZSR = error "unreachable"
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
@@ -134,7 +138,7 @@ genStaticShX = \n k -> case n of
genStaticShX n' $ \ssh ->
k (item :!% ssh)
where
- genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m ()
+ genItem :: Monad m => (forall n. SMayNat () n -> PropertyT m ()) -> PropertyT m ()
genItem k = do
b <- forAll Gen.bool
if b
@@ -157,7 +161,7 @@ genPermR n = Gen.shuffle [0 .. n-1]
genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
genPerm n@SNat k = do
list <- forAll $ genPermR (fromSNat' n)
- permFromList list $ \perm -> do
+ permFromListCont list $ \perm -> do
case permCheckPermutation perm $
case sameNat' (permRank perm) n of
Just Refl -> Just (k perm)
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 72bf8f6..2d35cd9 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -1,9 +1,12 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
{-# LANGUAGE TypeAbstractions #-}
+#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -12,19 +15,18 @@ module Tests.C where
import Control.Monad
import Data.Array.RankedS qualified as OR
-import Data.Foldable (toList)
import Data.Functor.Const
import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
import Hedgehog
import Hedgehog.Gen qualified as Gen
-import Hedgehog.Internal.Property (forAllT)
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -39,21 +41,23 @@ import Util
fineTol :: Double
fineTol = 1e-8
-prop_sum_nonempty :: Property
-prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+debugCoverage :: Bool
+debugCoverage = False
+
+gen_red_nonempty :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_nonempty f = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (shrTail sh)) -- only constrain the tail
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
- genStorables (Range.singleton (product sh))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (shrSize sh))
+ guard (all (> 0) (shrToList $ shrTail sh)) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (shrToList sh) <$>
+ genStorables (Range.singleton (shrSize sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ f inrank outrank arr
-prop_sum_empty :: Property
-prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+gen_red_empty :: (forall n. SNat (n + 1) -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_empty f = property $ genRank $ \outrankm1@(SNat @nm1) -> do
-- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
_outrank :: SNat n <- return $ SNat @(nm1 + 1)
let inrank = SNat @(n + 1)
@@ -62,26 +66,24 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (0 `elem` toList (shrTail sh))
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- let arr = OR.fromList @(n + 1) @Double (toList sh) []
- let rarr = rfromOrthotope inrank arr
- OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+ guard (0 `elem` (shrToList $ shrTail sh))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (shrSize sh))
+ let arr = OR.fromList @(n + 1) @Double (shrToList sh) []
+ f inrank arr
-prop_sum_lasteq1 :: Property
-prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+gen_red_lasteq1 :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_lasteq1 f = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
- guard (all (> 0) outsh)
+ guard (all (> 0) $ shrToList outsh)
let insh = shrAppend outsh (1 :$: ZSR)
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
- genStorables (Range.singleton (product insh))
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (shrToList insh) <$>
+ genStorables (Range.singleton (shrSize insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
- let rarr = rfromOrthotope inrank arr
- almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+ f inrank outrank arr
-prop_sum_replicated :: Bool -> Property
-prop_sum_replicated doTranspose = property $
+gen_red_replicated :: Bool -> (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
+gen_red_replicated doTranspose f = property $
genRank $ \inrank1@(SNat @m) ->
genRank $ \outrank@(SNat @nm1) -> do
inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
@@ -89,19 +91,65 @@ prop_sum_replicated doTranspose = property $
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
- guard (all (> 0) sh3)
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
+ guard (all (> 0) $ shrToList sh3)
arr <- forAllT $
- OR.stretch (toList sh3)
- . OR.reshape (toList sh2)
- . OR.fromVector @Double @m (toList sh1) <$>
- genStorables (Range.singleton (product sh1))
+ OR.stretch (shrToList sh3)
+ . OR.reshape (shrToList sh2)
+ . OR.fromVector @Double @m (shrToList sh1) <$>
+ genStorables (Range.singleton (shrSize sh1))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
arrTrans <-
if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
return $ OR.transpose perm arr
else return arr
- let rarr = rfromOrthotope inrank2 arrTrans
- almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+ f inrank2 outrank arrTrans
+
+
+prop_sum_nonempty :: Property
+prop_sum_nonempty = gen_red_nonempty $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_empty :: Property
+prop_sum_empty = gen_red_empty $ \inrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = gen_red_lasteq1 $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = gen_red_replicated doTranspose $ \inrank outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr)
+
+
+prop_sumall_nonempty :: Property
+prop_sumall_nonempty = gen_red_nonempty $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr)
+
+prop_sumall_empty :: Property
+prop_sumall_empty = gen_red_empty $ \inrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ rsumAllPrim rarr === 0.0
+
+prop_sumall_lasteq1 :: Property
+prop_sumall_lasteq1 = gen_red_lasteq1 $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rsumAllPrim rarr) (OR.sumA arr)
+
+prop_sumall_replicated :: Bool -> Property
+prop_sumall_replicated doTranspose = gen_red_replicated doTranspose $ \inrank _outrank arr -> do
+ let rarr = rfromOrthotope inrank arr
+ almostEq 1e-6 (rsumAllPrim rarr) (OR.sumA arr)
+
prop_negate_with :: forall f b. Show b
=> ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
@@ -111,9 +159,9 @@ prop_negate_with :: forall f b. Show b
prop_negate_with genRank' genB preproc = property $
genRank' $ \extra rank@(SNat @n) -> do
sh <- forAll $ genShR rank
- guard (all (> 0) sh)
- arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
- genStorables (Range.singleton (product sh))
+ guard (all (> 0) $ shrToList sh)
+ arr <- forAllT $ OR.fromVector @Double @n (shrToList sh) <$>
+ genStorables (Range.singleton (shrSize sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
bval <- forAll $ genB extra sh
let arr' = preproc extra bval arr
@@ -130,6 +178,13 @@ tests = testGroup "C"
,testProperty "replicated" (prop_sum_replicated False)
,testProperty "replicated_transposed" (prop_sum_replicated True)
]
+ ,testGroup "sumAll"
+ [testProperty "nonempty" prop_sumall_nonempty
+ ,testProperty "empty" prop_sumall_empty
+ ,testProperty "last==1" prop_sumall_lasteq1
+ ,testProperty "replicated" (prop_sumall_replicated False)
+ ,testProperty "replicated_transposed" (prop_sumall_replicated True)
+ ]
,testGroup "negate"
[testProperty "normalised" $ prop_negate_with
(\k -> genRank (k (Const ())))
@@ -146,7 +201,7 @@ tests = testGroup "C"
(\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1))
len <- Gen.integral (Range.constant 0 (n-lo-1))
return (lo, len)
- pairs <- mapM genPair (toList sh)
+ pairs <- mapM genPair (shrToList sh)
return pairs)
(\_ -> OR.slice)
]
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 1e7ad13..4e75d64 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -6,7 +6,7 @@ module Tests.Permutation where
import Data.Type.Equality
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Permutation
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -24,7 +24,7 @@ tests = testGroup "Permutation"
[testProperty "permCheckPermutation" $ property $ do
n <- forAll $ Gen.int (Range.linear 0 10)
list <- forAll $ genPermR n
- let r = permFromList list $ \perm ->
+ let r = permFromListCont list $ \perm ->
permCheckPermutation perm ()
case r of
Just () -> return ()
diff --git a/test/Util.hs b/test/Util.hs
index 34cf8ab..6514fbf 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -15,7 +15,7 @@ import GHC.TypeLits
import Hedgehog
import Hedgehog.Internal.Property (failDiff)
-import Data.Array.Mixed.Types (fromSNat')
+import Data.Array.Nested.Types (fromSNat')
-- Returns highest value that satisfies the predicate, or `lo` if none does
@@ -36,16 +36,20 @@ orSumOuter1 (sn@SNat :: SNat n) =
let n = fromSNat' sn
in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
-class AlmostEq f where
- type AlmostEqConstr f :: Type -> Constraint
+class AlmostEq t where
+ type EltOf t :: Type
-- | absolute tolerance, lhs, rhs
- almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
- => a -> f a -> f a -> m ()
+ almostEq :: MonadTest m => EltOf t -> t -> t -> m ()
-instance AlmostEq (OR.Array n) where
- type AlmostEqConstr (OR.Array n) = OR.Unbox
+instance (OR.Unbox a, Ord a, Show a, Fractional a) => AlmostEq (OR.Array n a) where
+ type EltOf (OR.Array n a) = a
almostEq atol lhs rhs
| OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
success
| otherwise =
failDiff lhs rhs
+
+instance AlmostEq Double where
+ type EltOf Double = Double
+ almostEq atol lhs rhs | abs (lhs - rhs) < atol = success
+ | otherwise = failDiff lhs rhs