aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--.stylish-haskell.yaml452
-rw-r--r--CHANGELOG.md7
-rw-r--r--README.md191
-rw-r--r--bench/Main.hs244
-rw-r--r--cabal.project5
-rw-r--r--cbits/arith.c808
-rw-r--r--cbits/arith_lists.h39
-rw-r--r--example/Main.hs29
-rwxr-xr-xgentrace.sh31
-rw-r--r--ops/Data/Array/Strided.hs7
-rw-r--r--ops/Data/Array/Strided/Arith.hs7
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs933
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Foreign.hs47
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs95
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs83
-rw-r--r--ops/Data/Array/Strided/Array.hs44
-rw-r--r--ox-arrays.cabal176
-rw-r--r--release-hints.txt3
-rw-r--r--src/Data/Array/Mixed.hs416
-rw-r--r--src/Data/Array/Nested.hs128
-rw-r--r--src/Data/Array/Nested/Convert.hs333
-rw-r--r--src/Data/Array/Nested/Internal.hs1294
-rw-r--r--src/Data/Array/Nested/Lemmas.hs162
-rw-r--r--src/Data/Array/Nested/Mixed.hs936
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs644
-rw-r--r--src/Data/Array/Nested/Permutation.hs283
-rw-r--r--src/Data/Array/Nested/Ranked.hs323
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs268
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs369
-rw-r--r--src/Data/Array/Nested/Shaped.hs272
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs255
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs425
-rw-r--r--src/Data/Array/Nested/Trace.hs72
-rw-r--r--src/Data/Array/Nested/Trace/TH.hs98
-rw-r--r--src/Data/Array/Nested/Types.hs152
-rw-r--r--src/Data/Array/Strided/Orthotope.hs43
-rw-r--r--src/Data/Array/XArray.hs348
-rw-r--r--src/Data/Bag.hs18
-rw-r--r--src/Data/INat.hs121
-rw-r--r--test/Gen.hs174
-rw-r--r--test/Main.hs32
-rw-r--r--test/Tests/C.hs160
-rw-r--r--test/Tests/Permutation.hs39
-rw-r--r--test/Util.hs51
45 files changed, 8680 insertions, 1938 deletions
diff --git a/.gitignore b/.gitignore
index a3ac1fc..56ab906 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
dist-newstyle/
cabal.project.local
+.ccls-cache/
diff --git a/.stylish-haskell.yaml b/.stylish-haskell.yaml
new file mode 100644
index 0000000..bfd25ea
--- /dev/null
+++ b/.stylish-haskell.yaml
@@ -0,0 +1,452 @@
+# stylish-haskell configuration file
+# ==================================
+
+# The stylish-haskell tool is mainly configured by specifying steps. These steps
+# are a list, so they have an order, and one specific step may appear more than
+# once (if needed). Each file is processed by these steps in the given order.
+steps:
+ # Convert some ASCII sequences to their Unicode equivalents. This is disabled
+ # by default.
+ # - unicode_syntax:
+ # # In order to make this work, we also need to insert the UnicodeSyntax
+ # # language pragma. If this flag is set to true, we insert it when it's
+ # # not already present. You may want to disable it if you configure
+ # # language extensions using some other method than pragmas. Default:
+ # # true.
+ # add_language_pragma: true
+
+ # Format module header
+ #
+ # Currently, this option is not configurable and will format all exports and
+ # module declarations to minimize diffs
+ #
+ # - module_header:
+ # # How many spaces use for indentation in the module header.
+ # indent: 4
+ #
+ # # Should export lists be sorted? Sorting is only performed within the
+ # # export section, as delineated by Haddock comments.
+ # sort: true
+ #
+ # # See `separate_lists` for the `imports` step.
+ # separate_lists: true
+
+ # Format record definitions. This is disabled by default.
+ #
+ # You can control the layout of record fields. The only rules that can't be configured
+ # are these:
+ #
+ # - "|" is always aligned with "="
+ # - "," in fields is always aligned with "{"
+ # - "}" is likewise always aligned with "{"
+ #
+ # - records:
+ # # How to format equals sign between type constructor and data constructor.
+ # # Possible values:
+ # # - "same_line" -- leave "=" AND data constructor on the same line as the type constructor.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of the next line.
+ # equals: "indent 2"
+ #
+ # # How to format first field of each record constructor.
+ # # Possible values:
+ # # - "same_line" -- "{" and first field goes on the same line as the data constructor.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of the data constructor
+ # first_field: "indent 2"
+ #
+ # # How many spaces to insert between the column with "," and the beginning of the comment in the next line.
+ # field_comment: 2
+ #
+ # # How many spaces to insert before "deriving" clause. Deriving clauses are always on separate lines.
+ # deriving: 2
+ #
+ # # How many spaces to insert before "via" clause counted from indentation of deriving clause
+ # # Possible values:
+ # # - "same_line" -- "via" part goes on the same line as "deriving" keyword.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of "deriving" keyword.
+ # via: "indent 2"
+ #
+ # # Sort typeclass names in the "deriving" list alphabetically.
+ # sort_deriving: true
+ #
+ # # Wheter or not to break enums onto several lines
+ # #
+ # # Default: false
+ # break_enums: false
+ #
+ # # Whether or not to break single constructor data types before `=` sign
+ # #
+ # # Default: true
+ # break_single_constructors: true
+ #
+ # # Whether or not to curry constraints on function.
+ # #
+ # # E.g: @allValues :: Enum a => Bounded a => Proxy a -> [a]@
+ # #
+ # # Instead of @allValues :: (Enum a, Bounded a) => Proxy a -> [a]@
+ # #
+ # # Default: false
+ # curried_context: false
+
+ # Align the right hand side of some elements. This is quite conservative
+ # and only applies to statements where each element occupies a single
+ # line.
+ # Possible values:
+ # - always - Always align statements.
+ # - adjacent - Align statements that are on adjacent lines in groups.
+ # - never - Never align statements.
+ # All default to always.
+ - simple_align:
+ cases: never
+ top_level_patterns: never
+ records: never
+ multi_way_if: never
+
+ # Import cleanup
+ - imports:
+ # There are different ways we can align names and lists.
+ #
+ # - global: Align the import names and import list throughout the entire
+ # file.
+ #
+ # - file: Like global, but don't add padding when there are no qualified
+ # imports in the file.
+ #
+ # - group: Only align the imports per group (a group is formed by adjacent
+ # import lines).
+ #
+ # - none: Do not perform any alignment.
+ #
+ # Default: global.
+ align: group
+
+ # The following options affect only import list alignment.
+ #
+ # List align has following options:
+ #
+ # - after_alias: Import list is aligned with end of import including
+ # 'as' and 'hiding' keywords.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # > init, last, length)
+ #
+ # - with_alias: Import list is aligned with start of alias or hiding.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # > init, last, length)
+ #
+ # - with_module_name: Import list is aligned `list_padding` spaces after
+ # the module name.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # init, last, length)
+ #
+ # This is mainly intended for use with `pad_module_names: false`.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # init, last, length, scanl, scanr, take, drop,
+ # sort, nub)
+ #
+ # - new_line: Import list starts always on new line.
+ #
+ # > import qualified Data.List as List
+ # > (concat, foldl, foldr, head, init, last, length)
+ #
+ # - repeat: Repeat the module name to align the import list.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head)
+ # > import qualified Data.List as List (init, last, length)
+ #
+ # Default: after_alias
+ list_align: after_alias
+
+ # Right-pad the module names to align imports in a group:
+ #
+ # - true: a little more readable
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr,
+ # > init, last, length)
+ # > import qualified Data.List.Extra as List (concat, foldl, foldr,
+ # > init, last, length)
+ #
+ # - false: diff-safe
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, init,
+ # > last, length)
+ # > import qualified Data.List.Extra as List (concat, foldl, foldr,
+ # > init, last, length)
+ #
+ # Default: true
+ pad_module_names: false
+
+ # Long list align style takes effect when import is too long. This is
+ # determined by 'columns' setting.
+ #
+ # - inline: This option will put as much specs on same line as possible.
+ #
+ # - new_line: Import list will start on new line.
+ #
+ # - new_line_multiline: Import list will start on new line when it's
+ # short enough to fit to single line. Otherwise it'll be multiline.
+ #
+ # - multiline: One line per import list entry.
+ # Type with constructor list acts like single import.
+ #
+ # > import qualified Data.Map as M
+ # > ( empty
+ # > , singleton
+ # > , ...
+ # > , delete
+ # > )
+ #
+ # Default: inline
+ long_list_align: new_line_multiline
+
+ # Align empty list (importing instances)
+ #
+ # Empty list align has following options
+ #
+ # - inherit: inherit list_align setting
+ #
+ # - right_after: () is right after the module name:
+ #
+ # > import Vector.Instances ()
+ #
+ # Default: inherit
+ empty_list_align: inherit
+
+ # List padding determines indentation of import list on lines after import.
+ # This option affects 'long_list_align'.
+ #
+ # - <integer>: constant value
+ #
+ # - module_name: align under start of module name.
+ # Useful for 'file' and 'group' align settings.
+ #
+ # Default: 4
+ list_padding: 2
+
+ # Separate lists option affects formatting of import list for type
+ # or class. The only difference is single space between type and list
+ # of constructors, selectors and class functions.
+ #
+ # - true: There is single space between Foldable type and list of it's
+ # functions.
+ #
+ # > import Data.Foldable (Foldable (fold, foldl, foldMap))
+ #
+ # - false: There is no space between Foldable type and list of it's
+ # functions.
+ #
+ # > import Data.Foldable (Foldable(fold, foldl, foldMap))
+ #
+ # Default: true
+ separate_lists: false
+
+ # Space surround option affects formatting of import lists on a single
+ # line. The only difference is single space after the initial
+ # parenthesis and a single space before the terminal parenthesis.
+ #
+ # - true: There is single space associated with the enclosing
+ # parenthesis.
+ #
+ # > import Data.Foo ( foo )
+ #
+ # - false: There is no space associated with the enclosing parenthesis
+ #
+ # > import Data.Foo (foo)
+ #
+ # Default: false
+ space_surround: false
+
+ # Enabling this argument will use the new GHC lib parse to format imports.
+ #
+ # This currently assumes a few things, it will assume that you want post
+ # qualified imports. It is also not as feature complete as the old
+ # imports formatting.
+ #
+ # It does not remove redundant lines or merge lines. As such, the full
+ # feature scope is still pending.
+ #
+ # It _is_ however, a fine alternative if you are using features that are
+ # not parseable by haskell src extensions and you're comfortable with the
+ # presets.
+ #
+ # Default: false
+ ghc_lib_parser: false
+
+ # Post qualify option moves any qualifies found in import declarations
+ # to the end of the declaration. This also adjust padding for any
+ # unqualified import declarations.
+ #
+ # - true: Qualified as <module name> is moved to the end of the
+ # declaration.
+ #
+ # > import Data.Bar
+ # > import Data.Foo qualified as F
+ #
+ # - false: Qualified remains in the default location and unqualified
+ # imports are padded to align with qualified imports.
+ #
+ # > import Data.Bar
+ # > import qualified Data.Foo as F
+ #
+ # Default: false
+ post_qualify: true
+
+ # A list of rules specifying how to group modules and how to
+ # order the groups.
+ #
+ # Each rule has a match field; the rule only applies to module
+ # names matched by this pattern. Patterns are POSIX extended
+ # regular expressions; see the documentation of Text.Regex.TDFA
+ # for details:
+ # https://hackage.haskell.org/package/regex-tdfa-1.3.1.2/docs/Text-Regex-TDFA.html
+ #
+ # Rules are processed in order, so only the *first* rule that
+ # matches a specific module will apply. Any module names that do
+ # not match a single rule will be put into a single group at the
+ # end of the import block.
+ #
+ # Example: group MyApp modules first, with everything else in
+ # one group at the end.
+ #
+ # group_rules:
+ # - match: "^MyApp\\>"
+ #
+ # > import MyApp
+ # > import MyApp.Foo
+ # >
+ # > import Control.Monad
+ # > import MyApps
+ # > import Test.MyApp
+ #
+ # A rule can also optionally have a sub_group pattern. Imports
+ # that match the rule will be broken up into further groups by
+ # the part of the module name matched by the sub_group pattern.
+ #
+ # Example: group MyApp modules first, then everything else
+ # sub-grouped by the first part of the module name.
+ #
+ # group_rules:
+ # - match: "^MyApp\\>"
+ # - match: "."
+ # sub_group: "^[^.]+"
+ #
+ # > import MyApp
+ # > import MyApp.Foo
+ # >
+ # > import Control.Applicative
+ # > import Control.Monad
+ # >
+ # > import Data.Map
+ #
+ # A pattern only needs to match part of the module name, which
+ # could be in the middle. You can use ^pattern to anchor to the
+ # beginning of the module name, pattern$ to anchor to the end
+ # and ^pattern$ to force a full match. Example:
+ #
+ # - "Test\\." would match "Test.Foo" and "Foo.Test.Lib"
+ # - "^Test\\." would match "Test.Foo" but not "Foo.Test.Lib"
+ # - "\\.Test$" would match "Foo.Test" but not "Foo.Test.Lib"
+ # - "^Test$" would *only* match "Test"
+ #
+ # You can use \\< and \\> to anchor against the beginning and
+ # end of words, respectively. For example:
+ #
+ # - "^Test\\." would match "Test.Foo" but not "Test" or "Tests"
+ # - "^Test\\>" would match "Test.Foo" and "Test", but not
+ # "Tests"
+ #
+ # The default is a single rule that matches everything and
+ # sub-groups based on the first component of the module name.
+ #
+ # Default: [{ "match" : ".*", "sub_group": "^[^.]+" }]
+# group_rules:
+# - match: ".*"
+# sub_group: "^[^.]+"
+# - match: "^Data.Array\\>"
+# sub_group: "^[^.]+"
+# - match: "^HordeAd\\>"
+
+ # Language pragmas
+ - language_pragmas:
+ # We can generate different styles of language pragma lists.
+ #
+ # - vertical: Vertical-spaced language pragmas, one per line.
+ #
+ # - compact: A more compact style.
+ #
+ # - compact_line: Similar to compact, but wrap each line with
+ # `{-#LANGUAGE #-}'.
+ #
+ # Default: vertical.
+# style: compact
+
+ # Align affects alignment of closing pragma brackets.
+ #
+ # - true: Brackets are aligned in same column.
+ #
+ # - false: Brackets are not aligned together. There is only one space
+ # between actual import and closing bracket.
+ #
+ # Default: true
+ align: false
+
+ # stylish-haskell can detect redundancy of some language pragmas. If this
+ # is set to true, it will remove those redundant pragmas. Default: true.
+ remove_redundant: true
+
+ # Language prefix to be used for pragma declaration, this allows you to
+ # use other options non case-sensitive like "language" or "Language".
+ # If a non correct String is provided, it will default to: LANGUAGE.
+ language_prefix: LANGUAGE
+
+ # Replace tabs by spaces. This is disabled by default.
+ # - tabs:
+ # # Number of spaces to use for each tab. Default: 8, as specified by the
+ # # Haskell report.
+ # spaces: 8
+
+ # Remove trailing whitespace
+ - trailing_whitespace: {}
+
+ # Squash multiple spaces between the left and right hand sides of some
+ # elements into single spaces. Basically, this undoes the effect of
+ # simple_align but is a bit less conservative.
+ # - squash: {}
+
+# A common setting is the number of columns (parts of) code will be wrapped
+# to. Different steps take this into account.
+#
+# Set this to null to disable all line wrapping.
+#
+# Default: 80.
+columns: 200
+
+# By default, line endings are converted according to the OS. You can override
+# preferred format here.
+#
+# - native: Native newline format. CRLF on Windows, LF on other OSes.
+#
+# - lf: Convert to LF ("\n").
+#
+# - crlf: Convert to CRLF ("\r\n").
+#
+# Default: native.
+newline: native
+
+# Sometimes, language extensions are specified in a cabal file or from the
+# command line instead of using language pragmas in the file. stylish-haskell
+# needs to be aware of these, so it can parse the file correctly.
+#
+# No language extensions are enabled by default.
+#language_extensions:
+# - NoStarIsType
+ # - TemplateHaskell
+ # - QuasiQuotes
+
+# Attempt to find the cabal file in ancestors of the current directory, and
+# parse options (currently only language extensions) from that.
+#
+# Default: true
+cabal: true
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..009d267
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,7 @@
+# Changelog for `ox-arrays`
+
+This package intends to follow the [PVP](https://pvp.haskell.org/).
+
+## 0.1.0.0
+- Initial release
+- Various aspects of the API are still experimental, and breaking changes are expected in the future.
diff --git a/README.md b/README.md
index 9b8d543..01bcbac 100644
--- a/README.md
+++ b/README.md
@@ -1,56 +1,165 @@
-Wrapper library around `orthotope` that defines nested arrays, including
-tuples, of (eventually) unboxed values. The arrays are represented in
-struct-of-arrays form via the `Data.Vector.Unboxed` data family trick. Below
-the surface layer, there is a more low-level wrapper around `orthotope` that
-defines an array type type-indexed by `[Maybe Nat]`: some dimensions are
-shape-typed (i.e. have their size statically known), and some not.
+## ox-arrays
-An overview of the API:
+ox-arrays is an array library that defines nested arrays, including tuples, of
+(eventually) unboxed values. The arrays are represented in struct-of-arrays
+form via the `Data.Vector.Unboxed` data family trick; the component arrays are
+`orthotope` arrays
+([RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html))
+which describe elements using a _stride vector_ or
+[LMAD](https://dl.acm.org/doi/pdf/10.1145/509705.509708) so that `transpose`
+and `replicate` need only modify array metadata, not actually move around data.
+
+Because of the struct-of-arrays representation, nested arrays are not fully
+general: indeed, arrays are not actually nested under the hood, so if one has an
+array of arrays, those element arrays must all have the same shape (length,
+width, etc.). If one has an array of tuples of arrays, then all the `fst`
+components must have the same shape and all the `snd` components must have the
+same shape, but the two pair components themselves can be different.
+
+However, the nesting functionality of ox-arrays can be completely ignored if you
+only care about other parts of its API, or the vectorised arithmetic operations
+(using hand-written C code). Nesting support mostly does not get in the way, and
+has essentially no overhead (both when it's used and when it's not used).
+
+ox-arrays defines three array types: `Ranked`, `Shaped` and `Mixed`.
+- `Ranked` corresponds to `orthotope`'s
+ [RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html)
+ and has the _rank_ of the array (its number of dimensions) on the type level.
+ For example, `Ranked 2 Float` is a two-dimensional array of `Float`s, i.e. a
+ matrix.
+- `Shaped` corresponds to `orthotope`'s
+ [ShapedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html).
+ and has the full _shape_ of the array (its dimension sizes) on the type level
+ as a type-level list of `Nat`s. For example, `Shaped [2,3] Float` is a 2-by-3
+ matrix. The innermost dimension correspond to the right-most element in the
+ list.
+- `Mixed` is halfway between the two: it has a type parameter of kind
+ `[Maybe Nat]` whose length is the rank of the array; `Nothing` elements have
+ unknown size, whereas `Just` elements have the indicated size. The type
+ `Mixed [Nothing, Nothing] a` is equivalent to `Ranked 2 a`; the type
+ `Mixed [Just n, Just m] a` is equivalent to `Shaped [n, m] a`.
+
+In various places in the API of a library like ox-arrays, one can make a
+decision between 1. requiring a type class constraint providing certain
+information (e.g.
+[KnownNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:KnownNat)
+or `orthotope`'s
+[Shape](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html#t:Shape)),
+or 2. taking singleton _values_ that encode said information in a way that is
+linked to the type level (e.g.
+[SNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:SNat)).
+`orthotope` chooses the type class approach; ox-arrays chooses the singleton
+approach. Singletons are more verbose at times, but give the programmer more
+insight in what data is flowing where, and more importantly, more control: type
+class inference is very nice and implicit, but if it's not powerful enough for
+the trickery you're doing, you're out of luck. Singletons allow you to explain
+as precisely as you want to GHC what exactly you're doing.
+
+Below the surface layer, there is a more low-level wrapper (`XArray`) around
+`orthotope` that defines a non-nested `Mixed`-style array type.
+
+Here is a little taster of the API, to get a sense for the design:
```haskell
-data Ranked (n :: INat) a {- e.g. -} Ranked 3 Float
-data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
-data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
-
-Ranked I0 a = Ranked Z a ~~= Acc.Array Z a = Acc.Scalar a
-Ranked I1 a = Ranked (S Z) a ~~= Acc.Array (Z :. Int) a = Acc.Vector a
-Ranked I2 a = Ranked (S (S Z)) a ~~= Acc.Array (Z :. Int :. Int) a = Acc.Matrix a
+import GHC.TypeLits (Nat, SNat)
+
+data Ranked (n :: Nat) a {- e.g. -} Ranked 3 Float
+data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
+data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
+-- Shape types are written Sh{R,S,X}. The 'I' prefix denotes a Int-filled shape;
+-- ShR and ShX are more general containers. ShS is a singleton.
+rshape :: Elt a => Ranked n a -> IShR n
+sshape :: Elt a => Shaped sh a -> ShS sh
+mshape :: Elt a => Mixed xsh a -> IShX xsh
-rshape :: (Elt a, KnownINat n) => Ranked n a -> IxR n
-sshape :: (Elt a, KnownShape sh) => Shaped sh a -> IxS sh
-mshape :: (Elt a, KnownShapeX xsh) => Mixed xsh a -> IxX xsh
+-- Index types are written Ix{R,S,X}.
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+mindex :: Elt a => Mixed xsh a -> IIxX xsh -> a
-rindex :: Elt a => Ranked n a -> IxR n -> a
-sindex :: Elt a => Shaped sh a -> IxS sh -> a
-mindex :: Elt a => Mixed xsh a -> IxX xsh -> a
+-- The index types can be used as if they were defined as follows; pattern
+-- synonyms are provided to construct the illusion. (The actual definitions are
+-- a bit more general and indirect.)
+data IIxR n where
+ ZIR :: IIxR 0
+ (:.:) :: Int -> IIxR n -> IIxR (n + 1)
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
+data IIxS sh where
+ ZIS :: IIxS '[]
+ (:.$) :: Int -> IIxS sh -> IIxS (n : sh)
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
+data IIxX xsh where
+ ZIX :: IIxX '[]
+ (:.%) :: Int -> IIxX xsh -> IIxX (mn : xsh)
-data IxX sh where
- IZX :: IxX '[]
- (::@) :: Int -> IxX sh -> IxX (Just n : sh)
- (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+-- Similarly, the shape types can be used as if they were defined as follows.
+data IShR n where
+ ZSR :: IShR 0
+ (:$:) :: Int -> IShR n -> IShR (n + 1)
+data ShS sh where
+ ZSS :: ShS '[]
+ (:$$) :: SNat n -> ShS sh -> ShS (n : sh)
+
+data IShX xsh where
+ ZSX :: IShX '[]
+ (:$%) :: SMayNat Int SNat mn -> IShX xsh -> IShX (mn : xsh)
+-- where:
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+
+-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed
+-- shape -- that is to say, only the statically-known part of a mixed shape.
+-- StaticShX provides for this need. It can be used as if defined as follows:
+data StaticShX xsh where
+ ZKX :: StaticShX '[]
+ (:!%) :: SMayNat () SNat mn -> StaticShX xsh -> StaticShX (mn : xsh)
+
+-- The Elt class describes types that can be used as elements of an array. While
+-- it is technically possible to define new instances of this class, typical
+-- usage should regard Elt as closed. The user-relevant instances are the
+-- following:
class Elt a
-instance Elt ()
-instance Elt Double
-instance Elt Int
-instance (Elt a, Elt b) => Elt (a, b)
-instance (Elt a, KnownINat n) => Elt (Ranked n a)
-instance (Elt a, KnownShape sh) => Elt (Shaped sh a)
-instance (Elt a, KnownShapeX xsh) => Elt (Mixed xsh a)
-
-rgenerate :: Elt a => IxR n -> (IxR n -> a) -> Ranked n a
-sgenerate :: (Elt a, KnownShape sh) => (IxS sh -> a) -> Shaped sh a
-mgenerate :: (Elt a, KnownShapeX xsh) => IxX xsh -> (IxX xsh -> a) -> Mixed xsh a
+instance Elt ()
+instance Elt Bool
+instance Elt Float
+instance Elt Double
+instance Elt Int
+instance (Elt a, Elt b) => Elt (a, b)
+instance Elt a => Elt (Ranked n a)
+instance Elt a => Elt (Shaped sh a)
+instance Elt a => Elt (Mixed xsh a)
+-- Essentially all functions that ox-arrays offers on arrays are first-order:
+-- add two arrays elementwise, transpose an array, append arrays, compute
+-- minima/maxima, zip/unzip, nest/unnest, etc. The first-order approach allows
+-- operations, especially arithmetic ones, to be vectorised using hand-written
+-- C code, without needing any sort of JIT compilation.
+rappend :: Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+mappend :: Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+
+-- Exceptionally, also one higher-order function is provided per array type:
+-- 'generate'. These functions have the caveat that regularity of arrays must be
+-- preserved: all returned 'a's must have equal shape. See the documentation of
+-- 'mgenerate'.
+-- Warning: because the invocations of the function you pass cannot be
+-- vectorised, 'generate' is rather slow if 'a' is small.
+-- The 'KnownElt' class captures an API infelicity where constraint-based shape
+-- passing is the only practical option.
+rgenerate :: KnownElt a => IShR n -> (IxR n -> a) -> Ranked n a
+sgenerate :: KnownElt a => ShS sh -> (IxS sh -> a) -> Shaped sh a
+mgenerate :: KnownElt a => IShX xsh -> (IxX xsh -> a) -> Mixed xsh a
+
+-- Under the hood, Ranked and Shaped are both newtypes over Mixed. Mixed itself
+-- is a data family over XArray, which is a newtype over orthotope's RankedS.
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
```
+
+About the name: when importing `orthotope` array modules, a possible naming
+convention is to use qualified imports as `OR` for "orthotope ranked" arrays and
+`OS` for "orthotope shaped" arrays. ox-arrays was started to fill the `OX` gap,
+then grew out of proportion.
diff --git a/bench/Main.hs b/bench/Main.hs
new file mode 100644
index 0000000..b604eb9
--- /dev/null
+++ b/bench/Main.hs
@@ -0,0 +1,244 @@
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
+module Main where
+
+import Control.Exception (bracket)
+import Control.Monad (when)
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+import Data.Foldable (toList)
+import Data.Vector.Storable qualified as VS
+import Numeric.LinearAlgebra qualified as LA
+import Test.Tasty.Bench
+import Text.Show (showListWith)
+
+import Data.Array.Nested
+import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
+import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
+import Data.Array.Strided.Arith.Internal qualified as Arith
+import Data.Array.XArray (XArray(..))
+
+
+enableMisc :: Bool
+enableMisc = False
+
+bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
+bgroupIf True = bgroup
+bgroupIf False = \name _ -> bgroup name []
+
+
+main :: IO ()
+main = do
+ let enable = False
+ bracket (Arith.statisticsEnable enable)
+ (\() -> do Arith.statisticsEnable False
+ when enable Arith.statisticsPrintAll)
+ (\() -> main_tests)
+
+main_tests :: IO ()
+main_tests = defaultMain
+ [bgroup "compare" tests_compare
+ ,bgroup "dotprod" $
+ let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides
+ dotprodBench name (inp1, inp2) =
+ let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
+ in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
+ l ""
+ in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
+ " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
+ nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
+
+ iota = riota @Double
+ in
+ [dotprodBench "dot 1D"
+ (iota 10_000_000
+ ,iota 10_000_000)
+ ,dotprodBench "revdot"
+ (rrev1 (iota 10_000_000)
+ ,rrev1 (iota 10_000_000))
+ ,dotprodBench "dot 2D"
+ (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+ ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+ ,dotprodBench "batched dot"
+ (rreplicate (1000 :$: ZSR) (iota 10_000)
+ ,rreplicate (1000 :$: ZSR) (iota 10_000))
+ ,dotprodBench "transposed dot" $
+ let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+ ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+ in (rtranspose [1,0] a, rtranspose [1,0] b)
+ ,dotprodBench "repdot" $
+ let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000)
+ ,rreplicate (1000 :$: ZSR) (iota 10_000))
+ in (rtranspose [1,0] a, rtranspose [1,0] b)
+ ,dotprodBench "matvec" $
+ let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+ ,iota 10_000)
+ in (m, rreplicate (1000 :$: ZSR) v)
+ ,dotprodBench "vecmat" $
+ let (v, m) = (iota 1_000
+ ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+ in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m)
+ ,dotprodBench "matmat" $
+ let (n,m,k) = (100, 100, 1000)
+ (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
+ ,rreshape (m :$: k :$: ZSR) (iota (m*k)))
+ in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
+ ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2))
+ ,dotprodBench "matmatT" $
+ let (n,m,k) = (100, 100, 1000)
+ (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
+ ,rreshape (k :$: m :$: ZSR) (iota (m*k)))
+ in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
+ ,rreplicate (n :$: ZSR) m2)
+ ]
+ ,bgroup "orthotope"
+ [bench "normalize [1e6]" $
+ let n = 1_000_000
+ in nf (\a -> RS.normalize a)
+ (RS.rev [0] (RS.iota @Double n))
+ ,bench "normalize noop [1e6]" $
+ let n = 1_000_000
+ in nf (\a -> RS.normalize a)
+ (RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
+ ]
+ ,bgroupIf enableMisc "misc"
+ [let n = 1000
+ k = 1000
+ in bgroup ("fusion [" ++ show k ++ "]*" ++ show n)
+ [bench "sum (concat)" $
+ nf (\as -> VS.sum (VS.concat as))
+ (replicate n (VS.enumFromTo (1::Int) k))
+ ,bench "sum (force (concat))" $
+ nf (\as -> VS.sum (VS.force (VS.concat as)))
+ (replicate n (VS.enumFromTo (1::Int) k))]
+ ,bgroup "concat"
+ [bgroup "N"
+ [bgroup "hmatrix"
+ [bench ("LA.vjoin [500]*1e" ++ show ni) $
+ let n = 10 ^ ni
+ k = 500
+ in nf (\as -> LA.vjoin as)
+ (replicate n (VS.enumFromTo (1::Int) k))
+ | ni <- [1::Int ..5]]
+ ,bgroup "vectorStorable"
+ [bench ("VS.concat [500]*1e" ++ show ni) $
+ let n = 10 ^ ni
+ k = 500
+ in nf (\as -> VS.concat as)
+ (replicate n (VS.enumFromTo (1::Int) k))
+ | ni <- [1::Int ..5]]
+ ]
+ ,bgroup "K"
+ [bgroup "hmatrix"
+ [bench ("LA.vjoin [1e" ++ show ki ++ "]*500") $
+ let n = 500
+ k = 10 ^ ki
+ in nf (\as -> LA.vjoin as)
+ (replicate n (VS.enumFromTo (1::Int) k))
+ | ki <- [1::Int ..5]]
+ ,bgroup "vectorStorable"
+ [bench ("VS.concat [1e" ++ show ki ++ "]*500") $
+ let n = 500
+ k = 10 ^ ki
+ in nf (\as -> VS.concat as)
+ (replicate n (VS.enumFromTo (1::Int) k))
+ | ki <- [1::Int ..5]]
+ ]
+ ]
+ ]
+ ]
+
+tests_compare :: [Benchmark]
+tests_compare =
+ let n = 1_000_000 in
+ [bgroup "Num"
+ [bench "sum(+) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b)))
+ (riota @Double n, riota n)
+ ,bench "sum(*) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b)))
+ (riota @Double n, riota n)
+ ,bench "sum(/) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b)))
+ (riota @Double n, riota n)
+ ,bench "sum(**) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b)))
+ (riota @Double n, riota n)
+ ,bench "sum(sin) Double [1e6]" $
+ nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a)))
+ (riota @Double n)
+ ,bench "sum Double [1e6]" $
+ nf (\a -> runScalar (rsumOuter1 a))
+ (riota @Double n)
+ ]
+ ,bgroup "NumElt"
+ [bench "sum(+) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (a + b)))
+ (riota @Double n, riota n)
+ ,bench "sum(*) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ (riota @Double n, riota n)
+ ,bench "sum(/) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
+ (riota @Double n, riota n)
+ ,bench "sum(**) Double [1e6]" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (a ** b)))
+ (riota @Double n, riota n)
+ ,bench "sum(sin) Double [1e6]" $
+ nf (\a -> runScalar (rsumOuter1 (sin a)))
+ (riota @Double n)
+ ,bench "sum Double [1e6]" $
+ nf (\a -> runScalar (rsumOuter1 a))
+ (riota @Double n)
+ ,bench "sum(*) Double [1e6] stride 1; -1" $
+ nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ (riota @Double n, rrev1 (riota n))
+ ,bench "dotprod Float [1e6]" $
+ nf (\(a, b) -> rdot a b)
+ (riota @Float n, riota @Float n)
+ ,bench "dotprod Float [1e6] stride 1; -1" $
+ nf (\(a, b) -> rdot a b)
+ (riota @Float n, rrev1 (riota @Float n))
+ ,bench "dotprod Double [1e6]" $
+ nf (\(a, b) -> rdot a b)
+ (riota @Double n, riota @Double n)
+ ,bench "dotprod Double [1e6] stride 1; -1" $
+ nf (\(a, b) -> rdot a b)
+ (riota @Double n, rrev1 (riota @Double n))
+ ]
+ ,bgroup "hmatrix"
+ [bench "sum(+) Double [1e6]" $
+ nf (\(a, b) -> LA.sumElements (a + b))
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "sum(*) Double [1e6]" $
+ nf (\(a, b) -> LA.sumElements (a * b))
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "sum(/) Double [1e6]" $
+ nf (\(a, b) -> LA.sumElements (a / b))
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "sum(**) Double [1e6]" $
+ nf (\(a, b) -> LA.sumElements (a ** b))
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "sum(sin) Double [1e6]" $
+ nf (\a -> LA.sumElements (sin a))
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "sum Double [1e6]" $
+ nf (\a -> LA.sumElements a)
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "dotprod Float [1e6]" $
+ nf (\(a, b) -> a LA.<.> b)
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
+ ,bench "dotprod Double [1e6]" $
+ nf (\(a, b) -> a LA.<.> b)
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
+ ]
+ ]
diff --git a/cabal.project b/cabal.project
index 697d3bd..d102ed6 100644
--- a/cabal.project
+++ b/cabal.project
@@ -1,5 +1,2 @@
packages: .
-with-compiler: ghc-9.8.2
-
-allow-newer:
- orthotope:deepseq
+with-compiler: ghc-9.8.4
diff --git a/cbits/arith.c b/cbits/arith.c
new file mode 100644
index 0000000..f19b01e
--- /dev/null
+++ b/cbits/arith.c
@@ -0,0 +1,808 @@
+#include <stdio.h>
+#include <stdint.h>
+#include <inttypes.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <stdatomic.h>
+#include <string.h>
+#include <math.h>
+#include <threads.h>
+#include <sys/time.h>
+
+// These are the wrapper macros used in arith_lists.h. Preset them to empty to
+// avoid having to touch macros unrelated to the particular operation set below.
+#define LIST_BINOP(name, id, hsop)
+#define LIST_IBINOP(name, id, hsop)
+#define LIST_FBINOP(name, id, hsop)
+#define LIST_UNOP(name, id, _)
+#define LIST_FUNOP(name, id, _)
+#define LIST_REDOP(name, id, _)
+
+
+// Shorter names, due to CPP used both in function names and in C types.
+typedef int32_t i32;
+typedef int64_t i64;
+
+
+// PRECONDITIONS
+//
+// All strided array operations in this file assume that none of the shape
+// components are zero -- that is, the input arrays are non-empty. This must
+// be arranged on the Haskell side.
+//
+// Furthermore, note that while the Haskell side has an offset into the backing
+// vector, the C side assumes that the offset is zero. Shift the pointer if
+// necessary.
+
+
+/*****************************************************************************
+ * Performance statistics *
+ *****************************************************************************/
+
+// Each block holds a buffer with variable-length messages. Each message starts
+// with a tag byte; the respective sublists below give the fields after that tag
+// byte.
+// - 1: unary operation performance measurement
+// - u8: some identifier
+// - i32: input rank
+// - i64[rank]: input shape
+// - i64[rank]: input strides
+// - f64: seconds taken
+// - 2: binary operation performance measurement
+// - u8: a stats_binary_id
+// - i32: input rank
+// - i64[rank]: input shape
+// - i64[rank]: input 1 strides
+// - i64[rank]: input 2 strides
+// - f64: seconds taken
+// The 'prev' and 'cap' fields are set only once on creation of a block, and can
+// thus be read without restrictions. The 'len' field is potentially mutated
+// from different threads and must be handled with care.
+struct stats_block {
+ struct stats_block *prev; // backwards linked list; NULL if first block
+ size_t cap; // bytes capacity of buffer in this block
+ atomic_size_t len; // bytes filled in this buffer
+ uint8_t buf[]; // trailing VLA
+};
+
+enum stats_binary_id {
+ sbi_dotprod = 1,
+};
+
+// Atomic because blocks may be allocated from different threads.
+static _Atomic(struct stats_block*) stats_current = NULL;
+static atomic_bool stats_enabled = false;
+
+void oxarrays_stats_enable(i32 yes) { atomic_store(&stats_enabled, yes == 1); }
+
+static uint8_t* stats_alloc(size_t nbytes) {
+try_again: ;
+ struct stats_block *block = atomic_load(&stats_current);
+ size_t curlen = block != NULL ? atomic_load(&block->len) : 0;
+ size_t curcap = block != NULL ? block->cap : 0;
+
+ if (block == NULL || curlen + nbytes > curcap) {
+ const size_t newcap = stats_current == NULL ? 4096 : 2 * stats_current->cap;
+ struct stats_block *new = malloc(sizeof(struct stats_block) + newcap);
+ new->prev = stats_current;
+ curcap = new->cap = newcap;
+ curlen = new->len = 0;
+ if (!atomic_compare_exchange_strong(&stats_current, &block, new)) {
+ // Race condition, simply free this memory block and try again
+ free(new);
+ goto try_again;
+ }
+ block = new;
+ }
+
+ // Try to update the 'len' field of the block we captured at the start of the
+ // function. Note that it doesn't matter if someone else already allocated a
+ // new block in the meantime; we're still accessing the same block here, which
+ // may succeed or fail independently.
+ while (!atomic_compare_exchange_strong(&block->len, &curlen, curlen + nbytes)) {
+ // curlen was updated to the actual value.
+ // If the block got full in the meantime, try again from the start
+ if (curlen + nbytes > curcap) goto try_again;
+ }
+
+ return block->buf + curlen;
+}
+
+__attribute__((unused))
+static void stats_record_unary(enum stats_binary_id id, i32 rank, const i64 *shape, const i64 *strides, double secs) {
+ if (!atomic_load(&stats_enabled)) return;
+ uint8_t *buf = stats_alloc(1 + 1 + 4 + 2*rank*8 + 8);
+ *buf = 1; buf += 1;
+ *buf = id; buf += 1;
+ *(i32*)buf = rank; buf += 4;
+ memcpy((i64*)buf, shape, rank * 8); buf += rank * 8;
+ memcpy((i64*)buf, strides, rank * 8); buf += rank * 8;
+ *(double*)buf = secs;
+}
+
+__attribute__((unused))
+static void stats_record_binary(enum stats_binary_id id, i32 rank, const i64 *shape, const i64 *strides1, const i64 *strides2, double secs) {
+ if (!atomic_load(&stats_enabled)) return;
+ uint8_t *buf = stats_alloc(1 + 1 + 4 + 3*rank*8 + 8);
+ *buf = 2; buf += 1;
+ *buf = id; buf += 1;
+ *(i32*)buf = rank; buf += 4;
+ memcpy((i64*)buf, shape, rank * 8); buf += rank * 8;
+ memcpy((i64*)buf, strides1, rank * 8); buf += rank * 8;
+ memcpy((i64*)buf, strides2, rank * 8); buf += rank * 8;
+ *(double*)buf = secs;
+}
+
+#define TIME_START(varname_) \
+ struct timeval varname_ ## _start, varname_ ## _end; \
+ gettimeofday(&varname_ ## _start, NULL);
+#define TIME_END(varname_) \
+ (gettimeofday(&varname_ ## _end, NULL), \
+ ((varname_ ## _end).tv_sec - (varname_ ## _start).tv_sec) + \
+ ((varname_ ## _end).tv_usec - (varname_ ## _start).tv_usec) / (double)1e6)
+
+static size_t stats_print_unary(uint8_t *buf) {
+ uint8_t *orig_buf = buf;
+
+ enum stats_binary_id id = *buf; buf += 1;
+ i32 rank = *(i32*)buf; buf += 4;
+ i64 *shape = (i64*)buf; buf += rank * 8;
+ i64 *strides = (i64*)buf; buf += rank * 8;
+ double secs = *(double*)buf; buf += 8;
+
+ i64 shsize = 1; for (i32 i = 0; i < rank; i++) shsize *= shape[i];
+
+ printf("unary %d sz %" PRIi64 " ms %.3lf sh=[", (int)id, shsize, secs * 1000);
+ for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); }
+ printf("] str=[");
+ for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides[i]); }
+ printf("]\n");
+
+ return buf - orig_buf;
+}
+
+static size_t stats_print_binary(uint8_t *buf) {
+ uint8_t *orig_buf = buf;
+
+ enum stats_binary_id id = *buf; buf += 1;
+ i32 rank = *(i32*)buf; buf += 4;
+ i64 *shape = (i64*)buf; buf += rank * 8;
+ i64 *strides1 = (i64*)buf; buf += rank * 8;
+ i64 *strides2 = (i64*)buf; buf += rank * 8;
+ double secs = *(double*)buf; buf += 8;
+
+ i64 shsize = 1; for (i32 i = 0; i < rank; i++) shsize *= shape[i];
+
+ printf("binary %d sz %" PRIi64 " ms %.3lf sh=[", (int)id, shsize, secs * 1000);
+ for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); }
+ printf("] str1=[");
+ for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides1[i]); }
+ printf("] str2=[");
+ for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides2[i]); }
+ printf("]\n");
+
+ return buf - orig_buf;
+}
+
+// Also frees the printed log.
+void oxarrays_stats_print_all(void) {
+ printf("=== ox-arrays-arith-stats start ===\n");
+
+ // Claim the entire chain and prevent new blocks from being added to it.
+ // (This is technically slightly wrong because a value may still be in the
+ // process of being recorded to some blocks in the chain while we're doing
+ // this printing, but yolo)
+ struct stats_block *last = atomic_exchange(&stats_current, NULL);
+
+ // Reverse the linked list; after this loop, the 'prev' pointers point to the
+ // _next_ block, not the previous one.
+ struct stats_block *block = last;
+ if (last != NULL) {
+ struct stats_block *next = NULL;
+ // block next
+ // ##### <-##### <-##### NULL
+ while (block->prev != NULL) {
+ struct stats_block *prev = block->prev;
+ // prev block next
+ // ##### <-##### <-##### ##...
+ block->prev = next;
+ // prev block next
+ // ##### <-##### #####-> ##...
+ next = block;
+ // prev bl=nx
+ // ##### <-##### #####-> ##...
+ block = prev;
+ // block next
+ // ##### <-##### #####-> ##...
+ }
+ // block next
+ // NULL <-##### #####-> ##...
+ block->prev = next;
+ // block next
+ // NULL #####-> #####-> ##...
+ }
+
+ while (block != NULL) {
+ for (size_t i = 0; i < block->len; ) {
+ switch (block->buf[i]) {
+ case 1: i += 1 + stats_print_unary(block->buf + i+1); break;
+ case 2: i += 1 + stats_print_binary(block->buf + i+1); break;
+ default:
+ printf("# UNKNOWN ENTRY WITH ID %d, SKIPPING BLOCK\n", (int)block->buf[i]);
+ i = block->len;
+ break;
+ }
+ }
+ struct stats_block *next = block->prev; // remember, reversed!
+ free(block);
+ block = next;
+ }
+
+ printf("=== ox-arrays-arith-stats end ===\n");
+}
+
+
+/*****************************************************************************
+ * Additional math functions *
+ *****************************************************************************/
+
+#define GEN_ABS(x) \
+ _Generic((x), \
+ int: abs, \
+ long: labs, \
+ long long: llabs, \
+ float: fabsf, \
+ double: fabs)(x)
+
+// This does not result in multiple loads with GCC 13.
+#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0)
+
+#define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y)
+#define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x))
+#define GEN_ATAN2(y, x) _Generic((x), float: atan2f(y, x), double: atan2(y, x))
+#define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x)
+#define GEN_LOG(x) _Generic((x), float: logf, double: log)(x)
+#define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x)
+#define GEN_SIN(x) _Generic((x), float: sinf, double: sin)(x)
+#define GEN_COS(x) _Generic((x), float: cosf, double: cos)(x)
+#define GEN_TAN(x) _Generic((x), float: tanf, double: tan)(x)
+#define GEN_ASIN(x) _Generic((x), float: asinf, double: asin)(x)
+#define GEN_ACOS(x) _Generic((x), float: acosf, double: acos)(x)
+#define GEN_ATAN(x) _Generic((x), float: atanf, double: atan)(x)
+#define GEN_SINH(x) _Generic((x), float: sinhf, double: sinh)(x)
+#define GEN_COSH(x) _Generic((x), float: coshf, double: cosh)(x)
+#define GEN_TANH(x) _Generic((x), float: tanhf, double: tanh)(x)
+#define GEN_ASINH(x) _Generic((x), float: asinhf, double: asinh)(x)
+#define GEN_ACOSH(x) _Generic((x), float: acoshf, double: acosh)(x)
+#define GEN_ATANH(x) _Generic((x), float: atanhf, double: atanh)(x)
+#define GEN_LOG1P(x) _Generic((x), float: log1pf, double: log1p)(x)
+#define GEN_EXPM1(x) _Generic((x), float: expm1f, double: expm1)(x)
+
+// Taken from Haskell's implementation:
+// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#log1mexpOrd
+#define LOG1MEXP_IMPL(x) do { \
+ if (x > _Generic((x), float: logf, double: log)(2)) return GEN_LOG(-GEN_EXPM1(x)); \
+ else return GEN_LOG1P(-GEN_EXP(x)); \
+ } while (0)
+
+static float log1mexp_float(float x) { LOG1MEXP_IMPL(x); }
+static double log1mexp_double(double x) { LOG1MEXP_IMPL(x); }
+
+#define GEN_LOG1MEXP(x) _Generic((x), float: log1mexp_float, double: log1mexp_double)(x)
+
+// Taken from Haskell's implementation:
+// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#line-595
+#define LOG1PEXP_IMPL(x) do { \
+ if (x <= 18) return GEN_LOG1P(GEN_EXP(x)); \
+ if (x <= 100) return x + GEN_EXP(-x); \
+ return x; \
+ } while (0)
+
+static float log1pexp_float(float x) { LOG1PEXP_IMPL(x); }
+static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
+
+#define GEN_LOG1PEXP(x) _Generic((x), float: log1pexp_float, double: log1pexp_double)(x)
+
+
+/*****************************************************************************
+ * Helper functions *
+ *****************************************************************************/
+
+__attribute__((used))
+static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
+ fputc('[', stream);
+ for (i64 i = 0; i < rank; i++) {
+ if (i != 0) fputc(',', stream);
+ fprintf(stream, "%" PRIi64, shape[i]);
+ }
+ fputc(']', stream);
+}
+
+
+/*****************************************************************************
+ * Skeletons *
+ *****************************************************************************/
+
+// Walk a orthotope-style strided array, except for the inner dimension. The
+// body is run for every "inner vector".
+// Provides idx, outlinidx, arrlinidx.
+#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, ...) \
+ do { \
+ i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
+ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
+ i64 arrlinidx = 0; \
+ i64 outlinidx = 0; \
+ again_label_name: \
+ { \
+ __VA_ARGS__ \
+ } \
+ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
+ if (++idx[dim] < (shape)[dim]) { \
+ arrlinidx += (strides)[dim]; \
+ outlinidx++; \
+ goto again_label_name; \
+ } \
+ arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \
+ idx[dim] = 0; \
+ } \
+ } while (false)
+
+// Walk TWO orthotope-style strided arrays simultaneously, except for their
+// inner dimension. The arrays must have the same shape, but may have different
+// strides. The body is run for every pair of "inner vectors".
+// Provides idx, outlinidx, arrlinidx1, arrlinidx2.
+#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, ...) \
+ do { \
+ i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
+ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
+ i64 arrlinidx1 = 0, arrlinidx2 = 0; \
+ i64 outlinidx = 0; \
+ again_label_name: \
+ { \
+ __VA_ARGS__ \
+ } \
+ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
+ if (++idx[dim] < (shape)[dim]) { \
+ arrlinidx1 += (strides1)[dim]; \
+ arrlinidx2 += (strides2)[dim]; \
+ outlinidx++; \
+ goto again_label_name; \
+ } \
+ arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \
+ arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \
+ idx[dim] = 0; \
+ } \
+ } while (false)
+
+
+/*****************************************************************************
+ * Kernel functions *
+ *****************************************************************************/
+
+#define COMM_OP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ if (rank == 0) { out[0] = x op y[0]; return; } \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ }); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ if (rank == 0) { out[0] = x[0] op y[0]; return; } \
+ TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \
+ } \
+ }); \
+ }
+
+#define NONCOMM_OP_STRIDED(name, op, typ) \
+ COMM_OP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ if (rank == 0) { out[0] = x[0] op y; return; } \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \
+ } \
+ }); \
+ }
+
+#define PREFIX_BINOP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ if (rank == 0) { out[0] = op(x, y[0]); return; } \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \
+ } \
+ }); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ if (rank == 0) { out[0] = op(x[0], y[0]); return; } \
+ TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \
+ } \
+ }); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ if (rank == 0) { out[0] = op(x[0], y); return; } \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \
+ } \
+ }); \
+ }
+
+#define UNARY_OP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \
+ print_shape(stderr, rank, shape); \
+ fprintf(stderr, " strides="); \
+ print_shape(stderr, rank, strides); \
+ fprintf(stderr, "\n"); */ \
+ if (rank == 0) { out[0] = op(arr[0]); return; } \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \
+ } \
+ }); \
+ }
+
+// Used for reduction and dot product kernels below
+#define MANUAL_VECT_WID 8
+
+// Used in REDUCE1_OP and REDUCEFULL_OP below
+#define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \
+ do { \
+ const i64 n = innerLen; const i64 s = innerStride; \
+ if (n < MANUAL_VECT_WID) { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < n; i++) accum = accum op arr[arrlinidx + s * i]; \
+ destination = accum; \
+ } else { \
+ typ accum[MANUAL_VECT_WID]; \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr[arrlinidx + s * j]; \
+ for (i64 i = 1; i < n / MANUAL_VECT_WID; i++) { \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) { \
+ accum[j] = accum[j] op arr[arrlinidx + s * (MANUAL_VECT_WID * i + j)]; \
+ } \
+ } \
+ typ res = accum[0]; \
+ for (i64 j = 1; j < MANUAL_VECT_WID; j++) res = res op accum[j]; \
+ for (i64 i = n / MANUAL_VECT_WID * MANUAL_VECT_WID; i < n; i++) \
+ res = res op arr[arrlinidx + s * i]; \
+ destination = res; \
+ } \
+ } while (0)
+
+// Reduces along the innermost dimension.
+// 'out' will be filled densely in linearisation order.
+#define REDUCE1_OP(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, out[outlinidx]); \
+ }); \
+ }
+
+#define REDUCEFULL_OP(name, op, typ) \
+ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ if (rank == 0) return arr[0]; \
+ typ result = 0; \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \
+ }); \
+ return result; \
+ }
+
+// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
+#define EXTREMUM_OP(name, cmp, typ) \
+ void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ if (rank == 0) return; /* output index vector has length 0 anyways */ \
+ typ best = arr[0]; \
+ memset(outidx, 0, rank * sizeof(i64)); \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
+ bool found = false; \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ if (arr[arrlinidx + i] cmp best) { \
+ best = arr[arrlinidx + strides[rank - 1] * i]; \
+ found = true; \
+ outidx[rank - 1] = i; \
+ } \
+ } \
+ if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
+ }); \
+ }
+
+// Reduces along the innermost dimension.
+// 'out' will be filled densely in linearisation order.
+#define DOTPROD_INNER_OP(typ) \
+ void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
+ TIME_START(tm); \
+ TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
+ const i64 length = shape[rank - 1], stride1 = strides1[rank - 1], stride2 = strides2[rank - 1]; \
+ if (length < MANUAL_VECT_WID) { \
+ typ res = 0; \
+ for (i64 i = 0; i < length; i++) res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \
+ out[outlinidx] = res; \
+ } else { \
+ typ accum[MANUAL_VECT_WID]; \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[arrlinidx1 + stride1 * j] * arr2[arrlinidx2 + stride2 * j]; \
+ for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) \
+ accum[j] += arr1[arrlinidx1 + stride1 * (MANUAL_VECT_WID * i + j)] * arr2[arrlinidx2 + stride2 * (MANUAL_VECT_WID * i + j)]; \
+ typ res = accum[0]; \
+ for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \
+ for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \
+ res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \
+ out[outlinidx] = res; \
+ } \
+ }); \
+ stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \
+ }
+
+
+/*****************************************************************************
+ * Entry point functions *
+ *****************************************************************************/
+
+__attribute__((noreturn, cold))
+static void wrong_op(const char *name, int tag) {
+ fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag);
+ abort();
+}
+
+enum binop_tag_t {
+#undef LIST_BINOP
+#define LIST_BINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_BINOP
+#define LIST_BINOP(name, id, hsop)
+};
+
+#define ENTRY_BINARY_STRIDED_OPS(typ) \
+ void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ default: wrong_op("binary_sv_strided", tag); \
+ } \
+ } \
+ void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
+ default: wrong_op("binary_vs_strided", tag); \
+ } \
+ } \
+ void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ default: wrong_op("binary_vv_strided", tag); \
+ } \
+ }
+
+enum ibinop_tag_t {
+#undef LIST_IBINOP
+#define LIST_IBINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_IBINOP
+#define LIST_IBINOP(name, id, hsop)
+};
+
+#define ENTRY_IBINARY_STRIDED_OPS(typ) \
+ void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ switch (tag) { \
+ case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ default: wrong_op("ibinary_sv_strided", tag); \
+ } \
+ } \
+ void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ switch (tag) { \
+ case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ default: wrong_op("ibinary_vs_strided", tag); \
+ } \
+ } \
+ void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ switch (tag) { \
+ case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ default: wrong_op("ibinary_vv_strided", tag); \
+ } \
+ }
+
+enum fbinop_tag_t {
+#undef LIST_FBINOP
+#define LIST_FBINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_FBINOP
+#define LIST_FBINOP(name, id, hsop)
+};
+
+#define ENTRY_FBINARY_STRIDED_OPS(typ) \
+ void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ switch (tag) { \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case FB_ATAN2: oxarop_op_atan2_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ default: wrong_op("fbinary_sv_strided", tag); \
+ } \
+ } \
+ void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ switch (tag) { \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ default: wrong_op("fbinary_vs_strided", tag); \
+ } \
+ } \
+ void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ switch (tag) { \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ default: wrong_op("fbinary_vv_strided", tag); \
+ } \
+ }
+
+enum unop_tag_t {
+#undef LIST_UNOP
+#define LIST_UNOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_UNOP
+#define LIST_UNOP(name, id, _)
+};
+
+#define ENTRY_UNARY_STRIDED_OPS(typ) \
+ void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \
+ switch (tag) { \
+ case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case UO_SIGNUM: oxarop_op_signum_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ default: wrong_op("unary_strided", tag); \
+ } \
+ }
+
+enum funop_tag_t {
+#undef LIST_FUNOP
+#define LIST_FUNOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_FUNOP
+#define LIST_FUNOP(name, id, _)
+};
+
+#define ENTRY_FUNARY_STRIDED_OPS(typ) \
+ void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \
+ switch (tag) { \
+ case FU_RECIP: oxarop_op_recip_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_EXP: oxarop_op_exp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG: oxarop_op_log_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_SQRT: oxarop_op_sqrt_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_SIN: oxarop_op_sin_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_COS: oxarop_op_cos_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_TAN: oxarop_op_tan_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ASIN: oxarop_op_asin_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ACOS: oxarop_op_acos_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ATAN: oxarop_op_atan_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_SINH: oxarop_op_sinh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_COSH: oxarop_op_cosh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_TANH: oxarop_op_tanh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ASINH: oxarop_op_asinh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ACOSH: oxarop_op_acosh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ATANH: oxarop_op_atanh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG1P: oxarop_op_log1p_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_EXPM1: oxarop_op_expm1_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ default: wrong_op("funary_strided", tag); \
+ } \
+ }
+
+enum redop_tag_t {
+#undef LIST_REDOP
+#define LIST_REDOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_REDOP
+#define LIST_REDOP(name, id, _)
+};
+
+#define ENTRY_REDUCE1_OPS(typ) \
+ void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ switch (tag) { \
+ case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \
+ case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \
+ default: wrong_op("reduce", tag); \
+ } \
+ }
+
+#define ENTRY_REDUCEFULL_OPS(typ) \
+ typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ switch (tag) { \
+ case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \
+ case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \
+ default: wrong_op("reduce", tag); \
+ } \
+ }
+
+
+/*****************************************************************************
+ * Generate all the functions *
+ *****************************************************************************/
+
+#define INT_TYPES_XLIST X(i32) X(i64)
+#define FLOAT_TYPES_XLIST X(double) X(float)
+#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST
+
+#define X(typ) \
+ COMM_OP_STRIDED(add, +, typ) \
+ NONCOMM_OP_STRIDED(sub, -, typ) \
+ COMM_OP_STRIDED(mul, *, typ) \
+ UNARY_OP_STRIDED(neg, -, typ) \
+ UNARY_OP_STRIDED(abs, GEN_ABS, typ) \
+ UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \
+ REDUCE1_OP(sum1, +, typ) \
+ REDUCE1_OP(product1, *, typ) \
+ REDUCEFULL_OP(sumfull, +, typ) \
+ REDUCEFULL_OP(productfull, *, typ) \
+ ENTRY_BINARY_STRIDED_OPS(typ) \
+ ENTRY_UNARY_STRIDED_OPS(typ) \
+ ENTRY_REDUCE1_OPS(typ) \
+ ENTRY_REDUCEFULL_OPS(typ) \
+ EXTREMUM_OP(min, <, typ) \
+ EXTREMUM_OP(max, >, typ) \
+ DOTPROD_INNER_OP(typ)
+NUM_TYPES_XLIST
+#undef X
+
+#define X(typ) \
+ NONCOMM_OP_STRIDED(quot, /, typ) \
+ NONCOMM_OP_STRIDED(rem, %, typ) \
+ ENTRY_IBINARY_STRIDED_OPS(typ)
+INT_TYPES_XLIST
+#undef X
+
+#define X(typ) \
+ NONCOMM_OP_STRIDED(fdiv, /, typ) \
+ PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \
+ PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \
+ PREFIX_BINOP_STRIDED(atan2, GEN_ATAN2, typ) \
+ UNARY_OP_STRIDED(recip, 1.0/, typ) \
+ UNARY_OP_STRIDED(exp, GEN_EXP, typ) \
+ UNARY_OP_STRIDED(log, GEN_LOG, typ) \
+ UNARY_OP_STRIDED(sqrt, GEN_SQRT, typ) \
+ UNARY_OP_STRIDED(sin, GEN_SIN, typ) \
+ UNARY_OP_STRIDED(cos, GEN_COS, typ) \
+ UNARY_OP_STRIDED(tan, GEN_TAN, typ) \
+ UNARY_OP_STRIDED(asin, GEN_ASIN, typ) \
+ UNARY_OP_STRIDED(acos, GEN_ACOS, typ) \
+ UNARY_OP_STRIDED(atan, GEN_ATAN, typ) \
+ UNARY_OP_STRIDED(sinh, GEN_SINH, typ) \
+ UNARY_OP_STRIDED(cosh, GEN_COSH, typ) \
+ UNARY_OP_STRIDED(tanh, GEN_TANH, typ) \
+ UNARY_OP_STRIDED(asinh, GEN_ASINH, typ) \
+ UNARY_OP_STRIDED(acosh, GEN_ACOSH, typ) \
+ UNARY_OP_STRIDED(atanh, GEN_ATANH, typ) \
+ UNARY_OP_STRIDED(log1p, GEN_LOG1P, typ) \
+ UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \
+ UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \
+ UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \
+ ENTRY_FBINARY_STRIDED_OPS(typ) \
+ ENTRY_FUNARY_STRIDED_OPS(typ)
+FLOAT_TYPES_XLIST
+#undef X
+
+// Note: [zero-length VLA]
+//
+// Zero-length variable-length arrays are not allowed in C(99). Thus whenever we
+// have a VLA that could sometimes suffice to be empty (e.g. `idx` in the
+// TARRAY_WALK_NOINNER macros), we tweak the length formula (typically by just
+// adding 1) so that it never ends up empty.
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
new file mode 100644
index 0000000..432765c
--- /dev/null
+++ b/cbits/arith_lists.h
@@ -0,0 +1,39 @@
+LIST_BINOP(BO_ADD, 1, +)
+LIST_BINOP(BO_SUB, 2, -)
+LIST_BINOP(BO_MUL, 3, *)
+
+LIST_IBINOP(IB_QUOT, 1, quot)
+LIST_IBINOP(IB_REM, 2, rem)
+
+LIST_FBINOP(FB_DIV, 1, /)
+LIST_FBINOP(FB_POW, 2, **)
+LIST_FBINOP(FB_LOGBASE, 3, logBase)
+LIST_FBINOP(FB_ATAN2, 4, atan2)
+
+LIST_UNOP(UO_NEG, 1,)
+LIST_UNOP(UO_ABS, 2,)
+LIST_UNOP(UO_SIGNUM, 3,)
+
+LIST_FUNOP(FU_RECIP, 1,)
+LIST_FUNOP(FU_EXP, 2,)
+LIST_FUNOP(FU_LOG, 3,)
+LIST_FUNOP(FU_SQRT, 4,)
+LIST_FUNOP(FU_SIN, 5,)
+LIST_FUNOP(FU_COS, 6,)
+LIST_FUNOP(FU_TAN, 7,)
+LIST_FUNOP(FU_ASIN, 8,)
+LIST_FUNOP(FU_ACOS, 9,)
+LIST_FUNOP(FU_ATAN, 10,)
+LIST_FUNOP(FU_SINH, 11,)
+LIST_FUNOP(FU_COSH, 12,)
+LIST_FUNOP(FU_TANH, 13,)
+LIST_FUNOP(FU_ASINH, 14,)
+LIST_FUNOP(FU_ACOSH, 15,)
+LIST_FUNOP(FU_ATANH, 16,)
+LIST_FUNOP(FU_LOG1P, 17,)
+LIST_FUNOP(FU_EXPM1, 18,)
+LIST_FUNOP(FU_LOG1PEXP, 19,)
+LIST_FUNOP(FU_LOG1MEXP, 20,)
+
+LIST_REDOP(RO_SUM, 1,)
+LIST_REDOP(RO_PRODUCT, 2,)
diff --git a/example/Main.hs b/example/Main.hs
new file mode 100644
index 0000000..76c75c2
--- /dev/null
+++ b/example/Main.hs
@@ -0,0 +1,29 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeApplications #-}
+module Main where
+
+import Data.Array.Nested
+
+
+arr :: Ranked 2 (Shaped [2, 3] (Double, Int))
+arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
+ sgenerate (SNat @2 :$$ SNat @3 :$$ ZSS) $ \(k :.$ l :.$ ZIS) ->
+ let s = 24*i + 6*j + 3*k + l
+ in (fromIntegral s, s)
+
+foo :: (Double, Int)
+foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS)
+
+bad :: Ranked 2 (Ranked 1 Double)
+bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
+ rgenerate (i :$: ZSR) $ \(k :.: ZIR) ->
+ let s = 24*i + 6*j + 3*k
+ in fromIntegral s
+
+main :: IO ()
+main = do
+ print arr
+ print foo
+ print (rtranspose [1,0] arr)
+ -- print bad
diff --git a/gentrace.sh b/gentrace.sh
new file mode 100755
index 0000000..c3f1240
--- /dev/null
+++ b/gentrace.sh
@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+
+cat <<'EOF'
+module Data.Array.Nested.Trace (
+ -- * Traced variants
+ module Data.Array.Nested.Trace,
+
+ -- * Re-exports from the plain "Data.Array.Nested" module
+EOF
+
+sed -n '/^module/,/^) where/!d; /^\s*--\( \|$\)/d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
+
+cat <<'EOF'
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.Nested
+import Data.Array.Nested.Trace.TH
+
+
+EOF
+
+# shellcheck disable=SC2016 # dollar in single-quoted string
+echo '$(concat <$> mapM convertFun'
+sed -n '/^module/,/^) where/!d; /^\s*-- /d; /^ /p' src/Data/Array/Nested.hs |
+ grep -o '\b[a-z][a-zA-Z0-9_'"'"']*\b' |
+ grep -wv -e 'pattern' -e 'type' |
+ tr $'\n' ' ' |
+ sed 's/\([^ ]\+\)/'"'"'\1,/g; s/, $/])/; s/^/ [/'
+echo
diff --git a/ops/Data/Array/Strided.hs b/ops/Data/Array/Strided.hs
new file mode 100644
index 0000000..7d8c2d0
--- /dev/null
+++ b/ops/Data/Array/Strided.hs
@@ -0,0 +1,7 @@
+module Data.Array.Strided (
+ module Data.Array.Strided.Array,
+ module Data.Array.Strided.Arith,
+) where
+
+import Data.Array.Strided.Arith
+import Data.Array.Strided.Array
diff --git a/ops/Data/Array/Strided/Arith.hs b/ops/Data/Array/Strided/Arith.hs
new file mode 100644
index 0000000..7be6390
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith.hs
@@ -0,0 +1,7 @@
+module Data.Array.Strided.Arith (
+ NumElt(..),
+ IntElt(..),
+ FloatElt(..),
+) where
+
+import Data.Array.Strided.Arith.Internal
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
new file mode 100644
index 0000000..5802573
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -0,0 +1,933 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Strided.Arith.Internal where
+
+import Control.Monad
+import Data.Bifunctor (second)
+import Data.Bits
+import Data.Int
+import Data.List (sort, zip4)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
+import Foreign.C.Types
+import Foreign.Ptr
+import Foreign.Storable
+import GHC.TypeLits
+import GHC.TypeNats qualified as TypeNats
+import Language.Haskell.TH
+import System.IO (hFlush, stdout)
+import System.IO.Unsafe
+
+import Data.Array.Strided.Arith.Internal.Foreign
+import Data.Array.Strided.Arith.Internal.Lists
+import Data.Array.Strided.Array
+
+
+-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
+
+
+-- TODO: move this to a utilities module
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
+data Dict c where
+ Dict :: c => Dict c
+
+debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String
+debugShow (Array sh strides offset vec) =
+ "Array @" ++ show (natVal (Proxy @n)) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">"
+
+
+-- TODO: test all the cases of this thing with various input strides
+liftOpEltwise1 :: Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a -> Array n a
+liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec)
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ if blockSz == 0
+ then Array sh (map (const 0) strides) 0 VS.empty
+ else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [blockSz] [1] blockOff vec)
+ in Array sh strides (offset - blockOff) resvec
+ | otherwise = wrapUnary sn ptrconv cf_strided arr
+
+-- TODO: test all the cases of this thing with various input strides
+liftOpEltwise2 :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv
+ -> Array n a -> Array n a -> Array n a
+liftOpEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
+ arr1@(Array sh1 strides1 offset1 vec1)
+ arr2@(Array sh2 strides2 offset2 vec2)
+ | sh1 /= sh2 = error $ "liftOpEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | any (<= 0) sh1 = Array sh1 (0 <$ strides1) 0 VS.empty
+ | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
+ (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))
+ in Array sh1 strides1 0 vec'
+
+ (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
+ let arr2' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec2)
+ resvec = arrValues $ wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
+ in Array sh1 strides2 (offset2 - blockOff) resvec
+
+ (Just (_, 1), Nothing) -> -- scalar * array
+ wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2
+
+ (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
+ let arr1' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec1)
+ resvec = arrValues $ wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
+ in Array sh1 strides1 (offset1 - blockOff) resvec
+
+ (Nothing, Just (_, 1)) -> -- array * scalar
+ wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2)
+
+ (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
+ | strides1 == strides2
+ -> -- dense * dense but the strides match
+ if blockSz1 /= blockSz2 || offset1 - blockOff1 /= offset2 - blockOff2
+ then error $ "Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " ++ show (strides1, (blockOff1, blockSz1), strides2, (blockOff2, blockSz2))
+ else
+ let arr1' = arrayFromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
+ arr2' = arrayFromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
+ resvec = arrValues $ wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2'
+ in Array sh1 strides1 (offset1 - blockOff1) resvec
+
+ (_, _) -> -- fallback case
+ wrapBinaryVV sn ptrconv f_vv arr1 arr2
+
+-- | Given shape vector, offset and stride vector, check whether this virtual
+-- vector uses a dense subarray of its backing array. If so, the first index
+-- and the number of elements in this subarray is returned.
+-- This excludes any offset.
+stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
+stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0)
+stridesDense sh offsetNeg stridesNeg =
+ -- First reverse all dimensions with negative stride, so that the first used
+ -- value is at 'offset' and the rest is >= offset.
+ let (offset, strides) = flipReverseds sh offsetNeg stridesNeg
+ in -- sort dimensions on their stride, ascending, dropping any zero strides
+ case filter ((/= 0) . fst) (sort (zip strides sh)) of
+ [] -> Just (offset, 1)
+ (1, n) : pairs -> (offset,) <$> checkCover n pairs
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
+ where
+ -- Given size of currently densely covered region at beginning of the
+ -- array and the remaining (stride, size) pairs with all strides >=1,
+ -- return whether this all together covers a dense prefix of the array. If
+ -- it does, return the number of elements in this prefix.
+ checkCover :: Int -> [(Int, Int)] -> Maybe Int
+ checkCover block [] = Just block
+ checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs
+
+ -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
+ flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
+ flipReverseds [] off [] = (off, [])
+ flipReverseds (n : sh') off (s : str')
+ | s >= 0 = second (s :) (flipReverseds sh' off str')
+ | otherwise =
+ let off' = off + (n - 1) * s
+ in second ((-s) :) (flipReverseds sh' off' str')
+ flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
+
+data Unreplicated a =
+ forall n'. KnownNat n' =>
+ -- | Let the original array, with replicated dimensions, be called A.
+ Unreplicated -- | An array with all strides /= 0. Call this array U. It has
+ -- the same shape as A, except with all the replicated (stride
+ -- == 0) dimensions removed. The shape of U is the
+ -- "unreplicated shape".
+ (Array n' a)
+ -- | Product of sizes of the unreplicated dimensions
+ Int
+ -- | Given the stride vector of an array with the unreplicated
+ -- shape, this function reinserts zeros so that it may be
+ -- combined with the original shape of A.
+ ([Int] -> [Int])
+
+-- | Removes all replicated dimensions (i.e. those with stride == 0) from the array.
+unreplicateStrides :: Array n a -> Unreplicated a
+unreplicateStrides (Array sh strides offset vec) =
+ let replDims = map (== 0) strides
+ (shF, stridesF) = unzip [(n, s) | (n, s) <- zip sh strides, s /= 0]
+
+ reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
+ reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
+ reinsertZeros [] [] = []
+ reinsertZeros (False : _) [] = error "unreplicateStrides: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error "unreplicateStrides: Internal error: reply strides too long"
+
+ unrepSize = product [n | (n, True) <- zip sh replDims]
+
+ in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
+
+simplifyArray :: Array n a
+ -> (forall n'. KnownNat n'
+ => Array n' a -- U
+ -- Product of sizes of the unreplicated dimensions
+ -> Int
+ -- Convert index in U back to index into original
+ -- array. Replicated dimensions get 0.
+ -> ([Int] -> [Int])
+ -- Given a new array of the same shape as U, convert
+ -- it back to the original shape and iteration order.
+ -> (Array n' a -> Array n a)
+ -- Do the same except without the INNER dimension.
+ -- This throws an error if the inner dimension had
+ -- stride 0.
+ -> (Array (n' - 1) a -> Array (n - 1) a)
+ -> r)
+ -> r
+simplifyArray array k
+ | let revDims = map (<0) (arrStrides array)
+ , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
+ = k array'
+ unrepSize
+ (\idx -> rereplicate (zipWith3 (\b n i -> if b then n - 1 - i else i)
+ revDims (arrShape array') idx))
+ (\(Array sh' strides' offset' vec') ->
+ if sh' == arrShape array'
+ then arrayRevDims revDims (Array (arrShape array) (rereplicate strides') offset' vec')
+ else error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")")
+ (\(Array sh' strides' offset' vec') ->
+ if | sh' /= init (arrShape array') ->
+ error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")"
+ | last (arrStrides array) == 0 ->
+ error "simplifyArray: Internal error: reduction reply handler used while inner stride was 0"
+ | otherwise ->
+ arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec'))
+
+-- | The two input arrays must have the same shape.
+simplifyArray2 :: Array n a -> Array n a
+ -> (forall n'. KnownNat n'
+ => Array n' a -- U1
+ -> Array n' a -- U2 (same shape as U1)
+ -- Product of sizes of the dimensions that are
+ -- replicated in neither input
+ -> Int
+ -- Convert index in U{1,2} back to index into original
+ -- arrays. Dimensions that are replicated in both
+ -- inputs get 0.
+ -> ([Int] -> [Int])
+ -- Given a new array of the same shape as U1 (& U2),
+ -- convert it back to the original shape and
+ -- iteration order.
+ -> (Array n' a -> Array n a)
+ -- Do the same except without the INNER dimension.
+ -- This throws an error if the inner dimension had
+ -- stride 0 in both inputs.
+ -> (Array (n' - 1) a -> Array (n - 1) a)
+ -> r)
+ -> r
+simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
+ | sh /= sh2 = error "simplifyArray2: Unequal shapes"
+
+ | let revDims = zipWith (\s1 s2 -> s1 < 0 && s2 < 0) (arrStrides arr1) (arrStrides arr2)
+ , Array _ strides1 offset1 vec1 <- arrayRevDims revDims arr1
+ , Array _ strides2 offset2 vec2 <- arrayRevDims revDims arr2
+
+ , let replDims = zipWith (\s1 s2 -> s1 == 0 && s2 == 0) strides1 strides2
+ , let (shF, strides1F, strides2F) = unzip3 [(n, s1, s2) | (n, s1, s2, False) <- zip4 sh strides1 strides2 replDims]
+
+ , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
+ reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
+ reinsertZeros [] [] = []
+ reinsertZeros (False : _) [] = error "simplifyArray2: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error "simplifyArray2: Internal error: reply strides too long"
+
+ , let unrepSize = product [n | (n, True) <- zip sh replDims]
+
+ = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ k @lenshF
+ (Array shF strides1F offset1 vec1)
+ (Array shF strides2F offset2 vec2)
+ unrepSize
+ (\idx -> zipWith3 (\b n i -> if b then n - 1 - i else i)
+ revDims sh (reinsertZeros replDims idx))
+ (\(Array sh' strides' offset' vec') ->
+ if sh' /= shF then error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
+ else arrayRevDims revDims (Array sh (reinsertZeros replDims strides') offset' vec'))
+ (\(Array sh' strides' offset' vec') ->
+ if | sh' /= init shF ->
+ error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
+ | last replDims ->
+ error "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
+ | otherwise ->
+ arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec'))
+
+{-# NOINLINE wrapUnary #-}
+wrapUnary :: forall a b n. Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a
+ -> Array n a
+wrapUnary _ ptrconv cf_strided array =
+ simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do
+ let ndims' = length sh
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral ndims') (ptrconv poutv) psh pstrides pv'
+ restore . arrayFromVector sh <$> VS.unsafeFreeze outv
+
+{-# NOINLINE wrapBinarySV #-}
+wrapBinarySV :: forall a b n. Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
+ -> a -> Array n a
+ -> Array n a
+wrapBinarySV SNat valconv ptrconv cf_strided x array =
+ simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do
+ let ndims' = length sh
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral ndims') psh (ptrconv poutv) (valconv x) pstrides pv'
+ restore . arrayFromVector sh <$> VS.unsafeFreeze outv
+
+wrapBinaryVS :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
+ -> Array n a -> a
+ -> Array n a
+wrapBinaryVS sn valconv ptrconv cf_strided arr y =
+ wrapBinarySV sn valconv ptrconv
+ (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr
+
+-- | The two shapes must be equal and non-empty. This is checked.
+{-# NOINLINE wrapBinaryVV #-}
+wrapBinaryVV :: forall a b n. Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
+ -> Array n a -> Array n a
+ -> Array n a
+-- TODO: do unreversing and unreplication on the input arrays (but
+-- simultaneously: can only unreplicate if _both_ are replicated on that
+-- dimension)
+wrapBinaryVV sn@SNat ptrconv cf_strided
+ (Array sh strides1 offset1 vec1)
+ (Array sh2 strides2 offset2 vec2)
+ | sh /= sh2 = error $ "wrapBinaryVV: unequal shapes: " ++ show sh ++ " and " ++ show sh2
+ | any (<= 0) sh = error $ "wrapBinaryVV: empty shape: " ++ show sh
+ | otherwise = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith vec1 $ \pv1 ->
+ VS.unsafeWith vec2 $ \pv2 ->
+ let pv1' = pv1 `plusPtr` (offset1 * sizeOf (undefined :: a))
+ pv2' = pv2 `plusPtr` (offset2 * sizeOf (undefined :: a))
+ in cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 pv1' pstrides2 pv2'
+ arrayFromVector sh <$> VS.unsafeFreeze outv
+
+-- TODO: test handling of negative strides
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> Array (n + 1) a -> Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides offset vec)
+ | null sh = error "unreachable"
+ | last sh <= 0 = arrayFromConstant (init sh) 0
+ | any (<= 0) (init sh) = Array (init sh) (0 <$ init strides) 0 VS.empty
+ -- now the input array is nonempty
+ | last sh == 1 = Array (init sh) (init strides) offset vec
+ | last strides == 0 =
+ wrapBinarySV sn valconv ptrconv fscale (fromIntegral @Int @a (last sh))
+ (Array (init sh) (init strides) offset vec)
+ -- now there is useful work along the inner dimension
+ -- Note that unreplication keeps the inner dimension intact, because `last strides /= 0` at this point.
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec' :: Array n' a) _ _ _ restore -> unsafePerformIO $ do
+ let ndims' = length sh'
+ outv <- VSM.unsafeNew (product (init sh'))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
+ TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
+ LTI -> pure Dict
+ EQI -> pure Dict
+ _ -> error "impossible" -- because `last strides /= 0`
+ case sameNat (natSing @(n' - 1)) (natSing @n'm1) of
+ Just Refl -> restore . arrayFromVector @_ @n'm1 (init sh') <$> VS.unsafeFreeze outv
+ Nothing -> error "impossible"
+
+-- TODO: test handling of negative strides
+-- | Reduce full array
+{-# NOINLINE vectorRedFullOp #-}
+vectorRedFullOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> Int -> a)
+ -> (b -> a)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> Array n a -> a
+vectorRedFullOp _ scaleval valbackconv ptrconv fred array@(Array sh strides offset vec)
+ | null sh = vec VS.! offset -- 0D array has one element
+ | any (<= 0) sh = 0
+ -- now the input array is nonempty
+ | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec') unrepSize _ _ _ -> unsafePerformIO $ do
+ let ndims' = length sh'
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in (`scaleval` unrepSize) . valbackconv
+ <$> fred (fromIntegral ndims') psh pstrides (ptrconv pv')
+
+-- TODO: test this function
+-- | Find extremum (minindex ("argmin") or maxindex) in full array
+{-# NOINLINE vectorExtremumOp #-}
+vectorExtremumOp :: forall a b n. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> Array n a -> [Int] -- result length: n
+vectorExtremumOp ptrconv fextrem array@(Array sh strides _ _)
+ | null sh = []
+ | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
+ -- now the input array is nonempty
+ | all (== 0) strides = 0 <$ sh
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ simplifyArray array $ \(Array sh' strides' offset' vec') _ upindex _ _ -> unsafePerformIO $ do
+ let ndims' = length sh'
+ outvR <- VSM.unsafeNew (length sh')
+ VSM.unsafeWith outvR $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides ->
+ VS.unsafeWith vec' $ \pv ->
+ let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
+ in fextrem poutv (fromIntegral ndims') psh pstrides (ptrconv pv')
+ upindex . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outvR
+
+{-# NOINLINE vectorDotprodInnerOp #-}
+vectorDotprodInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (SNat n -> Array n a -> Array n a -> Array n a) -- ^ elementwise multiplication
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
+ arr1@(Array sh1 strides1 offset1 vec1)
+ arr2@(Array sh2 strides2 offset2 vec2)
+ | null sh1 || null sh2 = error "unreachable"
+ | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | last sh1 <= 0 = arrayFromConstant (init sh1) 0
+ | any (<= 0) (init sh1) = Array (init sh1) (0 <$ init strides1) 0 VS.empty
+ -- now the input arrays are nonempty
+ | last sh1 == 1 =
+ fmul sn (Array (init sh1) (init strides1) offset1 vec1)
+ (Array (init sh2) (init strides2) offset2 vec2)
+ | last strides1 == 0 =
+ fmul sn
+ (Array (init sh1) (init strides1) offset1 vec1)
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr2)
+ | last strides2 == 0 =
+ fmul sn
+ (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
+ (Array (init sh2) (init strides2) offset2 vec2)
+ -- now there is useful dotprod work along the inner dimension
+ | otherwise =
+ simplifyArray2 arr1 arr2 $ \(Array sh' strides1' offset1' vec1' :: Array n' a) (Array _ strides2' offset2' vec2') _ _ _ restore ->
+ unsafePerformIO $ do
+ let inrank = length sh'
+ outv <- VSM.unsafeNew (product (init sh'))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1')) $ \pstrides1 ->
+ VS.unsafeWith vec1' $ \pvec1 ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2')) $ \pstrides2 ->
+ VS.unsafeWith vec2' $ \pvec2 ->
+ fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
+ pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
+ pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
+ TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
+ LTI -> pure Dict
+ EQI -> pure Dict
+ GTI -> error "impossible" -- because `last strides1 /= 0`
+ case sameNat (natSing @(n' - 1)) (natSing @n'm1) of
+ Just Refl -> restore . arrayFromVector (init sh') <$> VS.unsafeFreeze outv
+ Nothing -> error "impossible"
+
+mulWithInt :: Num a => a -> Int -> a
+mulWithInt a i = a * fromIntegral i
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss_str = varE (aboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM intTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_ibinary_" ++ atCName arithtype
+ c_ss_str = varE (aiboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss_str = varE (afboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let scaleVar = case arithop of
+ RO_SUM -> varE 'mulWithInt
+ RO_PRODUCT -> varE '(^)
+ let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
+ namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
+ c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ sequence [SigD name1 <$>
+ [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
+ return $ FunD name1 [Clause [] (NormalB body) []]
+ ,SigD namefull <$>
+ [t| forall n. SNat n -> Array n $ttyp -> $ttyp |]
+ ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
+ return $ FunD namefull [Clause [] (NormalB body) []]
+ ])
+
+$(fmap concat . forM typesList $ \arithtype ->
+ fmap concat . forM ["min", "max"] $ \fname -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
+ sequence [SigD name <$>
+ [t| forall n. Array n $ttyp -> [Int] |]
+ ,do body <- [| vectorExtremumOp id $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype))
+ mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |]
+ ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO ()
+foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO ()
+
+statisticsEnable :: Bool -> IO ()
+statisticsEnable b = c_stats_enable (if b then 1 else 0)
+
+-- | Consumes the log: one particular event will only ever be printed once,
+-- even if statisticsPrintAll is called multiple times.
+statisticsPrintAll :: IO ()
+statisticsPrintAll = do
+ hFlush stdout -- lower the chance of overlapping output
+ c_stats_print_all
+
+-- This branch is ostensibly a runtime branch, but will (hopefully) be
+-- constant-folded away by GHC.
+intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
+ => (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
+ -> (SNat n -> Array n i -> Array n i)
+intWidBranch1 f32 f64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr f32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr f64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> i -> i) -- ss
+ -- int32
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv
+ -> (SNat n -> Array n i -> Array n i -> Array n i)
+intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> Array (n + 1) i -> Array n i)
+intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> Int -> i) -- ^ scale op
+ -- int32
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> (SNat n -> Array n i -> i)
+intWidBranchRedFull fsc fred32 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i)
+ => -- int32
+ (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> (Array n i -> [Int])
+intWidBranchExtr fextr32 fextr64
+ | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
+ | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i)
+ => -- int32
+ (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -- int64
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel
+ -> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i)
+intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32
+ | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64
+ | otherwise = error "Unsupported Int width"
+
+class NumElt a where
+ numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltSub :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltMul :: SNat n -> Array n a -> Array n a -> Array n a
+ numEltNeg :: SNat n -> Array n a -> Array n a
+ numEltAbs :: SNat n -> Array n a -> Array n a
+ numEltSignum :: SNat n -> Array n a -> Array n a
+ numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a
+ numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a
+ numEltSumFull :: SNat n -> Array n a -> a
+ numEltProductFull :: SNat n -> Array n a -> a
+ numEltMinIndex :: SNat n -> Array n a -> [Int]
+ numEltMaxIndex :: SNat n -> Array n a -> [Int]
+ numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+
+instance NumElt Int32 where
+ numEltAdd = addVectorInt32
+ numEltSub = subVectorInt32
+ numEltMul = mulVectorInt32
+ numEltNeg = negVectorInt32
+ numEltAbs = absVectorInt32
+ numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
+ numEltSumFull = sumFullVectorInt32
+ numEltProductFull = productFullVectorInt32
+ numEltMinIndex _ = minindexVectorInt32
+ numEltMaxIndex _ = maxindexVectorInt32
+ numEltDotprodInner = dotprodinnerVectorInt32
+
+instance NumElt Int64 where
+ numEltAdd = addVectorInt64
+ numEltSub = subVectorInt64
+ numEltMul = mulVectorInt64
+ numEltNeg = negVectorInt64
+ numEltAbs = absVectorInt64
+ numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
+ numEltSumFull = sumFullVectorInt64
+ numEltProductFull = productFullVectorInt64
+ numEltMinIndex _ = minindexVectorInt64
+ numEltMaxIndex _ = maxindexVectorInt64
+ numEltDotprodInner = dotprodinnerVectorInt64
+
+instance NumElt Float where
+ numEltAdd = addVectorFloat
+ numEltSub = subVectorFloat
+ numEltMul = mulVectorFloat
+ numEltNeg = negVectorFloat
+ numEltAbs = absVectorFloat
+ numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
+ numEltSumFull = sumFullVectorFloat
+ numEltProductFull = productFullVectorFloat
+ numEltMinIndex _ = minindexVectorFloat
+ numEltMaxIndex _ = maxindexVectorFloat
+ numEltDotprodInner = dotprodinnerVectorFloat
+
+instance NumElt Double where
+ numEltAdd = addVectorDouble
+ numEltSub = subVectorDouble
+ numEltMul = mulVectorDouble
+ numEltNeg = negVectorDouble
+ numEltAbs = absVectorDouble
+ numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
+ numEltSumFull = sumFullVectorDouble
+ numEltProductFull = productFullVectorDouble
+ numEltMinIndex _ = minindexVectorDouble
+ numEltMaxIndex _ = maxindexVectorDouble
+ numEltDotprodInner = dotprodinnerVectorDouble
+
+instance NumElt Int where
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
+ numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+
+instance NumElt CInt where
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
+ numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
+
+class NumElt a => IntElt a where
+ intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a
+ intEltRem :: SNat n -> Array n a -> Array n a -> Array n a
+
+instance IntElt Int32 where
+ intEltQuot = quotVectorInt32
+ intEltRem = remVectorInt32
+
+instance IntElt Int64 where
+ intEltQuot = quotVectorInt64
+ intEltRem = remVectorInt64
+
+instance IntElt Int where
+ intEltQuot = intWidBranch2 @Int quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @Int rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+instance IntElt CInt where
+ intEltQuot = intWidBranch2 @CInt quot
+ (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ intEltRem = intWidBranch2 @CInt rem
+ (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
+ (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+
+class NumElt a => FloatElt a where
+ floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a
+ floatEltRecip :: SNat n -> Array n a -> Array n a
+ floatEltExp :: SNat n -> Array n a -> Array n a
+ floatEltLog :: SNat n -> Array n a -> Array n a
+ floatEltSqrt :: SNat n -> Array n a -> Array n a
+ floatEltSin :: SNat n -> Array n a -> Array n a
+ floatEltCos :: SNat n -> Array n a -> Array n a
+ floatEltTan :: SNat n -> Array n a -> Array n a
+ floatEltAsin :: SNat n -> Array n a -> Array n a
+ floatEltAcos :: SNat n -> Array n a -> Array n a
+ floatEltAtan :: SNat n -> Array n a -> Array n a
+ floatEltSinh :: SNat n -> Array n a -> Array n a
+ floatEltCosh :: SNat n -> Array n a -> Array n a
+ floatEltTanh :: SNat n -> Array n a -> Array n a
+ floatEltAsinh :: SNat n -> Array n a -> Array n a
+ floatEltAcosh :: SNat n -> Array n a -> Array n a
+ floatEltAtanh :: SNat n -> Array n a -> Array n a
+ floatEltLog1p :: SNat n -> Array n a -> Array n a
+ floatEltExpm1 :: SNat n -> Array n a -> Array n a
+ floatEltLog1pexp :: SNat n -> Array n a -> Array n a
+ floatEltLog1mexp :: SNat n -> Array n a -> Array n a
+ floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltPow = powVectorFloat
+ floatEltLogbase = logbaseVectorFloat
+ floatEltRecip = recipVectorFloat
+ floatEltExp = expVectorFloat
+ floatEltLog = logVectorFloat
+ floatEltSqrt = sqrtVectorFloat
+ floatEltSin = sinVectorFloat
+ floatEltCos = cosVectorFloat
+ floatEltTan = tanVectorFloat
+ floatEltAsin = asinVectorFloat
+ floatEltAcos = acosVectorFloat
+ floatEltAtan = atanVectorFloat
+ floatEltSinh = sinhVectorFloat
+ floatEltCosh = coshVectorFloat
+ floatEltTanh = tanhVectorFloat
+ floatEltAsinh = asinhVectorFloat
+ floatEltAcosh = acoshVectorFloat
+ floatEltAtanh = atanhVectorFloat
+ floatEltLog1p = log1pVectorFloat
+ floatEltExpm1 = expm1VectorFloat
+ floatEltLog1pexp = log1pexpVectorFloat
+ floatEltLog1mexp = log1mexpVectorFloat
+ floatEltAtan2 = atan2VectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltPow = powVectorDouble
+ floatEltLogbase = logbaseVectorDouble
+ floatEltRecip = recipVectorDouble
+ floatEltExp = expVectorDouble
+ floatEltLog = logVectorDouble
+ floatEltSqrt = sqrtVectorDouble
+ floatEltSin = sinVectorDouble
+ floatEltCos = cosVectorDouble
+ floatEltTan = tanVectorDouble
+ floatEltAsin = asinVectorDouble
+ floatEltAcos = acosVectorDouble
+ floatEltAtan = atanVectorDouble
+ floatEltSinh = sinhVectorDouble
+ floatEltCosh = coshVectorDouble
+ floatEltTanh = tanhVectorDouble
+ floatEltAsinh = asinhVectorDouble
+ floatEltAcosh = acoshVectorDouble
+ floatEltAtanh = atanhVectorDouble
+ floatEltLog1p = log1pVectorDouble
+ floatEltExpm1 = expm1VectorDouble
+ floatEltLog1pexp = log1pexpVectorDouble
+ floatEltLog1mexp = log1mexpVectorDouble
+ floatEltAtan2 = atan2VectorDouble
diff --git a/ops/Data/Array/Strided/Arith/Internal/Foreign.hs b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs
new file mode 100644
index 0000000..dad65f9
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs
@@ -0,0 +1,47 @@
+{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Strided.Arith.Internal.Foreign where
+
+import Data.Int
+import Foreign.C.Types
+import Foreign.Ptr
+import Language.Haskell.TH
+
+import Data.Array.Strided.Arith.Internal.Lists
+
+
+$(do
+ let importsScal ttyp tyn =
+ [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |])
+ ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
+ ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ]
+
+ let importsInt ttyp tyn =
+ [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |])
+ ]
+
+ let importsFloat ttyp tyn =
+ [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |])
+ ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ]
+
+ let generate types imports =
+ sequence
+ [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ
+ | arithtype <- types
+ , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)]
+ decs1 <- generate typesList importsScal
+ decs2 <- generate intTypesList importsInt
+ decs3 <- generate floatTypesList importsFloat
+ return (decs1 ++ decs2 ++ decs3))
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
new file mode 100644
index 0000000..910a77c
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -0,0 +1,95 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Strided.Arith.Internal.Lists where
+
+import Data.Char
+import Data.Int
+import Language.Haskell.TH
+
+import Data.Array.Strided.Arith.Internal.Lists.TH
+
+
+data ArithType = ArithType
+ { atType :: Name -- ''Int32
+ , atCName :: String -- "i32"
+ }
+
+intTypesList :: [ArithType]
+intTypesList =
+ [ArithType ''Int32 "i32"
+ ,ArithType ''Int64 "i64"
+ ]
+
+floatTypesList :: [ArithType]
+floatTypesList =
+ [ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
+typesList :: [ArithType]
+typesList = intTypesList ++ floatTypesList
+
+-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
+$(genArithDataType Binop "ArithBOp")
+
+$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
+$(genArithEnumFun Binop ''ArithBOp "aboEnum")
+
+$(do clauses <- readArithLists Binop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
+ ,return $ FunD (mkName "aboNumOp") clauses])
+
+
+-- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded)
+$(genArithDataType IBinop "ArithIBOp")
+
+$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3))
+$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum")
+
+$(do clauses <- readArithLists IBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |]
+ ,return $ FunD (mkName "aiboNumOp") clauses])
+
+
+-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
+$(genArithDataType FBinop "ArithFBOp")
+
+$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
+$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
+
+$(do clauses <- readArithLists FBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
+ ,return $ FunD (mkName "afboNumOp") clauses])
+
+
+-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
+$(genArithDataType Unop "ArithUOp")
+
+$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
+$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+
+
+-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
+$(genArithDataType FUnop "ArithFUOp")
+
+$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
+$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
+
+
+-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
+$(genArithDataType Redop "ArithRedOp")
+
+$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
+$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs
new file mode 100644
index 0000000..b8f6a3d
--- /dev/null
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs
@@ -0,0 +1,83 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Strided.Arith.Internal.Lists.TH where
+
+import Control.Monad
+import Control.Monad.IO.Class
+import Data.Maybe
+import Foreign.C.Types
+import Language.Haskell.TH
+import Language.Haskell.TH.Syntax
+import Text.Read
+
+
+data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop
+ deriving (Show, Eq)
+
+readArithLists :: OpKind
+ -> (String -> Int -> String -> Q a)
+ -> ([a] -> Q r)
+ -> Q r
+readArithLists targetkind fop fcombine = do
+ addDependentFile "cbits/arith_lists.h"
+ lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
+
+ mvals <- forM lns $ \line -> do
+ if null (dropWhile (== ' ') line)
+ then return Nothing
+ else do let (kind, name, num, aux) = parseLine line
+ if kind == targetkind
+ then Just <$> fop name num aux
+ else return Nothing
+
+ fcombine (catMaybes mvals)
+ where
+ parseLine s0
+ | ("LIST_", s1) <- splitAt 5 s0
+ , (kindstr, '(' : s2) <- break (== '(') s1
+ , (f1, ',' : s3) <- parseField s2
+ , (f2, ',' : s4) <- parseField s3
+ , (f3, ')' : _) <- parseField s4
+ , Just kind <- parseKind kindstr
+ , let name = f1
+ , Just num <- readMaybe f2
+ , let aux = f3
+ = (kind, name, num, aux)
+ | otherwise
+ = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
+
+ parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
+
+ parseKind "BINOP" = Just Binop
+ parseKind "IBINOP" = Just IBinop
+ parseKind "FBINOP" = Just FBinop
+ parseKind "UNOP" = Just Unop
+ parseKind "FUNOP" = Just FUnop
+ parseKind "REDOP" = Just Redop
+ parseKind _ = Nothing
+
+genArithDataType :: OpKind -> String -> Q [Dec]
+genArithDataType kind dtname = do
+ cons <- readArithLists kind
+ (\name _num _ -> return $ NormalC (mkName name) [])
+ return
+ return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
+
+genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
+genArithNameFun kind dtname funname nametrans = do
+ clauses <- readArithLists kind
+ (\name _num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (StringL (nametrans name))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
+ ,FunD (mkName funname) clauses]
+
+genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
+genArithEnumFun kind dtname funname = do
+ clauses <- readArithLists kind
+ (\name num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (IntegerL (fromIntegral num))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
+ ,FunD (mkName funname) clauses]
diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs
new file mode 100644
index 0000000..9280fe0
--- /dev/null
+++ b/ops/Data/Array/Strided/Array.hs
@@ -0,0 +1,44 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+module Data.Array.Strided.Array where
+
+import Data.List.NonEmpty qualified as NE
+import Data.Proxy
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable
+import GHC.TypeLits
+
+
+data Array (n :: Nat) a = Array
+ { arrShape :: ![Int]
+ , arrStrides :: ![Int]
+ , arrOffset :: !Int
+ , arrValues :: !(VS.Vector a)
+ }
+
+-- | Takes a vector in normalised order (inner dimension, i.e. last in the
+-- list, iterates fastest).
+arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a
+arrayFromVector sh vec
+ | VS.length vec == shsize
+ , length sh == fromIntegral (natVal (Proxy @n))
+ = Array sh strides 0 vec
+ | otherwise = error $ "arrayFromVector: Shape " ++ show sh ++ " does not match vector length " ++ show (VS.length vec)
+ where
+ shsize = product sh
+ strides = NE.tail (NE.scanr (*) 1 sh)
+
+arrayFromConstant :: Storable a => [Int] -> a -> Array n a
+arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x)
+
+arrayRevDims :: [Bool] -> Array n a -> Array n a
+arrayRevDims bs (Array sh strides offset vec)
+ | length bs == length sh =
+ Array sh
+ (zipWith (\b s -> if b then -s else s) bs strides)
+ (offset + sum (zipWith3 (\b n s -> if b then (n - 1) * s else 0) bs sh strides))
+ vec
+ | otherwise = error $ "arrayRevDims: " ++ show (length bs) ++ " booleans given but rank " ++ show (length sh)
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 0aa7001..be4bb03 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -1,32 +1,184 @@
cabal-version: 3.0
name: ox-arrays
version: 0.1.0.0
-author: Tom Smeding
+synopsis: An efficient CPU-based multidimensional array (tensor) library
+description:
+ An efficient and richly typed CPU-based multidimensional array (tensor)
+ library built upon the optimized tensor representation (strides list)
+ implemented in the orthotope package. See the README.
+
+ If you use this package: let me know (e.g. via email) if you find it useful!
+ Both positive feedback (keep this!) and negative feedback (I needed this but
+ ox-arrays doesn't provide it) is welcome.
+copyright: (c) 2025 Tom Smeding, Mikolaj Konarski
+author: Tom Smeding, Mikolaj Konarski
+maintainer: Tom Smeding <xhackage@tomsmeding.com>
license: BSD-3-Clause
+category: Array, Tensors
build-type: Simple
+extra-doc-files: README.md CHANGELOG.md
+extra-source-files: cbits/arith_lists.h
+
+flag trace-wrappers
+ description:
+ Compile modules that define wrappers around the array methods that trace
+ their arguments and results. This is conditional on a flag because these
+ modules make documentation generation fail.
+ (@https://gitlab.haskell.org/ghc/ghc/-/issues/24964@ , should be fixed in
+ GHC 9.12)
+ default: False
+ manual: True
+
+flag nonportable-simd
+ description:
+ Assume the binary will be run on the same CPU as where it is built. Setting
+ this flag causes `-march=native` to be passed to the C compiler when
+ compiling arithmetic operations. The result is generally much faster
+ arithmetic operations, but the executable is much less portable to
+ different computers.
+ default: False
+ manual: True
+
+flag pedantic-c-warnings
+ description:
+ Compile embedded C code with a high warning level. Only useful for
+ ox-arrays developers.
+ default: False
+ manual: True
+
+flag default-show-instances
+ description:
+ Use default GHC-derived Show instances for arrays, shapes and indices. This
+ exposes the internal struct-of-arrays representation and is less readable,
+ but can be useful for ox-arrays debugging.
+ default: False
+ manual: True
+
+common basics
+ default-language: Haskell2010
+ ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
+
library
+ import: basics
exposed-modules:
- Data.Array.Mixed
+ -- put this module on top so ghci considers it the "main" module
Data.Array.Nested
- Data.Array.Nested.Internal
- Data.INat
+
+ Data.Array.Nested.Convert
+ Data.Array.Nested.Mixed
+ Data.Array.Nested.Mixed.Shape
+ Data.Array.Nested.Lemmas
+ Data.Array.Nested.Permutation
+ Data.Array.Nested.Ranked
+ Data.Array.Nested.Ranked.Base
+ Data.Array.Nested.Ranked.Shape
+ Data.Array.Nested.Shaped
+ Data.Array.Nested.Shaped.Base
+ Data.Array.Nested.Shaped.Shape
+ Data.Array.Nested.Types
+ Data.Array.Strided.Orthotope
+ Data.Array.XArray
+ Data.Bag
+
+ if flag(trace-wrappers)
+ exposed-modules:
+ Data.Array.Nested.Trace
+ Data.Array.Nested.Trace.TH
+ build-depends:
+ template-haskell
+ other-extensions: TemplateHaskell
+
+ if flag(default-show-instances)
+ cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES
+
build-depends:
- base >=4.18 && <4.20,
+ strided-array-ops,
+
+ base,
+ deepseq < 1.7,
ghc-typelits-knownnat,
- -- ghc-typelits-natnormalise,
- orthotope,
+ ghc-typelits-natnormalise,
+ orthotope < 0.2,
vector
hs-source-dirs: src
- default-language: Haskell2010
- ghc-options: -Wall
+
+library strided-array-ops
+ import: basics
+ exposed-modules:
+ Data.Array.Strided
+ Data.Array.Strided.Array
+ Data.Array.Strided.Arith
+ Data.Array.Strided.Arith.Internal
+ Data.Array.Strided.Arith.Internal.Foreign
+ Data.Array.Strided.Arith.Internal.Lists
+ Data.Array.Strided.Arith.Internal.Lists.TH
+ build-depends:
+ base >=4.18 && <4.22,
+ ghc-typelits-knownnat < 1,
+ ghc-typelits-natnormalise < 1,
+ template-haskell < 3,
+ vector < 0.14
+ hs-source-dirs: ops
+ c-sources: cbits/arith.c
+
+ cc-options: -O3 -std=c11
+ if flag(pedantic-c-warnings)
+ cc-options: -Wall -Wextra -pedantic
+ if flag(nonportable-simd)
+ cc-options: -march=native
+ elif arch(x86_64) || arch(i386)
+ -- hmatrix assumes sse2, so we can too
+ cc-options: -msse2
+
+ other-extensions: TemplateHaskell
test-suite test
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
+ other-modules:
+ Gen
+ Tests.C
+ Tests.Permutation
+ Util
build-depends:
ox-arrays,
- base
+ base,
+ bytestring,
+ ghc-typelits-knownnat,
+ ghc-typelits-natnormalise,
+ hedgehog,
+ orthotope,
+ random >= 1.3.0,
+ tasty,
+ tasty-hedgehog,
+ vector
hs-source-dirs: test
- default-language: Haskell2010
- ghc-options: -Wall
+
+test-suite example
+ import: basics
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ build-depends:
+ ox-arrays,
+ base
+ hs-source-dirs: example
+
+benchmark bench
+ import: basics
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ build-depends:
+ ox-arrays,
+ strided-array-ops,
+ base,
+ hmatrix,
+ orthotope,
+ tasty-bench,
+ vector
+ hs-source-dirs: bench
+
+source-repository head
+ type: git
+ location: https://git.tomsmeding.com/ox-arrays
diff --git a/release-hints.txt b/release-hints.txt
new file mode 100644
index 0000000..d300da0
--- /dev/null
+++ b/release-hints.txt
@@ -0,0 +1,3 @@
+- Temporarily enable -Wredundant-constraints
+ - Has too many false-positives to enable normally, but sometimes catches actual redundant constraints
+- Don't forget to rerun gentrace.sh
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
deleted file mode 100644
index 0351beb..0000000
--- a/src/Data/Array/Mixed.hs
+++ /dev/null
@@ -1,416 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed where
-
-import qualified Data.Array.RankedS as S
-import qualified Data.Array.Ranked as ORB
-import Data.Coerce
-import Data.Kind
-import Data.Proxy
-import Data.Type.Equality
-import qualified Data.Vector.Storable as VS
-import Foreign.Storable (Storable)
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-
-import Data.INat
-
-
--- | The 'SNat' pattern synonym is complete, but it doesn't have a
--- @COMPLETE@ pragma. This copy of it does.
-pattern GHC_SNat :: () => KnownNat n => SNat n
-pattern GHC_SNat = SNat
-{-# COMPLETE GHC_SNat #-}
-
-fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
-
-
--- | Type-level list append.
-type family l1 ++ l2 where
- '[] ++ l2 = l2
- (x : xs) ++ l2 = x : xs ++ l2
-
-lemAppNil :: l ++ '[] :~: l
-lemAppNil = unsafeCoerce Refl
-
-lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
-lemAppAssoc _ _ _ = unsafeCoerce Refl
-
-type IxX :: [Maybe Nat] -> Type -> Type
-data IxX sh i where
- ZIX :: IxX '[] i
- (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i
- (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i
-deriving instance Show i => Show (IxX sh i)
-deriving instance Eq i => Eq (IxX sh i)
-deriving instance Ord i => Ord (IxX sh i)
-deriving instance Functor (IxX sh)
-deriving instance Foldable (IxX sh)
-infixr 3 :.@
-infixr 3 :.?
-
-type IIxX sh = IxX sh Int
-
-type ShX :: [Maybe Nat] -> Type -> Type
-data ShX sh i where
- ZSX :: ShX '[] i
- (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i
- (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i
-deriving instance Show i => Show (ShX sh i)
-deriving instance Eq i => Eq (ShX sh i)
-deriving instance Ord i => Ord (ShX sh i)
-deriving instance Functor (ShX sh)
-deriving instance Foldable (ShX sh)
-infixr 3 :$@
-infixr 3 :$?
-
-type IShX sh = ShX sh Int
-
--- | The part of a shape that is statically known.
-type StaticShX :: [Maybe Nat] -> Type
-data StaticShX sh where
- ZKSX :: StaticShX '[]
- (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh)
- (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh)
-deriving instance Show (StaticShX sh)
-infixr 3 :!$@
-infixr 3 :!$?
-
--- | Evidence for the static part of a shape.
-type KnownShapeX :: [Maybe Nat] -> Constraint
-class KnownShapeX sh where
- knownShapeX :: StaticShX sh
-instance KnownShapeX '[] where
- knownShapeX = ZKSX
-instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = natSing :!$@ knownShapeX
-instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
- knownShapeX = () :!$? knownShapeX
-
-type family Rank sh where
- Rank '[] = Z
- Rank (_ : sh) = S (Rank sh)
-
-type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
- deriving (Show)
-
-zeroIxX :: StaticShX sh -> IIxX sh
-zeroIxX ZKSX = ZIX
-zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh
-zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh
-
-zeroIxX' :: IShX sh -> IIxX sh
-zeroIxX' ZSX = ZIX
-zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh
-zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh
-
--- This is a weird operation, so it has a long name
-completeShXzeros :: StaticShX sh -> IShX sh
-completeShXzeros ZKSX = ZSX
-completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh
-completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh
-
-ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
-ixAppend ZIX idx' = idx'
-ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'
-ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx'
-
-shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh')
-shAppend ZSX sh' = sh'
-shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh'
-shAppend (n :$? sh) sh' = n :$? shAppend sh sh'
-
-ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
-ixDrop sh ZIX = sh
-ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx
-ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx
-
-ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKSX sh' = sh'
-ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh'
-ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh'
-
-shapeSize :: IShX sh -> Int
-shapeSize ZSX = 1
-shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh
-shapeSize (n :$? sh) = n * shapeSize sh
-
--- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShape' :: StaticShX sh -> Maybe (IShX sh)
-ssxToShape' ZKSX = Just ZSX
-ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
-ssxToShape' (_ :!$? _) = Nothing
-
-fromLinearIdx :: IShX sh -> Int -> IIxX sh
-fromLinearIdx = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "fromLinearIdx: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
- where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$@ sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSNat' n
- in (locali :.@ idx, upi)
- go (n :$? sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` n
- in (locali :.? idx, upi)
-
-toLinearIdx :: IShX sh -> IIxX sh -> Int
-toLinearIdx = \sh i -> fst (go sh i)
- where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$@ sh) (i :.@ ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSNat' n * sz)
- go (n :$? sh) (i :.? ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, n * sz)
-
-enumShape :: IShX sh -> [IIxX sh]
-enumShape = \sh -> go sh id []
- where
- go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZSX f = (f ZIX :)
- go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]]
- go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
-
-shapeLshape :: IShX sh -> S.ShapeL
-shapeLshape ZSX = []
-shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh
-shapeLshape (n :$? sh) = n : shapeLshape sh
-
-ssxLength :: StaticShX sh -> Int
-ssxLength ZKSX = 0
-ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh
-ssxLength (_ :!$? ssh) = 1 + ssxLength ssh
-
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKSX = []
-ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
-ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
-
-lemRankApp :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
-lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
-
-lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
-lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-
-lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZSX = Dict
-lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
-
-lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZKSX = Dict
-lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-
-lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
-lemKnownShapeX ZKSX = Dict
-lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-
-lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
-lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'
- | Dict <- lemAppKnownShapeX ssh ssh'
- = Dict
-lemAppKnownShapeX (() :!$? ssh) ssh'
- | Dict <- lemAppKnownShapeX ssh ssh'
- = Dict
-
-shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh
-shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
- where
- go :: StaticShX sh' -> [Int] -> IShX sh'
- go ZKSX [] = ZSX
- go (n :!$@ ssh) (_ : l) = n :$@ go ssh l
- go (() :!$? ssh) (n : l) = n :$? go ssh l
- go _ _ = error "Invalid shapeL"
-
-fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
-fromVector sh v
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- = XArray (S.fromVector (shapeLshape sh) v)
-
-toVector :: Storable a => XArray sh a -> VS.Vector a
-toVector (XArray arr) = S.toVector arr
-
-scalar :: Storable a => a -> XArray '[] a
-scalar = XArray . S.scalar
-
-unScalar :: Storable a => XArray '[] a -> a
-unScalar (XArray a) = S.unScalar a
-
-constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
-constant sh x
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- = XArray (S.constant (shapeLshape sh) x)
-
-generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
-generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
-
--- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownINatRank sh =
--- XArray . S.fromVector (shapeLshape sh)
--- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
-
-indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
-indexPartial (XArray arr) ZIX = XArray arr
-indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx
-indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx
-
-index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
-index xarr i
- | Refl <- lemAppNil @sh
- = let XArray arr' = indexPartial xarr i :: XArray '[] a
- in S.unScalar arr'
-
-type family AddMaybe n m where
- AddMaybe Nothing _ = Nothing
- AddMaybe (Just _) Nothing = Nothing
- AddMaybe (Just n) (Just m) = Just (n + m)
-
-append :: forall n m sh a. (KnownShapeX sh, Storable a)
- => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
-append (XArray a) (XArray b)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- = XArray (S.append a b)
-
-rerank :: forall sh sh1 sh2 a b.
- (Storable a, Storable b)
- => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
- -> (XArray sh1 a -> XArray sh2 b)
- -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
-rerank ssh ssh1 ssh2 f (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
- (\a -> unXArray (f (XArray a)))
- arr)
- where
- unXArray (XArray a) = a
-
-rerankTop :: forall sh sh1 sh2 a b.
- (Storable a, Storable b)
- => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
- -> (XArray sh1 a -> XArray sh2 b)
- -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
-rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
-
-rerank2 :: forall sh sh1 sh2 a b c.
- (Storable a, Storable b, Storable c)
- => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
- -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
- -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
-rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
- (\a b -> unXArray (f (XArray a) (XArray b)))
- arr1 arr2)
- where
- unXArray (XArray a) = a
-
--- | The list argument gives indices into the original dimension list.
-transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
-transpose perm (XArray arr)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- = XArray (S.transpose perm arr)
-
-transpose2 :: forall sh1 sh2 a.
- StaticShX sh1 -> StaticShX sh2
- -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
-transpose2 ssh1 ssh2 (XArray arr)
- | Refl <- lemRankApp ssh1 ssh2
- , Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2)))
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1)))
- , Refl <- lemRankAppComm ssh1 ssh2
- , let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-
-sumFull :: (Storable a, Num a) => XArray sh a -> a
-sumFull (XArray arr) = S.sumA arr
-
-sumInner :: forall sh sh' a. (Storable a, Num a)
- => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
-sumInner ssh ssh'
- | Refl <- lemAppNil @sh
- = rerank ssh ssh' ZKSX (scalar . sumFull)
-
-sumOuter :: forall sh sh' a. (Storable a, Num a)
- => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
-sumOuter ssh ssh'
- | Refl <- lemAppNil @sh
- = sumInner ssh' ssh . transpose2 ssh ssh'
-
-fromList1 :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromList1 ssh l
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
- = case ssh of
- m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
- error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (natVal m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
-
-toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
-
--- | Throws if the given shape is not, in fact, empty.
-empty :: forall sh a. Storable a => IShX sh -> XArray sh a
-empty sh
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- = XArray (S.constant (shapeLshape sh)
- (error "Data.Array.Mixed.empty: shape was not empty"))
-
-slice :: [(Int, Int)] -> XArray sh a -> XArray sh a
-slice ivs (XArray arr) = XArray (S.slice ivs arr)
-
-rev1 :: XArray (n : sh) a -> XArray (n : sh) a
-rev1 (XArray arr) = XArray (S.rev [0] arr)
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index ec5f0b5..c3635e9 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -2,52 +2,126 @@
{-# LANGUAGE PatternSynonyms #-}
module Data.Array.Nested (
-- * Ranked arrays
- Ranked,
- ListR(ZR, (:::)), knownListR,
- IxR(.., ZIR, (:.:)), IIxR, knownIxR,
- ShR(.., ZSR, (:$:)), knownShR,
- rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
- rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
- rconstant, rfromList, rfromList1, rtoList, rtoList1,
- rslice, rrev1,
+ Ranked(Ranked),
+ ListR(ZR, (:::)),
+ IxR(.., ZIR, (:.:)), IIxR,
+ ShR(.., ZSR, (:$:)), IShR,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,
+ rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
+ remptyArray,
+ rrerank,
+ rreplicate, rreplicateScal,
+ rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear,
+ rtoList, rtoListOuter, rtoListLinear,
+ rslice, rrev1, rreshape, rflatten, riota,
+ rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
+ rnest, runNest, rzip, runzip,
-- ** Lifting orthotope operations to 'Ranked' arrays
- rlift,
+ rlift, rlift2,
+ -- ** Conversions
+ rtoXArrayPrim, rfromXArrayPrim,
+ rtoMixed, rcastToMixed, rcastToShaped,
+ rfromOrthotope, rtoOrthotope,
+ -- ** Additional arithmetic operations
+ --
+ -- $integralRealFloat
+ rquotArray, rremArray, ratan2Array,
-- * Shaped arrays
- Shaped,
+ Shaped(Shaped),
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
- ShS(..), KnownShape(..),
- sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
+ ShS(.., ZSS, (:$$)), KnownShS(..),
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
- sconstant, sfromList, sfromList1, stoList, stoList1,
- sslice, srev1,
+ -- TODO: sconcat? What should its type be?
+ semptyArray,
+ srerank,
+ sreplicate, sreplicateScal,
+ sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear,
+ stoList, stoListOuter, stoListLinear,
+ sslice, srev1, sreshape, sflatten, siota,
+ sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
+ snest, sunNest, szip, sunzip,
-- ** Lifting orthotope operations to 'Shaped' arrays
- slift,
+ slift, slift2,
+ -- ** Conversions
+ stoXArrayPrim, sfromXArrayPrim,
+ stoMixed, scastToMixed, stoRanked,
+ sfromOrthotope, stoOrthotope,
+ -- ** Additional arithmetic operations
+ --
+ -- $integralRealFloat
+ squotArray, sremArray, satan2Array,
-- * Mixed arrays
Mixed,
- IxX(..), IIxX,
- KnownShapeX(..), StaticShX(..),
- mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
- mconstant, mfromList, mtoList, mslice, mrev1,
+ ListX(ZX, (::%)),
+ IxX(.., ZIX, (:.%)), IIxX,
+ ShX(.., ZSX, (:$%)), KnownShX(..), IShX,
+ StaticShX(.., ZKX, (:!%)),
+ SMayNat(..),
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,
+ mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
+ memptyArray,
+ mrerank,
+ mreplicate, mreplicateScal,
+ mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear,
+ mtoList, mtoListOuter, mtoListLinear,
+ mslice, mrev1, mreshape, mflatten, miota,
+ mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
+ mnest, munNest, mzip, munzip,
+ -- ** Lifting orthotope operations to 'Mixed' arrays
+ mlift, mlift2,
+ -- ** Conversions
+ mtoXArrayPrim, mfromXArrayPrim,
+ mcast,
+ mcastToShaped, mtoRanked,
+ convert, Conversion(..),
+ -- ** Additional arithmetic operations
+ --
+ -- $integralRealFloat
+ mquotArray, mremArray, matan2Array,
-- * Array elements
- Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2),
+ Elt,
PrimElt,
Primitive(..),
-
- -- * Inductive natural numbers
- module Data.INat,
+ KnownElt,
-- * Further utilities / re-exports
type (++),
Storable,
+ SNat, pattern SNat,
+ pattern SZ, pattern SS,
+ Perm(..),
+ IsPermutation,
+ KnownPerm(..),
+ NumElt, IntElt, FloatElt,
+ Rank, Product,
+ Replicate,
+ MapJust,
) where
-import Prelude hiding (mappend)
+import Prelude hiding (mappend, mconcat)
-import Data.Array.Mixed
-import Data.Array.Nested.Internal
-import Data.INat
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
import Foreign.Storable
+import GHC.TypeLits
+
+-- $integralRealFloat
+--
+-- These functions are separate top-level functions, and not exposed in
+-- instances for 'RealFloat' and 'Integral', because those classes include a
+-- variety of other functions that make no sense for arrays.
+-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but
+-- having 'Num', 'Fractional' and 'Floating' available is just too useful.
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
new file mode 100644
index 0000000..2438f68
--- /dev/null
+++ b/src/Data/Array/Nested/Convert.hs
@@ -0,0 +1,333 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+module Data.Array.Nested.Convert (
+ -- * Shape\/index\/list casting functions
+ -- ** To ranked
+ ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
+ listrCast, ixrCast, shrCast,
+ -- ** To shaped
+ ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
+ ixsCast,
+ -- ** To mixed
+ ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
+ ixxCast, shxCast, shxCast',
+
+ -- * Array conversions
+ convert,
+ Conversion(..),
+
+ -- * Special cases of array conversions
+ --
+ -- | These functions can all be implemented using 'convert' in some way,
+ -- but some have fewer constraints.
+ rtoMixed, rcastToMixed, rcastToShaped,
+ stoMixed, scastToMixed, stoRanked,
+ mcast, mcastToShaped, mtoRanked,
+) where
+
+import Control.Category
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped.Base
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+-- * Shape or index or list casting functions
+
+-- * To ranked
+
+ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
+ixrFromIxS ZIS = ZIR
+ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX ZIX = ZIR
+ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
+-- shrFromShX re-exported
+-- shrFromShX2 re-exported
+-- listrCast re-exported
+-- ixrCast re-exported
+-- shrCast re-exported
+
+-- * To shaped
+
+-- TODO: these take a ShS because there are KnownNats inside IxS.
+
+ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
+ixsFromIxR ZSS ZIR = ZIS
+ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+ixsFromIxR _ _ = error "unreachable"
+
+-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
+-- following, but more efficient:
+--
+-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx)
+ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i
+ixsFromIxR' ZSS ZIR = ZIS
+ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
+ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
+
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX ZSS ZIX = ZIS
+ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+
+-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
+-- the following, but more efficient:
+--
+-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx)
+ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i
+ixsFromIxX' ZSS ZIX = ZIS
+ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx
+ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank"
+
+-- | Produce an existential 'ShS' from an 'IShR'.
+withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r
+withShsFromShR ZSR k = k ZSS
+withShsFromShR (n :$: sh) k =
+ withShsFromShR sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
+
+-- shsFromShX re-exported
+
+-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
+-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
+withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r
+withShsFromShX ZSX k = k ZSS
+withShsFromShX (SKnown sn@SNat :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ k (sn :$$ sh')
+withShsFromShX (SUnknown n :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+
+shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
+shsFromSSX = shsFromShX Prelude.. shxFromSSX
+
+-- ixsCast re-exported
+
+-- * To mixed
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR ZIR = ZIX
+ixxFromIxR (n :.: (idx :: IxR m i)) =
+ castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixxFromIxR idx)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS ZIS = ZIX
+ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+
+shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
+shxFromShR ZSR = ZSX
+shxFromShR (n :$: (idx :: ShR m i)) =
+ castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shxFromShR idx)
+
+shxFromShS :: ShS sh -> IShX (MapJust sh)
+shxFromShS ZSS = ZSX
+shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+
+-- ixxCast re-exported
+-- shxCast re-exported
+-- shxCast' re-exported
+
+
+-- * Array conversions
+
+-- | The constructors that perform runtime shape checking are marked with a
+-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types
+-- ensure that the shapes are already compatible. To convert between 'Ranked'
+-- and 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Conversion' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'XArray's. This leads to the inclusion of some operations that do
+-- not look like simple conversions (casts) at first glance, like 'ConvZip'.
+--
+-- /Note/: Haddock gleefully renames type variables in constructors so that
+-- they match the data type head as much as possible. See the source for a more
+-- readable presentation of this data type.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Conversion (Mixed sh a) (Shaped sh' a)
+
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
+ where
+ -- The 'esh' is the extension shape: the conversion happens under a whole
+ -- bunch of additional dimensions that it does not touch. These dimensions
+ -- are 'esh'.
+ -- The strategy is to unwind step-by-step to a large Mixed array, and to
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
+ x))
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go ConvX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x)
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go ConvZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go ConvUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
+
+ lemRankAppRankEq :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
+ lemRankAppRankEq _ _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
+
+
+-- * Special cases of array conversions
+
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked = convert ConvXR
+
+rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
+rtoMixed (Ranked arr) = arr
+
+-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
+-- compatibility check.
+rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
+rcastToMixed sshx rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank rarr)
+ = mcast sshx arr
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => ShS sh' -> Mixed sh a -> Shaped sh' a
+mcastToShaped targetsh = convert (ConvXS' targetsh)
+
+stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
+stoMixed (Shaped arr) = arr
+
+-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
+-- compatibility check.
+scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => StaticShX sh' -> Shaped sh a -> Mixed sh' a
+scastToMixed sshx sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mcast sshx arr
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr
+
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped targetsh arr
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
deleted file mode 100644
index 350eb6f..0000000
--- a/src/Data/Array/Nested/Internal.hs
+++ /dev/null
@@ -1,1294 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DefaultSignatures #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE RoleAnnotations #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-
-{-|
-TODO:
-* We should be more consistent in whether functions take a 'StaticShX'
- argument or a 'KnownShapeX' constraint.
-
-* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
- being that we need to do induction over the former, but the latter need to be
- able to get large.
-
--}
-
-module Data.Array.Nested.Internal where
-
-import Prelude hiding (mappend)
-
-import Control.Monad (forM_, when)
-import Control.Monad.ST
-import qualified Data.Array.RankedS as S
-import Data.Bifunctor (first)
-import Data.Coerce (coerce, Coercible)
-import Data.Foldable (toList)
-import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
-import Foreign.Storable (Storable)
-import GHC.TypeLits
-
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)
-import qualified Data.Array.Mixed as X
-import Data.INat
-
-
--- Invariant in the API
--- ====================
---
--- In the underlying XArray, there is some shape for elements of an empty
--- array. For example, for this array:
---
--- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
--- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
---
--- the two underlying XArrays have a shape, and those shapes might be anything.
--- The invariant is that these element shapes are unobservable in the API.
--- (This is possible because you ought to not be able to get to such an element
--- without indexing out of bounds.)
---
--- Note, though, that the converse situation may arise: the outer array might
--- be nonempty but then the inner arrays might. This is fine, an invariant only
--- applies if the _outer_ array is empty.
---
--- TODO: can we enforce that the elements of an empty (nested) array have
--- all-zero shape?
--- -> no, because mlift and also any kind of internals probing from outsiders
-
-
--- Primitive element types
--- =======================
---
--- There are a few primitive element types; arrays containing elements of such
--- type are a newtype over an XArray, which it itself a newtype over a Vector.
--- Unfortunately, the setup of the library requires us to list these primitive
--- element types multiple times; to aid in extending the list, all these lists
--- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-
-
-type family Replicate n a where
- Replicate Z a = '[]
- Replicate (S n) a = a : Replicate n a
-
-type family MapJust l where
- MapJust '[] = '[]
- MapJust (x : xs) = Just x : MapJust xs
-
-lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
- where
- go :: SINat m -> StaticShX (Replicate m Nothing)
- go SZ = ZKSX
- go (SS n) = () :!$? go n
-
-lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (inatSing @n)
- where
- go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
- go SZ = Refl
- go (SS n) | Refl <- go n = Refl
-
-lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a
- -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (inatSing @n)
- where
- go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a
- go SZ = Refl
- go (SS n) | Refl <- go n = Refl
-
-shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
-shAppSplit _ ZKSX idx = (ZSX, idx)
-shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx)
-shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)
-
-
--- | Wrapper type used as a tag to attach instances on. The instances on arrays
--- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
--- of scalars; this means that if @orthotope@ supports an element type @T@ that
--- this library does not (directly), it may just work if you use an array of
--- @'Primitive' T@ instead.
-newtype Primitive a = Primitive a
-
--- | Element types that are primitive; arrays of these types are just a newtype
--- wrapper over an array.
-class PrimElt a where
- fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
- toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
-
- default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
- fromPrimitive = coerce
-
- default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
- toPrimitive = coerce
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-instance PrimElt Int
-instance PrimElt Double
-instance PrimElt ()
-
-
--- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
--- over product-typed elements using a data family so that the full array is
--- always in struct-of-arrays format.
---
--- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
--- dimension permutations (e.g. 'mtranspose') are typically free.
---
--- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
--- class.
-type Mixed :: [Maybe Nat] -> Type -> Type
-data family Mixed sh a
--- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
--- that you're not supposed to see. In particular, you might see (nonempty)
--- sizes of the elements of an empty array, which is information that should
--- ostensibly not exist; the full array is still empty.
-
-newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
- deriving (Show)
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Int = M_Int (XArray sh Int)
- deriving (Show)
-newtype instance Mixed sh Double = M_Double (XArray sh Double)
- deriving (Show)
-newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
- deriving (Show)
--- etc.
-
-data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b)
-deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
--- etc.
-
-newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
-deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-
-
--- | Internal helper data family mirroring 'Mixed' that consists of mutable
--- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
-data family MixedVecs s sh a
-
-newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
-newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
-newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
--- etc.
-
-data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
--- etc.
-
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-
-
--- | Tree giving the shape of every array component.
-type family ShapeTree a where
- ShapeTree (Primitive _) = ()
- -- [PRIMITIVE ELEMENT TYPES LIST]
- ShapeTree Int = ()
- ShapeTree Double = ()
- ShapeTree () = ()
-
- ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a)
- ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
- ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-
-
--- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
--- a@; see the documentation for 'Primitive' for more details.
-class Elt a where
- -- ====== PUBLIC METHODS ====== --
-
- mshape :: KnownShapeX sh => Mixed sh a -> IShX sh
- mindex :: Mixed sh a -> IIxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
- mscalar :: a -> Mixed '[] a
-
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- If you want a single-dimensional array from your list, map 'mscalar'
- -- first.
- mfromList1 :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a
-
- mtoList1 :: Mixed (n : sh) a -> [Mixed sh a]
-
- -- | Note: this library makes no particular guarantees about the shapes of
- -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
- -- full 'XArray' and as such you can distinguish different empty arrays by
- -- the "shapes" of their elements. This information is meaningless, so you
- -- should not use it.
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a
-
- -- | See the documentation for 'mlift'.
- mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
-
- -- ====== PRIVATE METHODS ====== --
-
- -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IShX sh -> Mixed sh a
-
- mshapeTree :: a -> ShapeTree a
-
- mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
-
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
-
- mshowShapeTree :: Proxy a -> ShapeTree a -> String
-
- -- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents.
- mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
-
- mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-
- -- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-
--- Arrays of scalars are basically just arrays of scalars.
-instance Storable a => Elt (Primitive a) where
- mshape (M_Primitive a) = X.shape a
- mindex (M_Primitive a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
- mscalar (Primitive x) = M_Primitive (X.scalar x)
- mfromList1 l = M_Primitive (X.fromList1 knownShapeX (coerce (toList l)))
- mtoList1 (M_Primitive arr) = coerce (X.toList1 arr)
-
- mlift :: forall sh1 sh2.
- (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
- mlift f (M_Primitive a)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
- = M_Primitive (f Proxy a)
-
- mlift2 :: forall sh1 sh2 sh3.
- (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
- mlift2 f (M_Primitive a) (M_Primitive b)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
- , Refl <- X.lemAppNil @sh3
- = M_Primitive (f Proxy a b)
-
- memptyArray sh = M_Primitive (X.empty sh)
- mshapeTree _ = ()
- mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
- mshowShapeTree _ () = "()"
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
- mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
-
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
- :: forall sh' sh s. KnownShapeX sh'
- => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
- let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))
- VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
-
- mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Int instance Elt Int
-deriving via Primitive Double instance Elt Double
-deriving via Primitive () instance Elt ()
-
--- Arrays of pairs are pairs of arrays.
-instance (Elt a, Elt b) => Elt (a, b) where
- mshape (M_Tup2 a _) = mshape a
- mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
- mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
- mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromList1 l = M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l))
- (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l))
- mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b)
- mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
- mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
-
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
- mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
- mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
- mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
- mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
- mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
- mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
-
--- Arrays of arrays are just arrays, but with more dimensions.
-instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
- -- TODO: this is quadratic in the nesting depth because it repeatedly
- -- truncates the shape vector to one a little shorter. Fix with a
- -- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh
- mshape (M_Nest arr)
- | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
- = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
-
- mindex (M_Nest arr) i = mindexPartial arr i
-
- mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest arr) i
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
-
- mscalar = M_Nest
-
- mfromList1 :: forall n sh. KnownShapeX (n : sh)
- => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a)
- mfromList1 l
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh'))
- = M_Nest (mfromList1 (coerce l))
-
- mtoList1 (M_Nest arr) = coerce (mtoList1 arr)
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift f (M_Nest arr)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- = M_Nest (mlift f' arr)
- where
- f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
- f' _
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
- = f (Proxy @(sh' ++ shT))
-
- mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
- => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
- mlift2 f (M_Nest arr1) (M_Nest arr2)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh'))
- = M_Nest (mlift2 f' arr1 arr2)
- where
- f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
- f' _
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
- = f (Proxy @(sh' ++ shT))
-
- memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh'))))
-
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh'))))
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example))
- (mindex example (X.zeroIxX (knownShapeX @sh')))
- where
- sh' = mshape example
-
- mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a)
-
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs
-
- mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs
-
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs
-
-
--- | Create an array given a size and a function that computes the element at a
--- given index.
---
--- __WARNING__: It is required that every @a@ returned by the argument to
--- 'mgenerate' has the same shape. For example, the following will throw a
--- runtime error:
---
--- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
--- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
--- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
--- > ...
---
--- because the size of the inner 'mgenerate' is not always the same (it depends
--- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
--- the entire hierarchy (after distributing out tuples) must be a rectangular
--- array. The type of 'mgenerate' allows this requirement to be broken very
--- easily, hence the runtime check.
-mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case X.enumShape sh of
- [] -> memptyArray sh
- firstidx : restidxs ->
- let firstelem = f (X.zeroIxX' sh)
- shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
- then memptyArray sh
- else runST $ do
- vecs <- mvecsUnsafeNew sh firstelem
- mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
- forM_ restidxs $ \idx -> do
- let val = f idx
- when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
- error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
- mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
-
-mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
-mtranspose perm =
- mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
- (X.transpose perm))
-
-mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
- => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
-mappend = mlift2 go
- where go :: forall sh' b. (KnownShapeX sh', Storable b)
- => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
- go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
-
-mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVectorP sh v = M_Primitive (X.fromVector sh v)
-
-mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a
-mfromVector sh v = fromPrimitive (mfromVectorP sh v)
-
-mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
-mtoVectorP (M_Primitive v) = X.toVector v
-
-mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a
-mtoVector arr = mtoVectorP (coerce toPrimitive arr)
-
-mfromList :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a
-mfromList = mfromList1 . fmap mscalar
-
-mtoList :: Elt a => Mixed '[n] a -> [a]
-mtoList = map munScalar . mtoList1
-
-munScalar :: Elt a => Mixed '[] a -> a
-munScalar arr = mindex arr ZIX
-
-mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a)
-mconstantP sh x = M_Primitive (X.constant sh x)
-
-mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a)
- => IShX sh -> a -> Mixed sh a
-mconstant sh x = fromPrimitive (mconstantP sh x)
-
-mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
-mslice ivs = mlift $ \_ -> X.slice ivs
-
-mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 = mlift $ \_ -> X.rev1
-
-mliftPrim :: (KnownShapeX sh, Storable a)
- => (a -> a)
- -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
-mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr))
-
-mliftPrim2 :: (KnownShapeX sh, Storable a)
- => (a -> a -> a)
- -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
-mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) =
- M_Primitive (X.XArray (S.zipWithA f arr1 arr2))
-
-instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where
- (+) = mliftPrim2 (+)
- (-) = mliftPrim2 (-)
- (*) = mliftPrim2 (*)
- negate = mliftPrim negate
- abs = mliftPrim abs
- signum = mliftPrim signum
- fromInteger n =
- case X.ssxToShape' (knownShapeX @sh) of
- Just sh -> M_Primitive (X.constant sh (fromInteger n))
- Nothing -> error "Data.Array.Nested.fromIntegral: \
- \Unknown components in shape, use explicit mconstant"
-
--- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)
-deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double)
-
-
--- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'INat'.
---
--- Valid elements of a ranked arrays are described by the 'Elt' type class.
--- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
--- supported (and are represented as a single, flattened, struct-of-arrays
--- array internally).
---
--- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a
--- type-level natural that supports induction.
---
--- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: INat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
-
--- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's. Note that
--- these are "GHC.TypeLits" naturals, because we do not need induction over
--- them and we want very large arrays to be possible.
---
--- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
--- and 'Shaped' itself is again an instance of 'Elt' as well.
---
--- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
-deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a))
-deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
-
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance (Elt a, KnownINat n) => Elt (Ranked n a) where
- mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
- mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mscalar (Ranked x) = M_Ranked (M_Nest x)
-
- mfromList1 :: forall m sh. KnownShapeX (m : sh)
- => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a)
- mfromList1 l
- | Dict <- lemKnownReplicate (Proxy @n)
- = M_Ranked (mfromList1 (coerce l))
-
- mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoList1 (M_Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr)
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift f (M_Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift f arr
-
- mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 f (M_Ranked arr1) (M_Ranked arr2)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 f arr1 arr2
-
- memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArray i
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArray i
-
- mshapeTree (Ranked arr)
- | Refl <- lemRankReplicate (Proxy @n)
- , Dict <- lemKnownReplicate (Proxy @n)
- = first shCvtXR (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (Proxy @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
- => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-
--- | The shape of a shape-typed array given as a list of 'SNat' values.
-data ShS sh where
- ZSS :: ShS '[]
- (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)
-deriving instance Show (ShS sh)
-deriving instance Eq (ShS sh)
-deriving instance Ord (ShS sh)
-infixr 3 :$$
-
--- | A statically-known shape of a shape-typed array.
-class KnownShape sh where knownShape :: ShS sh
-instance KnownShape '[] where knownShape = ZSS
-instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = natSing :$$ knownShape
-
-sshapeKnown :: ShS sh -> Dict KnownShape sh
-sshapeKnown ZSS = Dict
-sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict
-
-lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
-lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKSX
- go (n :$$ sh) = n :!$@ go sh
-
-lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustPlusApp _ _ = go (knownShape @sh1)
- where
- go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
- go ZSS = Refl
- go (_ :$$ sh) | Refl <- go sh = Refl
-
-instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
- mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
- mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mscalar (Shaped x) = M_Shaped (M_Nest x)
-
- mfromList1 :: forall n sh'. KnownShapeX (n : sh')
- => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a)
- mfromList1 l
- | Dict <- lemKnownMapJust (Proxy @sh)
- = M_Shaped (mfromList1 (coerce l))
-
- mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoList1 (M_Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr)
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift f (M_Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift f arr
-
- mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 f (M_Shaped arr1) (M_Shaped arr2)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 f arr1 arr2
-
- memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArray i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArray i
-
- mshapeTree (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = first (shCvtXS (knownShape @sh)) (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
-
--- Utility functions to satisfy the type checker sometimes
-
-rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
-rewriteMixed Refl x = x
-
-
--- ====== API OF RANKED ARRAYS ====== --
-
-arithPromoteRanked :: forall n a. KnownINat n
- => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
- -> Ranked n a -> Ranked n a
-arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
-
-arithPromoteRanked2 :: forall n a. KnownINat n
- => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
- -> Ranked n a -> Ranked n a -> Ranked n a
-arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
-
-instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
- (+) = arithPromoteRanked2 (+)
- (-) = arithPromoteRanked2 (-)
- (*) = arithPromoteRanked2 (*)
- negate = arithPromoteRanked negate
- abs = arithPromoteRanked abs
- signum = arithPromoteRanked signum
- fromInteger n = case inatSing @n of
- SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
- SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
- \Rank non-zero, use explicit mconstant"
-
--- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
-deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
-
-type role ListR nominal representational
-type ListR :: INat -> Type -> Type
-data ListR n i where
- ZR :: ListR Z i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i
-deriving instance Show i => Show (ListR n i)
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-infixr 3 :::
-
-instance Foldable (ListR n) where
- foldr f z l = foldr f z (listRToList l)
-
-listRToList :: ListR n i -> [i]
-listRToList ZR = []
-listRToList (i ::: is) = i : listRToList is
-
-knownListR :: ListR n i -> Dict KnownINat n
-knownListR ZR = Dict
-knownListR (_ ::: l) | Dict <- knownListR l = Dict
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: INat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ Z => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (S n ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
- where i :.: IxR sh = IxR (i ::: sh)
-{-# COMPLETE ZIR, (:.:) #-}
-infixr 3 :.:
-
-data UnconsIxRRes i n1 =
- forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i
-unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1)
-unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i)
-unconsIxR (IxR ZR) = Nothing
-
-type IIxR n = IxR n Int
-
-knownIxR :: IxR n i -> Dict KnownINat n
-knownIxR (IxR sh) = knownListR sh
-
-type role ShR nominal representational
-type ShR :: INat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-type IShR n = ShR n Int
-
-pattern ZSR :: forall n i. () => n ~ Z => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (S n ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
- where i :$: (ShR sh) = ShR (i ::: sh)
-{-# COMPLETE ZSR, (:$:) #-}
-infixr 3 :$:
-
-data UnconsShRRes i n1 =
- forall n. S n ~ n1 => UnconsShRRes (ShR n i) i
-unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1)
-unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i)
-unconsShR (ShR ZR) = Nothing
-
-knownShR :: ShR n i -> Dict KnownINat n
-knownShR (ShR sh) = knownListR sh
-
-zeroIxR :: SINat n -> IIxR n
-zeroIxR SZ = ZIR
-zeroIxR (SS n) = 0 :.: zeroIxR n
-
-ixCvtXR :: IIxX sh -> IIxR (X.Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx
-ixCvtXR (n :.? idx) = n :.: ixCvtXR idx
-
-shCvtXR :: IShX sh -> IShR (X.Rank sh)
-shCvtXR ZSX = ZSR
-shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx
-shCvtXR (n :$? idx) = n :$: shCvtXR idx
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: idx) = n :$? shCvtRX idx
-
-shapeSizeR :: IShR n -> Int
-shapeSizeR ZSR = 1
-shapeSizeR (n :$: sh) = n * shapeSizeR sh
-
-
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n
-rshape (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemRankReplicate (Proxy @n)
- = shCvtXR (mshape arr)
-
-rindex :: Elt a => Ranked n a -> IIxR n -> a
-rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
-rindexPartial (Ranked arr) idx =
- Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
- (ixCvtRX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a
-rgenerate sh f
- | Dict <- knownShR sh
- , Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemRankReplicate (Proxy @n)
- = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-
--- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
- => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a
-rlift f (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n2)
- = Ranked (mlift f arr)
-
-rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownINat n)
- => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked
- . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
- . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
- $ arr
-
-rsumOuter1 :: forall n a.
- (Storable a, Num a, PrimElt a, KnownINat n)
- => Ranked (S n) a -> Ranked n a
-rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
-
-rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
-rtranspose perm (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mtranspose perm arr)
-
-rappend :: forall n a. (KnownINat n, Elt a)
- => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
-rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
-
-rscalar :: Elt a => a -> Ranked I0 a
-rscalar x = Ranked (mscalar x)
-
-rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
-rfromVectorP sh v
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromVectorP (shCvtRX sh) v)
-
-rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
-rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
-
-rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
-rtoVectorP = coerce mtoVectorP
-
-rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
-rtoVector = coerce mtoVector
-
-rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
-rfromList1 l
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromList1 (coerce l))
-
-rfromList :: Elt a => NonEmpty a -> Ranked I1 a
-rfromList = Ranked . mfromList1 . fmap mscalar
-
-rtoList :: Elt a => Ranked (S n) a -> [Ranked n a]
-rtoList (Ranked arr) = coerce (mtoList1 arr)
-
-rtoList1 :: Elt a => Ranked I1 a -> [a]
-rtoList1 = map runScalar . rtoList
-
-runScalar :: Elt a => Ranked I0 a -> a
-runScalar arr = rindex arr ZIR
-
-rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
-rconstantP sh x
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mconstantP (shCvtRX sh) x)
-
-rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
- => IShR n -> a -> Ranked n a
-rconstant sh x = coerce fromPrimitive (rconstantP sh x)
-
-rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
-rslice ivs = rlift $ \_ -> X.slice ivs
-
-rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a
-rrev1 = rlift $ \_ -> X.rev1
-
-
--- ====== API OF SHAPED ARRAYS ====== --
-
-arithPromoteShaped :: forall sh a. KnownShape sh
- => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a
-arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce
-
-arithPromoteShaped2 :: forall sh a. KnownShape sh
- => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a -> Shaped sh a
-arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce
-
-instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where
- (+) = arithPromoteShaped2 (+)
- (-) = arithPromoteShaped2 (-)
- (*) = arithPromoteShaped2 (*)
- negate = arithPromoteShaped negate
- abs = arithPromoteShaped abs
- signum = arithPromoteShaped signum
- fromInteger n = sconstantP (fromInteger n)
-
--- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
-deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
-
-type role ListS nominal representational
-type ListS :: [Nat] -> Type -> Type
-data ListS sh i where
- ZS :: ListS '[] i
- (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i
-deriving instance Show i => Show (ListS sh i)
-deriving instance Eq i => Eq (ListS sh i)
-deriving instance Ord i => Ord (ListS sh i)
-deriving instance Functor (ListS sh)
-infixr 3 ::$
-
-instance Foldable (ListS sh) where
- foldr f z l = foldr f z (listSToList l)
-
-listSToList :: ListS sh i -> [i]
-listSToList ZS = []
-listSToList (i ::$ is) = i : listSToList is
-
--- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\"). Note that because the shape of a
--- shape-typed array is known statically, you can also retrieve the array shape
--- from a 'KnownShape' dictionary.
-type role IxS nominal representational
-type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
-
-pattern (:.$)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i))
- where i :.$ IxS shl = IxS (i ::$ shl)
-{-# COMPLETE ZIS, (:.$) #-}
-infixr 3 :.$
-
-data UnconsIxSRes i sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS sh i) i
-unconsIxS :: IxS sh1 i -> Maybe (UnconsIxSRes i sh1)
-unconsIxS (IxS (i ::$ shl')) = Just (UnconsIxSRes (IxS shl') i)
-unconsIxS (IxS ZS) = Nothing
-
-type IIxS sh = IxS sh Int
-
-data UnconsShSRes sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n)
-unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1)
-unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i)
-unconsShS ZSS = Nothing
-
-zeroIxS :: ShS sh -> IIxS sh
-zeroIxS ZSS = ZIS
-zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh
-
-ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx
-
-shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh
-shCvtXS ZSS ZSX = ZSS
-shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx
-
-ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
-ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh
-
-shCvtSX :: ShS sh -> IShX (MapJust sh)
-shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = n :$@ shCvtSX sh
-
-shapeSizeS :: ShS sh -> Int
-shapeSizeS ZSS = 1
-shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh
-
-
--- | This does not touch the passed array, all information comes from 'KnownShape'.
-sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh
-sshape _ = knownShape @sh
-
-sindex :: Elt a => Shaped sh a -> IIxS sh -> a
-sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-
-sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
-sindexPartial (Shaped arr) idx =
- Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
- (ixCvtSX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a
-sgenerate f
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh)))
-
--- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
- => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a
-slift f (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh2)
- = Shaped (mlift f arr)
-
-ssumOuter1P :: forall sh n a.
- (Storable a, Num a, KnownNat n, KnownShape sh)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped
- . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
- . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh))
- . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
- $ arr
-
-ssumOuter1 :: forall sh n a.
- (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
-
-stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
-stranspose perm (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mtranspose perm arr)
-
-sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)
- => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
-
-sscalar :: Elt a => a -> Shaped '[] a
-sscalar x = Shaped (mscalar x)
-
-sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP v
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v)
-
-sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a
-sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v)
-
-stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
-stoVectorP = coerce mtoVectorP
-
-stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a
-stoVector = coerce mtoVector
-
-sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
- => NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromList1 l
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromList1 (coerce l))
-
-sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a
-sfromList = Shaped . mfromList1 . fmap mscalar
-
-stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoList (Shaped arr) = coerce (mtoList1 arr)
-
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoList
-
-sunScalar :: Elt a => Shaped '[] a -> a
-sunScalar arr = sindex arr ZIS
-
-sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)
-sconstantP x
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)
-
-sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
- => a -> Shaped sh a
-sconstant x = coerce fromPrimitive (sconstantP @sh x)
-
-sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a
-sslice ivs = slift $ \_ -> X.slice ivs
-
-srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 = slift $ \_ -> X.rev1
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
new file mode 100644
index 0000000..8cac298
--- /dev/null
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -0,0 +1,162 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Lemmas where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+
+-- * Lemmas about numbers and lists
+
+-- ** Nat
+
+lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
+lemLeqSuccSucc _ _ = unsafeCoerceRefl
+
+lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
+lemLeqPlus _ _ _ = Refl
+
+-- ** Append
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerceRefl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerceRefl
+
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+-- ** Simple type families
+
+lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp sn _ _ = go sn
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- lemReplicateSucc @a @n'm1
+ , Refl <- go n
+ = sym (lemReplicateSucc @a @(n'm1 + m))
+
+lemDropLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
+lemDropLenApp _ _ _ = unsafeCoerceRefl
+
+lemTakeLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
+lemTakeLenApp _ _ _ = unsafeCoerceRefl
+
+lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
+lemInitApp _ _ = unsafeCoerceRefl
+
+lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
+lemLastApp _ _ = unsafeCoerceRefl
+
+
+-- ** KnownNat
+
+lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
+lemKnownNatSucc = Dict
+
+lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKX = Dict
+lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+
+
+-- * Lemmas about shapes
+
+-- ** Known shapes
+
+lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
+lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
+
+lemKnownShX :: StaticShX sh -> Dict KnownShX sh
+lemKnownShX ZKX = Dict
+lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
+lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
+
+-- ** Rank
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem (Proxy @(Rank sh1T)) Proxy Proxy $
+ sym (lemRankApp ssh1 ssh2)
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem _ _ _ Refl = Refl
+
+lemRankAppComm :: proxy sh1 -> proxy sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl
+
+lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = unsafeCoerceRefl
+
+lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust ZSS = Refl
+lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
+
+-- ** Related to MapJust and/or Permutation
+
+lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemTakeLenMapJust PNil _ = Refl
+lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
+lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemDropLenMapJust PNil _ = Refl
+lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
+lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemIndexMapJust SZ (_ :$$ _) = Refl
+lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemIndexMapJust i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemIndexMapJust _ ZSS = error "Index of empty"
+
+lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemPermuteMapJust PNil _ = Refl
+lemPermuteMapJust (i `PCons` is) sh
+ | Refl <- lemPermuteMapJust is sh
+ , Refl <- lemIndexMapJust i sh
+ = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
new file mode 100644
index 0000000..144230e
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -0,0 +1,936 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Mixed where
+
+import Prelude hiding (mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad (forM_, when)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (bimap)
+import Data.Coerce
+import Data.Foldable (toList)
+import Data.Int
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty(..))
+import Data.List.NonEmpty qualified as NE
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
+import Foreign.C.Types (CInt)
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
+import Data.Bag
+
+
+-- TODO:
+-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+-- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int
+-- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute)
+-- After benchmarking: matmul and matvec
+
+
+
+-- Invariant in the API
+-- ====================
+--
+-- In the underlying XArray, there is some shape for elements of an empty
+-- array. For example, for this array:
+--
+-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
+-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
+--
+-- the two underlying XArrays have a shape, and those shapes might be anything.
+-- The invariant is that these element shapes are unobservable in the API.
+-- (This is possible because you ought to not be able to get to such an element
+-- without indexing out of bounds.)
+--
+-- Note, though, that the converse situation may arise: the outer array might
+-- be nonempty but then the inner arrays might. This is fine, an invariant only
+-- applies if the _outer_ array is empty.
+--
+-- TODO: can we enforce that the elements of an empty (nested) array have
+-- all-zero shape?
+-- -> no, because mlift and also any kind of internals probing from outsiders
+
+
+-- Primitive element types
+-- =======================
+--
+-- There are a few primitive element types; arrays containing elements of such
+-- type are a newtype over an XArray, which it itself a newtype over a Vector.
+-- Unfortunately, the setup of the library requires us to list these primitive
+-- element types multiple times; to aid in extending the list, all these lists
+-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+ deriving (Show)
+
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class (Storable a, Elt a) => PrimElt a where
+ fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+ toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
+
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Bool
+instance PrimElt Int
+instance PrimElt Int64
+instance PrimElt Int32
+instance PrimElt CInt
+instance PrimElt Float
+instance PrimElt Double
+instance PrimElt ()
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'mtranspose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
+-- that you're not supposed to see. In particular, you might see (nonempty)
+-- sizes of the elements of an empty array, which is information that should
+-- ostensibly not exist; the full array is still empty.
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+#define ANDSHOW , Show
+#else
+#define ANDSHOW
+#endif
+
+data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
+ deriving (Eq, Ord, Generic ANDSHOW)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic ANDSHOW) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
+#endif
+-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
+
+deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b))
+deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b))
+
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (Show (Mixed (sh1 ++ sh2) a)) => Show (Mixed sh1 (Mixed sh2 a))
+#endif
+
+deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a))
+deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a))
+
+
+-- | Internal helper data family mirroring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool)
+newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
+newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
+newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
+newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
+newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
+newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+
+
+showsMixedArray :: (Show a, Elt a)
+ => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@
+ -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@
+ -> Int -> Mixed sh a -> ShowS
+showsMixedArray fromlistPrefix replicatePrefix d arr =
+ showParen (d > 10) $
+ -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here
+ case mtoListLinear arr of
+ hd : _ : _
+ | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) ->
+ showString replicatePrefix . showString " " . showsPrec 11 hd
+ _ ->
+ showString fromlistPrefix . showString " " . shows (mtoListLinear arr)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Mixed sh a) where
+ showsPrec d arr =
+ let sh = show (shxToList (mshape arr))
+ in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr
+#endif
+
+instance Elt a => NFData (Mixed sh a) where
+ rnf = mrnf
+
+
+mliftNumElt1 :: (PrimElt a, PrimElt b)
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
+ -> Mixed sh a -> Mixed sh b
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+
+mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
+ -> Mixed sh a -> Mixed sh b -> Mixed sh c
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 (liftO2 . numEltAdd)
+ (-) = mliftNumElt2 (liftO2 . numEltSub)
+ (*) = mliftNumElt2 (liftO2 . numEltMul)
+ negate = mliftNumElt1 (liftO1 . numEltNeg)
+ abs = mliftNumElt1 (liftO1 . numEltAbs)
+ signum = mliftNumElt1 (liftO1 . numEltSignum)
+ -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
+ fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ recip = mliftNumElt1 (liftO1 . floatEltRecip)
+ (/) = mliftNumElt2 (liftO2 . floatEltDiv)
+
+instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ exp = mliftNumElt1 (liftO1 . floatEltExp)
+ log = mliftNumElt1 (liftO1 . floatEltLog)
+ sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
+
+ (**) = mliftNumElt2 (liftO2 . floatEltPow)
+ logBase = mliftNumElt2 (liftO2 . floatEltLogbase)
+
+ sin = mliftNumElt1 (liftO1 . floatEltSin)
+ cos = mliftNumElt1 (liftO1 . floatEltCos)
+ tan = mliftNumElt1 (liftO1 . floatEltTan)
+ asin = mliftNumElt1 (liftO1 . floatEltAsin)
+ acos = mliftNumElt1 (liftO1 . floatEltAcos)
+ atan = mliftNumElt1 (liftO1 . floatEltAtan)
+ sinh = mliftNumElt1 (liftO1 . floatEltSinh)
+ cosh = mliftNumElt1 (liftO1 . floatEltCosh)
+ tanh = mliftNumElt1 (liftO1 . floatEltTanh)
+ asinh = mliftNumElt1 (liftO1 . floatEltAsinh)
+ acosh = mliftNumElt1 (liftO1 . floatEltAcosh)
+ atanh = mliftNumElt1 (liftO1 . floatEltAtanh)
+ log1p = mliftNumElt1 (liftO1 . floatEltLog1p)
+ expm1 = mliftNumElt1 (liftO1 . floatEltExpm1)
+ log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp)
+ log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)
+
+mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
+mquotArray = mliftNumElt2 (liftO2 . intEltQuot)
+mremArray = mliftNumElt2 (liftO2 . intEltRem)
+
+matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
+matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
+
+-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: Mixed sh a -> IShX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
+ mscalar :: a -> Mixed '[] a
+
+ -- | All arrays in the list, even subarrays inside @a@, must have the same
+ -- shape; if they do not, a runtime error will be thrown. See the
+ -- documentation of 'mgenerate' for more information about this restriction.
+ -- Furthermore, the length of the list must correspond with @n@: if @n@ is
+ -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
+ -- thrown.
+ --
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
+
+ -- | Note: this library makes no particular guarantees about the shapes of
+ -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the
+ -- full 'XArray' and as such you can distinguish different empty arrays by
+ -- the "shapes" of their elements. This information is meaningless, so you
+ -- should not use it.
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- | See the documentation for 'mlift'.
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+
+ -- TODO: mliftL is currently unused.
+ -- | All arrays in the input must have equal shapes, including subarrays
+ -- inside their elements.
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a)
+
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+
+ -- | All arrays in the input must have equal shapes, including subarrays
+ -- inside their elements.
+ mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a
+
+ mrnf :: Mixed sh a -> ()
+
+ -- ====== PRIVATE METHODS ====== --
+
+ -- | Tree giving the shape of every array component.
+ type ShapeTree a
+
+ mshapeTree :: a -> ShapeTree a
+
+ mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
+
+ mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+
+ mshowShapeTree :: Proxy a -> ShapeTree a -> String
+
+ -- | Returns the stride vector of each underlying component array making up
+ -- this mixed array.
+ marrayStrides :: Mixed sh a -> Bag [Int]
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- | Element types for which we have evidence of the (static part of the) shape
+-- in a type class constraint. Compare the instance contexts of the instances
+-- of this class with those of 'Elt': some instances have an additional
+-- "known-shape" constraint.
+--
+-- This class is (currently) only required for `memptyArray` and 'mgenerate'.
+class Elt a => KnownElt a where
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArrayUnsafe :: IShX sh -> Mixed sh a
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents.
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance Storable a => Elt (Primitive a) where
+ mshape (M_Primitive sh _) = sh
+ mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
+ mfromListOuter l@(arr1 :| _) =
+ let sh = SUnknown (length l) :$% mshape arr1
+ in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift ssh2 f (M_Primitive _ a)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , let result = f ZKX a
+ = M_Primitive (X.shape ssh2 result) result
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
+ mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
+ , let result = f ZKX a b
+ = M_Primitive (X.shape ssh3 result) result
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a))
+ mliftL ssh2 f l
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $
+ f ZKX (fmap (\(M_Primitive _ arr) -> arr) l)
+
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
+ let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ sh2 = shxCast' ssh2 sh1
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
+
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShX sh) perm arr)
+
+ mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
+ mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
+ let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l)
+ in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result
+
+ mrnf (M_Primitive sh a) = rnf sh `seq` rnf a
+
+ type ShapeTree (Primitive a) = ()
+ mshapeTree _ = ()
+ mshapeTreeEq _ () () = True
+ mshapeTreeEmpty _ () = False
+ mshowShapeTree _ () = "()"
+ marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShX sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Bool instance Elt Bool
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Int64 instance Elt Int64
+deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive CInt instance Elt CInt
+deriving via Primitive Double instance Elt Double
+deriving via Primitive Float instance Elt Float
+deriving via Primitive () instance Elt ()
+
+instance Storable a => KnownElt (Primitive a) where
+ memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Bool instance KnownElt Bool
+deriving via Primitive Int instance KnownElt Int
+deriving via Primitive Int64 instance KnownElt Int64
+deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive CInt instance KnownElt CInt
+deriving via Primitive Double instance KnownElt Double
+deriving via Primitive Float instance KnownElt Float
+deriving via Primitive () instance KnownElt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ mliftL ssh2 f =
+ let unzipT2l [] = ([], [])
+ unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
+ unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
+ in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2
+
+ mcastPartial ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b)
+
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+ mconcat =
+ let unzipT2l [] = ([], [])
+ unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
+ unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
+ in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2
+
+ mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b
+
+ type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
+ mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
+ mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
+ mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
+ marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
+ memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance Elt a => Elt (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting depth because it repeatedly
+ -- truncates the shape vector to one a little shorter. Fix with a
+ -- moverlongShape method, a prefix of which is mshape.
+ mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
+ mshape (M_Nest sh arr)
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
+
+ mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
+ mindex (M_Nest _ arr) = mindexPartial arr
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest sh arr) i
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mscalar = M_Nest ZSX
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
+ M_Nest (SUnknown (length l) :$% mshape arr)
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift ssh2 f (M_Nest sh1 arr) =
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ in M_Nest sh2 result
+ where
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
+ mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ in M_Nest sh3 result
+ where
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
+ -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
+ mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
+ let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ in fmap (M_Nest sh2) result
+ where
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ sh2 = shxCast' ssh2 sh1
+ in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = shxDropSh @sh @sh' sh (mshape arr)
+ , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (shxPermutePrefix perm sh)
+ (mtranspose perm arr)
+
+ mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mconcat l@(M_Nest sh1 _ :| _) =
+ let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
+ in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result
+
+ mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
+
+ type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+
+ mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Nest _ arr) = marrayStrides arr
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+
+instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
+ memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
+
+ mvecsUnsafeNew sh example
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh')))
+ where
+ sh' = mshape example
+
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
+
+memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
+memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
+
+mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
+mrank = shxRank . mshape
+
+-- | The total number of elements in the array.
+msize :: Elt a => Mixed sh a -> Int
+msize = shxSize . mshape
+
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+--
+-- __WARNING__: It is required that every @a@ returned by the argument to
+-- 'mgenerate' has the same shape. For example, the following will throw a
+-- runtime error:
+--
+-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
+-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
+-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
+-- > ...
+--
+-- because the size of the inner 'mgenerate' is not always the same (it depends
+-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
+-- the entire hierarchy (after distributing out tuples) must be a rectangular
+-- array. The type of 'mgenerate' allows this requirement to be broken very
+-- easily, hence the runtime check.
+mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum sh of
+ [] -> memptyArrayUnsafe sh
+ firstidx : restidxs ->
+ let firstelem = f (ixxZero' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEmpty (Proxy @a) shapetree
+ then memptyArrayUnsafe sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ restidxs $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
+ mvecsFreeze sh vecs
+
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1P (M_Primitive (n :$% sh) arr) =
+ let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
+
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+
+msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+
+mappend :: forall n m sh a. Elt a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
+ where
+ sn :$% sh = mshape arr1
+ sm :$% _ = mshape arr2
+ ssh = ssxFromShX sh
+ snm :: SMayNat () SNat (AddMaybe n m)
+ snm = case (sn, sm) of
+ (SUnknown{}, _) -> SUnknown ()
+ (SKnown{}, SUnknown{}) -> SUnknown ()
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
+
+ f :: forall sh' b. Storable b
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
+
+mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+
+mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+
+mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
+mtoVectorP (M_Primitive _ v) = X.toVector v
+
+mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
+mtoVector arr = mtoVectorP (toPrimitive arr)
+
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1 l)
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
+
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
+
+mtoListLinear :: Elt a => Mixed sh a -> [a]
+mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
+
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr ZIX
+
+mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
+mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+
+munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
+munNest (M_Nest _ arr) = arr
+
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
+
+munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
+munzip (M_Tup2 a b) = (a, b)
+
+mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+ let sh1 = shxDropSSX ssh sh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2)
+ (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
+
+-- | See the caveats at 'Data.Array.XArray.rerank'.
+mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
+mrerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = ssxFromShX (mshape arr)
+ in mlift (ssxAppend (ssxFromShX sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+
+mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+mslice i n arr =
+ let _ :$% sh = mshape arr
+ in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
+
+msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
+
+mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
+mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr
+
+mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' arr =
+ mlift (ssxFromShX sh')
+ (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh')
+ arr
+
+mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
+mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr
+
+miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
+miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+
+-- | Throws if the array is empty.
+mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr))
+
+-- | Throws if the array is empty.
+mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+
+mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
+mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sh1 of
+ _ :$% _
+ | sh1 == sh2
+ , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) ->
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
+ | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"
+ ZSX -> error "unreachable"
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'mdot1Inner' if applicable.
+mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
+mdot a b =
+ munScalar $
+ mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a)))
+ (fromPrimitive (mflatten (toPrimitive b)))
+
+mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
+mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
+
+mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
+mtoXArrayPrim = mtoXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
+
+mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+
+mliftPrim :: (PrimElt a, PrimElt b)
+ => (a -> b)
+ -> Mixed sh a -> Mixed sh b
+mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
+ => (a -> b -> c)
+ -> Mixed sh a -> Mixed sh b -> Mixed sh c
+mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
+ fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
new file mode 100644
index 0000000..852dd5e
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -0,0 +1,644 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Mixed.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Bifunctor (first)
+import Data.Coerce
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Functor.Product
+import Data.Kind (Constraint, Type)
+import Data.Monoid (Sum(..))
+import Data.Type.Equality
+import GHC.Exts (withDict)
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+
+import Data.Array.Nested.Types
+
+
+-- | The length of a type-level list. If the argument is a shape, then the
+-- result is the rank of that shape.
+type family Rank sh where
+ Rank '[] = 0
+ Rank (_ : sh) = Rank sh + 1
+
+
+-- * Mixed lists
+
+type role ListX nominal representational
+type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
+data ListX sh f where
+ ZX :: ListX '[] f
+ (::%) :: f n -> ListX sh f -> ListX (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
+infixr 3 ::%
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (forall n. Show (f n)) => Show (ListX sh f)
+#else
+instance (forall n. Show (f n)) => Show (ListX sh f) where
+ showsPrec _ = listxShow shows
+#endif
+
+instance (forall n. NFData (f n)) => NFData (ListX sh f) where
+ rnf ZX = ()
+ rnf (x ::% l) = rnf x `seq` rnf l
+
+data UnconsListXRes f sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
+listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
+listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
+listxUncons ZX = Nothing
+
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqType ZX ZX = Just Refl
+listxEqType (n ::% sh) (m ::% sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listxEqType sh sh'
+ = Just Refl
+listxEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqual ZX ZX = Just Refl
+listxEqual (n ::% sh) (m ::% sh')
+ | Just Refl <- testEquality n m
+ , n == m
+ , Just Refl <- listxEqual sh sh'
+ = Just Refl
+listxEqual _ _ = Nothing
+
+listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
+listxFmap _ ZX = ZX
+listxFmap f (x ::% xs) = f x ::% listxFmap f xs
+
+listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+listxFold _ ZX = mempty
+listxFold f (x ::% xs) = f x <> listxFold f xs
+
+listxLength :: ListX sh f -> Int
+listxLength = getSum . listxFold (\_ -> Sum 1)
+
+listxRank :: ListX sh f -> SNat (Rank sh)
+listxRank ZX = SNat
+listxRank (_ ::% l) | SNat <- listxRank l = SNat
+
+listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
+listxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListX sh' f -> ShowS
+ go _ ZX = id
+ go prefix (x ::% xs) = showString prefix . f x . go "," xs
+
+listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i)
+listxFromList topssh topl = go topssh topl
+ where
+ go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
+ go ZKX [] = ZX
+ go (_ :!% sh) (i : is) = Const i ::% go sh is
+ go _ _ = error $ "listxFromList: Mismatched list length (type says "
+ ++ show (ssxLength topssh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+listxToList :: ListX sh' (Const i) -> [i]
+listxToList ZX = []
+listxToList (Const i ::% is) = i : listxToList is
+
+listxHead :: ListX (mn ': sh) f -> f mn
+listxHead (i ::% _) = i
+
+listxTail :: ListX (n : sh) i -> ListX sh i
+listxTail (_ ::% sh) = sh
+
+listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend ZX idx' = idx'
+listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
+
+listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f
+listxDrop ZX long = long
+listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'
+
+listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
+listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
+listxInit (_ ::% ZX) = ZX
+
+listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
+listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
+listxLast (x ::% ZX) = x
+
+listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g)
+listxZip ZX ZX = ZX
+listxZip (i ::% irest) (j ::% jrest) =
+ Pair i j ::% listxZip irest jrest
+
+listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g
+ -> ListX sh h
+listxZipWith _ ZX ZX = ZX
+listxZipWith f (i ::% is) (j ::% js) =
+ f i j ::% listxZipWith f is js
+
+
+-- * Mixed indices
+
+-- | An index into a mixed-typed array.
+type role IxX nominal representational
+type IxX :: [Maybe Nat] -> Type -> Type
+newtype IxX sh i = IxX (ListX sh (Const i))
+ deriving (Eq, Ord, Generic)
+
+pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
+pattern ZIX = IxX ZX
+
+pattern (:.%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxX sh i -> IxX sh1 i
+pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
+ where i :.% IxX shl = IxX (Const i ::% shl)
+infixr 3 :.%
+
+{-# COMPLETE ZIX, (:.%) #-}
+
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
+type IIxX sh = IxX sh Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxX sh i)
+#else
+instance Show i => Show (IxX sh i) where
+ showsPrec _ (IxX l) = listxShow (shows . getConst) l
+#endif
+
+instance Functor (IxX sh) where
+ fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
+
+instance Foldable (IxX sh) where
+ foldMap f (IxX l) = listxFold (f . getConst) l
+
+instance NFData i => NFData (IxX sh i)
+
+ixxLength :: IxX sh i -> Int
+ixxLength (IxX l) = listxLength l
+
+ixxRank :: IxX sh i -> SNat (Rank sh)
+ixxRank (IxX l) = listxRank l
+
+ixxZero :: StaticShX sh -> IIxX sh
+ixxZero ZKX = ZIX
+ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
+
+ixxZero' :: IShX sh -> IIxX sh
+ixxZero' ZSX = ZIX
+ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
+
+ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
+ixxFromList = coerce (listxFromList @_ @i)
+
+ixxHead :: IxX (n : sh) i -> i
+ixxHead (IxX list) = getConst (listxHead list)
+
+ixxTail :: IxX (n : sh) i -> IxX sh i
+ixxTail (IxX list) = IxX (listxTail list)
+
+ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
+ixxAppend = coerce (listxAppend @_ @(Const i))
+
+ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i
+ixxDrop = coerce (listxDrop @(Const i) @(Const i))
+
+ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
+ixxInit = coerce (listxInit @(Const i))
+
+ixxLast :: forall n sh i. IxX (n : sh) i -> i
+ixxLast = coerce (listxLast @(Const i))
+
+ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
+ixxCast ZKX ZIX = ZIX
+ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx
+ixxCast _ _ = error "ixxCast: ranks don't match"
+
+ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
+ixxZip ZIX ZIX = ZIX
+ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
+
+ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
+ixxZipWith _ ZIX ZIX = ZIX
+ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
+
+ixxFromLinear :: IShX sh -> Int -> IIxX sh
+ixxFromLinear = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IShX sh -> Int -> (IIxX sh, Int)
+ go ZSX i = (ZIX, i)
+ go (n :$% sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` fromSMayNat' n
+ in (locali :.% idx, upi)
+
+ixxToLinear :: IShX sh -> IIxX sh -> Int
+ixxToLinear = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IShX sh -> IIxX sh -> (Int, Int)
+ go ZSX ZIX = (0, 1)
+ go (n :$% sh) (i :.% ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, fromSMayNat' n * sz)
+
+
+-- * Mixed shapes
+
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
+deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
+deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+
+instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+ rnf (SUnknown i) = rnf i
+ rnf (SKnown x) = rnf x
+
+instance TestEquality f => TestEquality (SMayNat i f) where
+ testEquality SUnknown{} SUnknown{} = Just Refl
+ testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
+ testEquality _ _ = Nothing
+
+fromSMayNat :: (n ~ Nothing => i -> r)
+ -> (forall m. n ~ Just m => f m -> r)
+ -> SMayNat i f n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
+type family AddMaybe n m where
+ AddMaybe Nothing _ = Nothing
+ AddMaybe (Just _) Nothing = Nothing
+ AddMaybe (Just n) (Just m) = Just (n + m)
+
+smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+
+
+-- | This is a newtype over 'ListX'.
+type role ShX nominal representational
+type ShX :: [Maybe Nat] -> Type -> Type
+newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+ deriving (Eq, Ord, Generic)
+
+pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
+pattern ZSX = ShX ZX
+
+pattern (:$%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
+ where i :$% ShX shl = ShX (i ::% shl)
+infixr 3 :$%
+
+{-# COMPLETE ZSX, (:$%) #-}
+
+type IShX sh = ShX sh Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ShX sh i)
+#else
+instance Show i => Show (ShX sh i) where
+ showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+#endif
+
+instance Functor (ShX sh) where
+ fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
+
+instance NFData i => NFData (ShX sh i) where
+ rnf (ShX ZX) = ()
+ rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
+ rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+
+-- | This checks only whether the types are equal; unknown dimensions might
+-- still differ. This corresponds to 'testEquality', except on the penultimate
+-- type parameter.
+shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqType ZSX ZSX = Just Refl
+shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
+ | Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType _ _ = Nothing
+
+-- | This checks whether all dimensions have the same value. This is more than
+-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
+-- @some@ package (except on the penultimate type parameter).
+shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqual ZSX ZSX = Just Refl
+shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
+ | i == j
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual _ _ = Nothing
+
+shxLength :: ShX sh i -> Int
+shxLength (ShX l) = listxLength l
+
+shxRank :: ShX sh i -> SNat (Rank sh)
+shxRank (ShX l) = listxRank l
+
+-- | The number of elements in an array described by this shape.
+shxSize :: IShX sh -> Int
+shxSize ZSX = 1
+shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
+
+shxFromList :: StaticShX sh -> [Int] -> IShX sh
+shxFromList topssh topl = go topssh topl
+ where
+ go :: StaticShX sh' -> [Int] -> IShX sh'
+ go ZKX [] = ZSX
+ go (SKnown sn :!% sh) (i : is)
+ | i == fromSNat' sn = SKnown sn :$% go sh is
+ | otherwise = error $ "shxFromList: Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is
+ go _ _ = error $ "shxFromList: Mismatched list length (type says "
+ ++ show (ssxLength topssh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+shxToList :: IShX sh -> [Int]
+shxToList ZSX = []
+shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+
+shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
+shxFromSSX ZKX = ZSX
+shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
+ | Refl <- lemMapJustCons @sh Refl
+ = SKnown n :$% shxFromSSX sh
+shxFromSSX (SUnknown _ :!% _) = error "unreachable"
+
+-- | This may fail if @sh@ has @Nothing@s in it.
+shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i)
+shxFromSSX2 ZKX = Just ZSX
+shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
+shxFromSSX2 (SUnknown _ :!% _) = Nothing
+
+shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+
+shxHead :: ShX (n : sh) i -> SMayNat i SNat n
+shxHead (ShX list) = listxHead list
+
+shxTail :: ShX (n : sh) i -> ShX sh i
+shxTail (ShX list) = ShX (listxTail list)
+
+shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+
+shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+
+shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
+shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+
+shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
+shxInit = coerce (listxInit @(SMayNat i SNat))
+
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
+shxLast = coerce (listxLast @(SMayNat i SNat))
+
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
+
+shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+ -> ShX sh i -> ShX sh j -> ShX sh k
+shxZipWith _ ZSX ZSX = ZSX
+shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
+
+-- This is a weird operation, so it has a long name
+shxCompleteZeros :: StaticShX sh -> IShX sh
+shxCompleteZeros ZKX = ZSX
+shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
+shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
+
+shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp _ ZKX idx = (ZSX, idx)
+shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
+
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = \sh -> go sh id []
+ where
+ go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
+ go ZSX f = (f ZIX :)
+ go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+
+shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
+shxCast ZKX ZSX = Just ZSX
+shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
+shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh
+shxCast _ _ = Nothing
+
+-- | Partial version of 'shxCast'.
+shxCast' :: StaticShX sh' -> IShX sh -> IShX sh'
+shxCast' ssh sh = case shxCast ssh sh of
+ Just sh' -> sh'
+ Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
+
+
+-- * Static mixed shapes
+
+-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+type StaticShX :: [Maybe Nat] -> Type
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+ deriving (Eq, Ord)
+
+pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
+pattern ZKX = StaticShX ZX
+
+pattern (:!%)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = StaticShX (i ::% shl)
+infixr 3 :!%
+
+{-# COMPLETE ZKX, (:!%) #-}
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (StaticShX sh)
+#else
+instance Show (StaticShX sh) where
+ showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+#endif
+
+instance NFData (StaticShX sh) where
+ rnf (StaticShX ZX) = ()
+ rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
+ rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
+
+instance TestEquality StaticShX where
+ testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
+
+ssxLength :: StaticShX sh -> Int
+ssxLength (StaticShX l) = listxLength l
+
+ssxRank :: StaticShX sh -> SNat (Rank sh)
+ssxRank (StaticShX l) = listxRank l
+
+-- | @ssxEqType = 'testEquality'@. Provided for consistency.
+ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+ssxEqType = testEquality
+
+ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
+ssxAppend ZKX sh' = sh'
+ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+
+ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
+ssxHead (StaticShX list) = listxHead list
+
+ssxTail :: StaticShX (n : sh) -> StaticShX sh
+ssxTail (_ :!% ssh) = ssh
+
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+
+ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+
+ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+
+ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
+ssxInit = coerce (listxInit @(SMayNat () SNat))
+
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
+ssxLast = coerce (listxLast @(SMayNat () SNat))
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = SUnknown () :!% ssxReplicate n
+
+ssxIotaFrom :: StaticShX sh -> Int -> [Int]
+ssxIotaFrom ZKX _ = []
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1)
+
+ssxFromShX :: ShX sh i -> StaticShX sh
+ssxFromShX ZSX = ZKX
+ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh
+
+ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
+ssxFromSNat SZ = ZKX
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShX :: [Maybe Nat] -> Constraint
+class KnownShX sh where knownShX :: StaticShX sh
+instance KnownShX '[] where knownShX = ZKX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+
+withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
+withKnownShX = withDict @(KnownShX sh)
+
+
+-- * Flattening
+
+type Flatten sh = Flatten' 1 sh
+
+type family Flatten' acc sh where
+ Flatten' acc '[] = Just acc
+ Flatten' acc (Nothing : sh) = Nothing
+ Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
+
+-- This function is currently unused
+ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go acc ZKX = SKnown acc
+ go _ (SUnknown () :!% _) = SUnknown ()
+ go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
+
+shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go acc ZSX = SKnown acc
+ go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
+ go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
+
+ goUnknown :: Int -> IShX sh -> Int
+ goUnknown acc ZSX = acc
+ goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
+ goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
+
+
+-- | Very untyped: only length is checked (at runtime).
+instance KnownShX sh => IsList (ListX sh (Const i)) where
+ type Item (ListX sh (Const i)) = i
+ fromList = listxFromList (knownShX @sh)
+ toList = listxToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShX sh => IsList (IxX sh i) where
+ type Item (IxX sh i) = i
+ fromList = IxX . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and known dimensions are checked (at runtime).
+instance KnownShX sh => IsList (ShX sh Int) where
+ type Item (ShX sh Int) = Int
+ fromList = shxFromList (knownShX @sh)
+ toList = shxToList
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
new file mode 100644
index 0000000..03d1640
--- /dev/null
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -0,0 +1,283 @@
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Permutation where
+
+import Data.Coerce (coerce)
+import Data.Functor.Const
+import Data.List (sort)
+import Data.Maybe (fromMaybe)
+import Data.Proxy
+import Data.Type.Bool
+import Data.Type.Equality
+import Data.Type.Ord
+import GHC.Exts (withDict)
+import GHC.TypeError
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Types
+
+
+-- * Permutations
+
+-- | A "backward" permutation of a dimension list. The operation on the
+-- dimension list is most similar to @backpermute@ in the @vector@ package; see
+-- 'Permute' for code that implements this.
+data Perm list where
+ PNil :: Perm '[]
+ PCons :: SNat a -> Perm l -> Perm (a : l)
+infixr 5 `PCons`
+deriving instance Show (Perm list)
+deriving instance Eq (Perm list)
+
+instance TestEquality Perm where
+ testEquality PNil PNil = Just Refl
+ testEquality (x `PCons` xs) (y `PCons` ys)
+ | Just Refl <- testEquality x y
+ , Just Refl <- testEquality xs ys = Just Refl
+ testEquality _ _ = Nothing
+
+permRank :: Perm list -> SNat (Rank list)
+permRank PNil = SNat
+permRank (_ `PCons` l) | SNat <- permRank l = SNat
+
+permFromList :: [Int] -> (forall list. Perm list -> r) -> r
+permFromList [] k = k PNil
+permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+
+permToList :: Perm list -> [Natural]
+permToList PNil = mempty
+permToList (x `PCons` l) = TN.fromSNat x : permToList l
+
+permToList' :: Perm list -> [Int]
+permToList' = map fromIntegral . permToList
+
+-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of
+-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't,
+-- then @Nothing@ is returned.
+permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r
+permCheckPermutation = \p k ->
+ let n = permRank p
+ in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
+ (Just Refl, Just Refl) -> Just k
+ _ -> Nothing
+ where
+ lemElemCount :: (0 <= n, Compare n m ~ LT)
+ => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount _ _ = unsafeCoerceRefl
+
+ lemCount :: (OrdCond (Compare i n) True False True ~ True)
+ => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount _ _ = unsafeCoerceRefl
+
+ lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
+ lemElem _ _ = unsafeCoerceRefl
+
+ provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
+ -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ provePerm1 _ _ PNil = Just Refl
+ provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ | Just Refl <- provePerm1 p rtop perm
+ = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
+ (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ _ -> Nothing
+ | otherwise
+ = Nothing
+
+ provePerm2 :: SNat i -> SNat n -> Perm is'
+ -> Maybe (AllElem' (Count i n) is' :~: True)
+ provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
+ case cmpNat i n of
+ EQI -> Just Refl
+ LTI | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ -> checkElem i perm
+ | otherwise -> Nothing
+ GTI -> error "unreachable"
+ where
+ checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
+ checkElem _ PNil = Nothing
+ checkElem i@SNat (PCons k@SNat perm :: Perm is') =
+ case sameNat i k of
+ Just Refl -> Just Refl
+ Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
+ | otherwise -> Nothing
+
+-- | Utility class for generating permutations from type class information.
+class KnownPerm l where makePerm :: Perm l
+instance KnownPerm '[] where makePerm = PNil
+instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+
+withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r
+withKnownPerm = withDict @(KnownPerm l)
+
+-- | Untyped permutations for ranked arrays
+type PermR = [Int]
+
+
+-- ** Applying permutations
+
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
+listxTakeLen PNil _ = ZX
+listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+
+listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
+listxDropLen PNil sh = sh
+listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+
+listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
+listxPermute PNil _ = ZX
+listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
+ listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
+
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
+listxIndex _ _ SZ (n ::% _) = n
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listxIndex p pT i sh
+listxIndex _ _ _ ZX = error "Index into empty shape"
+
+listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
+
+ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+
+ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+
+ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
+ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+
+ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+
+shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
+shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+
+
+-- * Operations on permutations
+
+permInverse :: Perm is
+ -> (forall is'.
+ IsPermutation is'
+ => Perm is'
+ -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
+ -> r)
+ -> r
+permInverse = \perm k ->
+ genPerm perm $ \(invperm :: Perm is') ->
+ fromMaybe
+ (error $ "permInverse: did not generate permutation? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ (permCheckPermutation invperm
+ (k invperm
+ (\ssh -> case permCheckInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)))
+ where
+ genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
+ genPerm perm =
+ let permList = permToList' perm
+ in toHList $ map snd (sort (zip permList [0..]))
+ where
+ toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
+ toHList [] k = k PNil
+ toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
+
+ permCheckInverse :: Perm is -> Perm is' -> StaticShX sh
+ -> Maybe (Permute is' (Permute is sh) :~: sh)
+ permCheckInverse perm perminv ssh =
+ ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
+
+type family MapSucc is where
+ MapSucc '[] = '[]
+ MapSucc (i : is) = i + 1 : MapSucc is
+
+permShift1 :: Perm l -> Perm (0 : MapSucc l)
+permShift1 = (SNat @0 `PCons`) . permMapSucc
+ where
+ permMapSucc :: Perm l -> Perm (MapSucc l)
+ permMapSucc PNil = PNil
+ permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+
+
+-- * Lemmas
+
+lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ PNil = Refl
+lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKX PNil = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% _) PNil = Refl
+lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
+ -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
new file mode 100644
index 0000000..9778c54
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -0,0 +1,323 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked (
+ Ranked(Ranked),
+ rquotArray, rremArray, ratan2Array,
+ rshape, rrank,
+ module Data.Array.Nested.Ranked,
+ liftRanked1, liftRanked2,
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
+
+
+remptyArray :: KnownElt a => Ranked 1 a
+remptyArray = mtoRanked (memptyArray ZSX)
+
+-- | The total number of elements in the array.
+rsize :: Elt a => Ranked n a -> Int
+rsize = shrSize . rshape
+
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)
+
+rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (ixxFromIxR idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgenerate sh f
+ | sn@SNat <- shrRank sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
+ = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
+
+-- | See the documentation of 'mlift'.
+rlift :: forall n1 n2 a. Elt a
+ => SNat n2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
+
+-- | See the documentation of 'mlift2'.
+rlift2 :: forall n1 n2 n3 a. Elt a
+ => SNat n3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
+rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+
+rsumOuter1P :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1P (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (msumOuter1P arr)
+
+rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+
+rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
+rsumAllPrim (Ranked arr) = msumAllPrim arr
+
+rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
+rtranspose perm arr
+ | sn@SNat <- rrank arr
+ , Dict <- lemKnownReplicate sn
+ , length perm <= fromIntegral (natVal (Proxy @n))
+ = rlift sn
+ (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
+ arr
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+
+rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
+rconcat
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce mconcat
+
+rappend :: forall n a. Elt a
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend arr1 arr2
+ | sn@SNat <- rrank arr1
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
+ arr1 arr2
+
+rscalar :: Elt a => a -> Ranked 0 a
+rscalar x = Ranked (mscalar x)
+
+rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
+ | Dict <- lemKnownReplicate (shrRank sh)
+ = Ranked (mfromVectorP (shxFromShR sh) v)
+
+rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+
+rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
+rtoVectorP = coerce mtoVectorP
+
+rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
+rtoVector = coerce mtoVector
+
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 l = Ranked (mfromList1 l)
+
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = rreshape sh (rfromList1 l)
+
+rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
+rfromListPrim l = Ranked (mfromListPrim l)
+
+rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
+
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+
+rtoListLinear :: Elt a => Ranked n a -> [a]
+rtoListLinear (Ranked arr) = mtoListLinear arr
+
+rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
+rfromOrthotope sn arr
+ | Refl <- lemRankReplicate sn
+ = let xarr = XArray arr
+ in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
+
+rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh)
+ = arr
+
+runScalar :: Elt a => Ranked 0 a -> a
+runScalar arr = rindex arr ZIR
+
+rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
+rnest n arr
+ | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
+ = coerce (mnest (ssxFromSNat n) (coerce arr))
+
+runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
+runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
+ | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked arr
+
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip = coerce mzip
+
+runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
+runzip = coerce munzip
+
+rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
+rrerankP sn sh2 f (Ranked arr)
+ | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
+ , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
+ = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr)
+
+-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
+-- input array, then there is no way to deduce the full shape of the output
+-- array (more precisely, the @n2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the @n2@ part of the output shape with zeros.
+--
+-- For example, if:
+--
+-- @
+-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- f :: Ranked 2 Int -> Ranked 3 Float
+-- @
+--
+-- then:
+--
+-- @
+-- rrerank _ _ _ f arr :: Ranked 6 Float
+-- @
+--
+-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
+-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
+-- to return an array with shape all-0 here (it probably didn't), but there is
+-- no better number to put here absent a subarray of the input to pass to @f@.
+rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked (n + n1) a -> Ranked (n + n2) b
+rrerank sn sh2 f (rtoPrimitive -> arr) =
+ rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
+
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shxFromShR sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
+ | Dict <- lemKnownReplicate (shrRank sh)
+ = Ranked (mreplicateScalP (shxFromShR sh) x)
+
+rreplicateScal :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+
+rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n arr
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = rlift (rrank arr)
+ (\_ -> X.sliceU i n)
+ arr
+
+rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 arr =
+ rlift (rrank arr)
+ (\(_ :: StaticShX sh') ->
+ case lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
+ arr
+
+rreshape :: forall n n' a. Elt a
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' rarr@(Ranked arr)
+ | Dict <- lemKnownReplicate (rrank rarr)
+ , Dict <- lemKnownReplicate (shrRank sh')
+ = Ranked (mreshape (shxFromShR sh') arr)
+
+rflatten :: Elt a => Ranked n a -> Ranked 1 a
+rflatten (Ranked arr) = mtoRanked (mflatten arr)
+
+riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
+riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+
+-- | Throws if the array is empty.
+rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rminIndexPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixrFromIxX (mminIndexPrim arr)
+
+-- | Throws if the array is empty.
+rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rmaxIndexPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixrFromIxX (mmaxIndexPrim arr)
+
+rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
+rdot1Inner arr1 arr2
+ | SNat <- rrank arr1
+ , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
+ = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'rdot1Inner' if applicable.
+rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
+rdot = coerce mdot
+
+rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
+rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr)
+
+rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
+rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr)
+
+rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
new file mode 100644
index 0000000..babc809
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -0,0 +1,268 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Ranked.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+import Data.Foldable (toList)
+#endif
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
+#endif
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
+deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Ranked n a) where
+ showsPrec d arr@(Ranked marr) =
+ let sh = show (toList (rshape arr))
+ in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Ranked n a) where
+ rnf (Ranked arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance Elt a => Elt (Ranked n a) where
+ mshape (M_Ranked arr) = mshape arr
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
+ @(NonEmpty (Mixed sh2 (Ranked n a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
+ mconcat l = M_Ranked (mconcat (coerce l))
+
+ mrnf (M_Ranked arr) = mrnf arr
+
+ type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+
+ mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Ranked arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
+
+liftRanked1 :: forall n a b.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
+ -> Ranked n a -> Ranked n b
+liftRanked1 = coerce
+
+liftRanked2 :: forall n a b c.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
+ -> Ranked n a -> Ranked n b -> Ranked n c
+liftRanked2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+ (+) = liftRanked2 (+)
+ (-) = liftRanked2 (-)
+ (*) = liftRanked2 (*)
+ negate = liftRanked1 negate
+ abs = liftRanked1 abs
+ signum = liftRanked1 signum
+ fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
+ recip = liftRanked1 recip
+ (/) = liftRanked2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
+ exp = liftRanked1 exp
+ log = liftRanked1 log
+ sqrt = liftRanked1 sqrt
+ (**) = liftRanked2 (**)
+ logBase = liftRanked2 logBase
+ sin = liftRanked1 sin
+ cos = liftRanked1 cos
+ tan = liftRanked1 tan
+ asin = liftRanked1 asin
+ acos = liftRanked1 acos
+ atan = liftRanked1 atan
+ sinh = liftRanked1 sinh
+ cosh = liftRanked1 cosh
+ tanh = liftRanked1 tanh
+ asinh = liftRanked1 asinh
+ acosh = liftRanked1 acosh
+ atanh = liftRanked1 atanh
+ log1p = liftRanked1 GHC.Float.log1p
+ expm1 = liftRanked1 GHC.Float.expm1
+ log1pexp = liftRanked1 GHC.Float.log1pexp
+ log1mexp = liftRanked1 GHC.Float.log1mexp
+
+rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+rquotArray = liftRanked2 mquotArray
+rremArray = liftRanked2 mremArray
+
+ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+ratan2Array = liftRanked2 matan2Array
+
+
+rshape :: Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = shrFromShX2 (mshape arr)
+
+rrank :: Elt a => Ranked n a -> SNat n
+rrank = shrRank . rshape
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
+shrFromShX ZSX = ZSR
+shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
+shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
+shrFromShX2 sh
+ | Refl <- lemRankReplicate (Proxy @n)
+ = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
new file mode 100644
index 0000000..8b670e5
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -0,0 +1,369 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Kind (Type)
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Types
+
+
+-- * Ranked lists
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ListR n i)
+#else
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+#endif
+
+instance NFData i => NFData (ListR n i) where
+ rnf ZR = ()
+ rnf (x ::: l) = rnf x `seq` rnf l
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqRank ZR ZR = Just Refl
+listrEqRank (_ ::: sh) (_ ::: sh')
+ | Just Refl <- listrEqRank sh sh'
+ = Just Refl
+listrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqual ZR ZR = Just Refl
+listrEqual (i ::: sh) (j ::: sh')
+ | Just Refl <- listrEqual sh sh'
+ , i == j
+ = Just Refl
+listrEqual _ _ = Nothing
+
+listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR n' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrLength :: ListR n i -> Int
+listrLength = length
+
+listrRank :: ListR n i -> SNat n
+listrRank ZR = SNat
+listrRank (_ ::: sh) = snatSucc (listrRank sh)
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrHead :: ListR (n + 1) i -> i
+listrHead (i ::: _) = i
+listrHead ZR = error "unreachable"
+
+listrTail :: ListR (n + 1) i -> ListR n i
+listrTail (_ ::: sh) = sh
+listrTail ZR = error "unreachable"
+
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
+-- | Performs a runtime check that the lengths are identical.
+listrCast :: SNat n' -> ListR n i -> ListR n' i
+listrCast = listrCastWithName "listrCast"
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
+listrZip ZR ZR = ZR
+listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
+listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+
+listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
+listrZipWith _ ZR ZR = ZR
+listrZipWith f (i ::: irest) (j ::: jrest) =
+ f i j ::: listrZipWith f irest jrest
+listrZipWith _ _ _ =
+ error "listrZipWith: impossible pattern needlessly required"
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrRank sperm, listrRank sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- * Ranked indices
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
+type IIxR n = IxR n Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxR n i)
+#else
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+#endif
+
+instance NFData i => NFData (IxR sh i)
+
+ixrLength :: IxR sh i -> Int
+ixrLength (IxR l) = listrLength l
+
+ixrRank :: IxR n i -> SNat n
+ixrRank (IxR sh) = listrRank sh
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixrHead :: IxR (n + 1) i -> i
+ixrHead (IxR list) = listrHead list
+
+ixrTail :: IxR (n + 1) i -> IxR n i
+ixrTail (IxR list) = IxR (listrTail list)
+
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
+-- | Performs a runtime check that the lengths are identical.
+ixrCast :: SNat n' -> IxR n i -> IxR n' i
+ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
+
+ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
+ixrAppend = coerce (listrAppend @_ @i)
+
+ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
+ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+
+ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
+ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- * Ranked shapes
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: ShR sh = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ShR n i)
+#else
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+#endif
+
+instance NFData i => NFData (ShR sh i)
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+
+-- | This compares the shapes for value equality.
+shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+
+shrLength :: ShR sh i -> Int
+shrLength (ShR l) = listrLength l
+
+-- | This function can also be used to conjure up a 'KnownNat' dictionary;
+-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
+-- synonym yields 'KnownNat' evidence.
+shrRank :: ShR n i -> SNat n
+shrRank (ShR sh) = listrRank sh
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrHead :: ShR (n + 1) i -> i
+shrHead (ShR list) = listrHead list
+
+shrTail :: ShR (n + 1) i -> ShR n i
+shrTail (ShR list) = ShR (listrTail list)
+
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
+-- | Performs a runtime check that the lengths are identical.
+shrCast :: SNat n' -> ShR n i -> ShR n' i
+shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+
+shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
+shrAppend = coerce (listrAppend @_ @i)
+
+shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
+shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+
+shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
+shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList topl = go (SNat @n) topl
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
+ ++ show (fromSNat (SNat @n)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList
+
+
+-- * Internal helper functions
+
+listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
+listrCastWithName _ SZ ZR = ZR
+listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
new file mode 100644
index 0000000..198a068
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -0,0 +1,272 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Shaped (
+ Shaped(Shaped),
+ squotArray, sremArray, satan2Array,
+ sshape,
+ module Data.Array.Nested.Shaped,
+ liftShaped1, liftShaped2,
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+import Data.Array.Internal.ShapedG qualified as SG
+import Data.Array.Internal.ShapedS qualified as SS
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Base
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+import Data.Array.XArray qualified as X
+
+
+semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
+semptyArray sh = Shaped (memptyArray (shxFromShS sh))
+
+srank :: Elt a => Shaped sh a -> SNat (Rank sh)
+srank = shsRank . sshape
+
+-- | The total number of elements in the array.
+ssize :: Elt a => Shaped sh a -> Int
+ssize = shsSize . sshape
+
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
+
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx _ _ ZIS = ZSS
+shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+
+sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
+sindexPartial sarr@(Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
+ (ixxFromIxS idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+
+-- | See the documentation of 'mlift'.
+slift :: forall sh1 sh2 a. Elt a
+ => ShS sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
+
+-- | See the documentation of 'mlift'.
+slift2 :: forall sh1 sh2 sh3 a. Elt a
+ => ShS sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+
+ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+
+ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+
+ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
+ssumAllPrim (Shaped arr) = msumAllPrim arr
+
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ , Refl <- lemTakeLenMapJust perm (sshape sarr)
+ , Refl <- lemDropLenMapJust perm (sshape sarr)
+ , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr))
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
+ = Shaped (mtranspose perm arr)
+
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
+
+sscalar :: Elt a => a -> Shaped '[] a
+sscalar x = Shaped (mscalar x)
+
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+
+sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+
+stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
+stoVectorP = coerce mtoVectorP
+
+stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
+stoVector = coerce mtoVector
+
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
+
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
+
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+
+sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromListPrim sn l
+ | Refl <- lemAppNil @'[Just n]
+ = let ssh = SUnknown () :!% ZKX
+ xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
+ in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+
+sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
+
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+
+stoListLinear :: Elt a => Shaped sh a -> [a]
+stoListLinear (Shaped arr) = mtoListLinear arr
+
+sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
+sfromOrthotope sh (SS.A (SG.A arr)) =
+ Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))
+
+stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a
+stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr ZIS
+
+snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
+snest sh arr
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr))
+
+sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
+sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
+ | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
+ = Shaped arr
+
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip = coerce mzip
+
+sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
+sunzip = coerce munzip
+
+srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
+srerankP sh sh2 f sarr@(Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh1)
+ , Refl <- lemMapJustApp sh (Proxy @sh2)
+ = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr))))
+ (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr)
+
+-- | See the caveats at 'Data.Array.XArray.rerank'.
+srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
+srerank sh sh2 f (stoPrimitive -> arr) =
+ sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shxFromShS sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)
+
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n@SNat arr =
+ let _ :$$ sh = sshape arr
+ in slift (n :$$ sh) (\_ -> X.slice i n) arr
+
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
+
+sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
+
+sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
+sflatten arr =
+ case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
+ n@SNat -> sreshape (n :$$ ZSS) arr
+
+siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
+siota sn = Shaped (miota sn)
+
+-- | Throws if the array is empty.
+sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
+
+-- | Throws if the array is empty.
+smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+
+sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
+sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sshape sarr1 of
+ _ :$$ _
+ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
+ -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
+ _ -> error "unreachable"
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'sdot1Inner' if applicable.
+sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
+sdot = coerce mdot
+
+stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr)
+
+stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr)
+
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr)
+
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr)
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
new file mode 100644
index 0000000..879e6b5
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -0,0 +1,255 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Shaped.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+#endif
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Shaped n a) where
+ showsPrec d arr@(Shaped marr) =
+ let sh = show (shsToList (sshape arr))
+ in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Shaped sh a) where
+ rnf (Shaped arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
+
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
+ @(NonEmpty (Mixed sh2 (Shaped sh a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ mconcat l = M_Shaped (mconcat (coerce l))
+
+ mrnf (M_Shaped arr) = mrnf arr
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Shaped arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+liftShaped1 :: forall sh a b.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
+ -> Shaped sh a -> Shaped sh b
+liftShaped1 = coerce
+
+liftShaped2 :: forall sh a b c.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
+ -> Shaped sh a -> Shaped sh b -> Shaped sh c
+liftShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = liftShaped2 (+)
+ (-) = liftShaped2 (-)
+ (*) = liftShaped2 (*)
+ negate = liftShaped1 negate
+ abs = liftShaped1 abs
+ signum = liftShaped1 signum
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ recip = liftShaped1 recip
+ (/) = liftShaped2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ exp = liftShaped1 exp
+ log = liftShaped1 log
+ sqrt = liftShaped1 sqrt
+ (**) = liftShaped2 (**)
+ logBase = liftShaped2 logBase
+ sin = liftShaped1 sin
+ cos = liftShaped1 cos
+ tan = liftShaped1 tan
+ asin = liftShaped1 asin
+ acos = liftShaped1 acos
+ atan = liftShaped1 atan
+ sinh = liftShaped1 sinh
+ cosh = liftShaped1 cosh
+ tanh = liftShaped1 tanh
+ asinh = liftShaped1 asinh
+ acosh = liftShaped1 acosh
+ atanh = liftShaped1 atanh
+ log1p = liftShaped1 GHC.Float.log1p
+ expm1 = liftShaped1 GHC.Float.expm1
+ log1pexp = liftShaped1 GHC.Float.log1pexp
+ log1mexp = liftShaped1 GHC.Float.log1mexp
+
+squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+squotArray = liftShaped2 mquotArray
+sremArray = liftShaped2 mremArray
+
+satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+satan2Array = liftShaped2 matan2Array
+
+
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
new file mode 100644
index 0000000..5f9ba79
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -0,0 +1,425 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Shaped.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Array.Shape qualified as O
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Functor.Product qualified as Fun
+import Data.Kind (Constraint, Type)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Exts (withDict)
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+
+
+-- * Shaped lists
+
+-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
+-- removed in a future release.
+type role ListS nominal representational
+type ListS :: [Nat] -> (Nat -> Type) -> Type
+data ListS sh f where
+ ZS :: ListS '[] f
+ -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
+ (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+infixr 3 ::$
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+#else
+instance (forall n. Show (f n)) => Show (ListS sh f) where
+ showsPrec _ = listsShow shows
+#endif
+
+instance (forall m. NFData (f m)) => NFData (ListS n f) where
+ rnf ZS = ()
+ rnf (x ::$ l) = rnf x `seq` rnf l
+
+data UnconsListSRes f sh1 =
+ forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
+listsUncons ZS = Nothing
+
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqType ZS ZS = Just Refl
+listsEqType (n ::$ sh) (m ::$ sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listsEqType sh sh'
+ = Just Refl
+listsEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqual ZS ZS = Just Refl
+listsEqual (n ::$ sh) (m ::$ sh')
+ | Just Refl <- testEquality n m
+ , n == m
+ , Just Refl <- listsEqual sh sh'
+ = Just Refl
+listsEqual _ _ = Nothing
+
+listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
+listsFmap _ ZS = ZS
+listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
+
+listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFold _ ZS = mempty
+listsFold f (x ::$ xs) = f x <> listsFold f xs
+
+listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListS sh' f -> ShowS
+ go _ ZS = id
+ go prefix (x ::$ xs) = showString prefix . f x . go "," xs
+
+listsLength :: ListS sh f -> Int
+listsLength = getSum . listsFold (\_ -> Sum 1)
+
+listsRank :: ListS sh f -> SNat (Rank sh)
+listsRank ZS = SNat
+listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+
+listsToList :: ListS sh (Const i) -> [i]
+listsToList ZS = []
+listsToList (Const i ::$ is) = i : listsToList is
+
+listsHead :: ListS (n : sh) f -> f n
+listsHead (i ::$ _) = i
+
+listsTail :: ListS (n : sh) f -> ListS sh f
+listsTail (_ ::$ sh) = sh
+
+listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
+listsInit (_ ::$ ZS) = ZS
+
+listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
+listsLast (n ::$ ZS) = n
+
+listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend ZS idx' = idx'
+listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+
+listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
+listsZip ZS ZS = ZS
+listsZip (i ::$ is) (j ::$ js) =
+ Fun.Pair i j ::$ listsZip is js
+
+listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
+ -> ListS sh h
+listsZipWith _ ZS ZS = ZS
+listsZipWith f (i ::$ is) (j ::$ js) =
+ f i j ::$ listsZipWith f is js
+
+listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm PNil _ = ZS
+listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
+listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm PNil sh = sh
+listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
+listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
+ case listsIndex (Proxy @is') (Proxy @sh) i sh of
+ (item, SNat) -> item ::$ listsPermute is sh
+
+-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
+listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listsIndex p pT i sh
+listsIndex _ _ _ ZS = error "Index into empty shape"
+
+listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+
+-- * Shaped indices
+
+-- | An index into a shape-typed array.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (ListS sh (Const i))
+ deriving (Eq, Ord, Generic)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS = IxS ZS
+
+-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be
+-- removed in a future release.
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
+ where i :.$ IxS shl = IxS (Const i ::$ shl)
+infixr 3 :.$
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
+type IIxS sh = IxS sh Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxS sh i)
+#else
+instance Show i => Show (IxS sh i) where
+ showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+#endif
+
+instance Functor (IxS sh) where
+ fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
+
+instance Foldable (IxS sh) where
+ foldMap f (IxS l) = listsFold (f . getConst) l
+
+instance NFData i => NFData (IxS sh i)
+
+ixsLength :: IxS sh i -> Int
+ixsLength (IxS l) = listsLength l
+
+ixsRank :: IxS sh i -> SNat (Rank sh)
+ixsRank (IxS l) = listsRank l
+
+ixsZero :: ShS sh -> IIxS sh
+ixsZero ZSS = ZIS
+ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
+
+ixsHead :: IxS (n : sh) i -> i
+ixsHead (IxS list) = getConst (listsHead list)
+
+ixsTail :: IxS (n : sh) i -> IxS sh i
+ixsTail (IxS list) = IxS (listsTail list)
+
+ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
+ixsInit (IxS list) = IxS (listsInit list)
+
+ixsLast :: IxS (n : sh) i -> i
+ixsLast (IxS list) = getConst (listsLast list)
+
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
+ixsCast ZSS ZIS = ZIS
+ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
+ixsCast _ _ = error "ixsCast: ranks don't match"
+
+ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
+ixsAppend = coerce (listsAppend @_ @(Const i))
+
+ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
+ixsZip ZIS ZIS = ZIS
+ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
+
+ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
+ixsZipWith _ ZIS ZIS = ZIS
+ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
+
+ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
+ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+
+
+-- * Shaped shapes
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+--
+-- Note that because the shape of a shape-typed array is known statically, you
+-- can also retrieve the array shape from a 'KnownShS' dictionary.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ListS sh SNat)
+ deriving (Eq, Ord, Generic)
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS = ShS ZS
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
+ where i :$$ ShS shl = ShS (i ::$ shl)
+
+infixr 3 :$$
+
+{-# COMPLETE ZSS, (:$$) #-}
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (ShS sh)
+#else
+instance Show (ShS sh) where
+ showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+#endif
+
+instance NFData (ShS sh) where
+ rnf (ShS ZS) = ()
+ rnf (ShS (SNat ::$ l)) = rnf (ShS l)
+
+instance TestEquality ShS where
+ testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+
+-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
+-- equal if and only if values are equal.)
+shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
+shsEqual = testEquality
+
+shsLength :: ShS sh -> Int
+shsLength (ShS l) = listsLength l
+
+shsRank :: ShS sh -> SNat (Rank sh)
+shsRank (ShS l) = listsRank l
+
+shsSize :: ShS sh -> Int
+shsSize ZSS = 1
+shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+
+shsToList :: ShS sh -> [Int]
+shsToList ZSS = []
+shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+
+shsHead :: ShS (n : sh) -> SNat n
+shsHead (ShS list) = listsHead list
+
+shsTail :: ShS (n : sh) -> ShS sh
+shsTail (ShS list) = ShS (listsTail list)
+
+shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
+shsInit (ShS list) = ShS (listsInit list)
+
+shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS list) = listsLast list
+
+shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
+shsAppend = coerce (listsAppend @_ @SNat)
+
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen = coerce (listsTakeLenPerm @SNat)
+
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute = coerce (listsPermute @SNat)
+
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+
+shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
+shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+
+type family Product sh where
+ Product '[] = 1
+ Product (n : ns) = n * Product ns
+
+shsProduct :: ShS sh -> SNat (Product sh)
+shsProduct ZSS = SNat
+shsProduct (n :$$ sh) = n `snatMul` shsProduct sh
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+
+withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
+withKnownShS = withDict @(KnownShS sh)
+
+shsKnownShS :: ShS sh -> Dict KnownShS sh
+shsKnownShS ZSS = Dict
+shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict
+
+shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
+shsOrthotopeShape ZSS = Dict
+shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
+-- This function may be removed in a future release.
+shsFromListS :: ListS sh f -> ShS sh
+shsFromListS ZS = ZSS
+shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
+-- function may be removed in a future release.
+shsFromIxS :: IxS sh i -> ShS sh
+shsFromIxS (IxS l) = shsFromListS l
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownShS sh => IsList (ListS sh (Const i)) where
+ type Item (ListS sh (Const i)) = i
+ fromList topl = go (knownShS @sh) topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listsToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShS sh => IsList (IxS sh i) where
+ type Item (IxS sh i) = i
+ fromList = IxS . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and values are checked at runtime.
+instance KnownShS sh => IsList (ShS sh) where
+ type Item (ShS sh) = Int
+ fromList topl = ShS (go (knownShS @sh) topl)
+ where
+ go :: ShS sh' -> [Int] -> ListS sh' SNat
+ go ZSS [] = ZS
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = sn ::$ go sh is
+ | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shsToList
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
new file mode 100644
index 0000000..8a29aa5
--- /dev/null
+++ b/src/Data/Array/Nested/Trace.hs
@@ -0,0 +1,72 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitNamespaces #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-|
+This module is API-compatible with "Data.Array.Nested", except that inputs and
+outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods
+also have additional 'Show' constraints.
+
+>>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7))
+>>> length (show res) `seq` ()
+oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))]
+oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))]
+oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))]
+oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))]
+oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))]
+oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))]
+>>> res
+Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35]))))
+-}
+module Data.Array.Nested.Trace (
+ -- * Traced variants
+ module Data.Array.Nested.Trace,
+
+ -- * Re-exports from the plain "Data.Array.Nested" module
+ Ranked(Ranked),
+ ListR(ZR, (:::)),
+ IxR(..), IIxR,
+ ShR(..), IShR,
+
+ Shaped(Shaped),
+ ListS(ZS, (::$)),
+ IxS(..), IIxS,
+ ShS(..), KnownShS(..),
+
+ Mixed,
+ ListX(ZX, (::%)),
+ IxX(..), IIxX,
+ ShX(..), KnownShX(..), IShX,
+ StaticShX(..),
+ SMayNat(..),
+ Conversion(..),
+
+ Elt,
+ PrimElt,
+ Primitive(..),
+ KnownElt,
+
+ type (++),
+ Storable,
+ SNat, pattern SNat,
+ pattern SZ, pattern SS,
+ Perm(..),
+ IsPermutation,
+ KnownPerm(..),
+ NumElt, IntElt, FloatElt,
+ Rank, Product,
+ Replicate,
+ MapJust,
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.Nested
+import Data.Array.Nested.Trace.TH
+
+
+$(concat <$> mapM convertFun
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs
new file mode 100644
index 0000000..4b388e3
--- /dev/null
+++ b/src/Data/Array/Nested/Trace/TH.hs
@@ -0,0 +1,98 @@
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Trace.TH where
+
+import Control.Monad (zipWithM)
+import Data.List (foldl', intersperse)
+import Data.Maybe (isJust)
+import Language.Haskell.TH hiding (cxt)
+
+import Debug.Trace qualified as Debug
+
+import Data.Array.Nested
+
+
+splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type)
+splitFunTy = \case
+ ArrowT `AppT` t1 `AppT` t2 ->
+ let (vars, cx, args, ret) = splitFunTy t2
+ in (vars, cx, t1 : args, ret)
+ ForallT vs cx' t ->
+ let (vars, cx, args, ret) = splitFunTy t
+ in (vars ++ vs, cx ++ cx', args, ret)
+ t -> ([], [], [], t)
+
+data Arg = RRanked Type Arg
+ | RShaped Type Arg
+ | RMixed Type Arg
+ | RShowable Type
+ | ROther Type
+ deriving (Show)
+
+-- TODO: always returns Just
+recognise :: Type -> Maybe Arg
+recognise (ConT name `AppT` sht `AppT` ty)
+ | name == ''Ranked = RRanked sht <$> recognise ty
+ | name == ''Shaped = RShaped sht <$> recognise ty
+ | name == ''Mixed = RMixed sht <$> recognise ty
+recognise ty@(ConT name `AppT` _)
+ | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] =
+ Just (RShowable ty)
+recognise _ = Nothing
+
+realise :: Arg -> Type
+realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty
+realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty
+realise (RMixed sht ty) = ConT ''Mixed `AppT` sht `AppT` realise ty
+realise (RShowable ty) = ty
+realise (ROther ty) = ty
+
+mkShow :: Arg -> Cxt
+mkShow (RRanked _ ty) = mkShowElt ty
+mkShow (RShaped _ ty) = mkShowElt ty
+mkShow (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty)]
+mkShow (RShowable _) = []
+mkShow (ROther ty) = [ConT ''Show `AppT` ty]
+
+mkShowElt :: Arg -> Cxt
+mkShowElt (RRanked _ ty) = mkShowElt ty
+mkShowElt (RShaped _ ty) = mkShowElt ty
+mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT ''Elt `AppT` realise (RMixed sht ty)]
+mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty]
+mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty]
+
+convertType :: Type -> Q (Type, [Bool], Bool)
+convertType typ =
+ let (tybndrs, cxt, args, ret) = splitFunTy typ
+ argrels = map recognise args
+ retrel = recognise ret
+ in return
+ (ForallT tybndrs
+ (cxt ++ [constr
+ | Just rel <- retrel : argrels
+ , constr <- mkShow rel])
+ (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args)
+ ,map isJust argrels
+ ,isJust retrel)
+
+convertFun :: Name -> Q [Dec]
+convertFun funname = do
+ defname <- newName (nameBase funname)
+ (convty, argarrs, retarr) <- reifyType funname >>= convertType
+ names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..]
+ resname <- newName "res"
+ let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
+ let ex = LetE [ValD (VarP resname)
+ (NormalB (foldl' AppE (VarE funname) (map VarE names)))
+ []]
+ (VarE 'Debug.trace
+ `AppE` (VarE 'concat `AppE` ListE
+ ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++
+ intersperse (LitE (StringL ", "))
+ (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++
+ [LitE (StringL "]")]))
+ `AppE` VarE resname)
+ return
+ [SigD defname convty
+ ,FunD defname [Clause (map VarP names) (NormalB ex) []]]
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
new file mode 100644
index 0000000..4444acd
--- /dev/null
+++ b/src/Data/Array/Nested/Types.hs
@@ -0,0 +1,152 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Types (
+ -- * Reasoning helpers
+ subst1, subst2,
+
+ -- * Reified evidence of a type class
+ Dict(..),
+
+ -- * Type-level naturals
+ pattern SZ, pattern SS,
+ fromSNat', sameNat',
+ snatPlus, snatMinus, snatMul,
+ snatSucc,
+
+ -- * Type-level lists
+ type (++),
+ Replicate,
+ lemReplicateSucc,
+ MapJust,
+ lemMapJustEmpty, lemMapJustCons,
+ Head,
+ Tail,
+ Init,
+ Last,
+
+ -- * Unsafe
+ unsafeCoerceRefl,
+) where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+import Unsafe.Coerce qualified
+
+
+-- Reasoning helpers
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
+
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
+sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
+sameNat' n@SNat m@SNat = sameNat n m
+
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
+-- This should be a function in base
+snatPlus :: SNat n -> SNat m -> SNat (n + m)
+snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+
+-- This should be a function in base
+snatMinus :: SNat n -> SNat m -> SNat (n - m)
+snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce
+
+-- This should be a function in base
+snatMul :: SNat n -> SNat m -> SNat (n * m)
+snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+
+
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerceRefl
+
+type family MapJust l = r | r -> l where
+ MapJust '[] = '[]
+ MapJust (x : xs) = Just x : MapJust xs
+
+lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[]
+lemMapJustEmpty Refl = unsafeCoerceRefl
+
+lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
+lemMapJustCons Refl = unsafeCoerceRefl
+
+type family Head l where
+ Head (x : _) = x
+
+type family Tail l where
+ Tail (_ : xs) = xs
+
+type family Init l where
+ Init (x : y : xs) = x : Init (y : xs)
+ Init '[x] = '[]
+
+type family Last l where
+ Last (x : y : xs) = Last (y : xs)
+ Last '[x] = x
+
+
+-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
+-- only typecheck for actual type equalities. One cannot, e.g. accidentally
+-- write this:
+--
+-- @
+-- foo :: Proxy a -> Proxy b -> a :~: b
+-- foo = unsafeCoerceRefl
+-- @
+--
+-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce',
+-- but would have resulted in interesting memory errors at runtime.
+unsafeCoerceRefl :: a :~: b
+unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl
diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs
new file mode 100644
index 0000000..5c38d14
--- /dev/null
+++ b/src/Data/Array/Strided/Orthotope.hs
@@ -0,0 +1,43 @@
+{-# LANGUAGE ImportQualifiedPost #-}
+module Data.Array.Strided.Orthotope (
+ module Data.Array.Strided.Orthotope,
+ module Data.Array.Strided.Arith,
+) where
+
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+
+import Data.Array.Strided qualified as AS
+import Data.Array.Strided.Arith
+
+-- for liftVEltwise1
+import Data.Array.Strided.Arith.Internal (stridesDense)
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable
+import GHC.TypeLits
+
+
+fromO :: RS.Array n a -> AS.Array n a
+fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec
+
+toO :: AS.Array n a -> RS.Array n a
+toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
+
+liftO1 :: (AS.Array n a -> AS.Array n' b)
+ -> RS.Array n a -> RS.Array n' b
+liftO1 f = toO . f . fromO
+
+liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
+ -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
+liftO2 f x y = toO (f (fromO x) (fromO y))
+
+liftVEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (VS.Vector a -> VS.Vector b)
+ -> RS.Array n a -> RS.Array n b
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ let vec' = f (VS.slice blockOff blockSz vec)
+ in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
+ | otherwise = RS.fromVector sh (f (RS.toVector arr))
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
new file mode 100644
index 0000000..bf47622
--- /dev/null
+++ b/src/Data/Array/XArray.hs
@@ -0,0 +1,348 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.XArray where
+
+import Control.DeepSeq (NFData)
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
+import Data.Array.Ranked qualified as ORB
+import Data.Array.RankedS qualified as S
+import Data.Coerce
+import Data.Foldable (toList)
+import Data.Kind
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Type.Ord
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
+
+
+type XArray :: [Maybe Nat] -> Type -> Type
+newtype XArray sh a = XArray (S.Array (Rank sh) a)
+ deriving (Show, Eq, Ord, Generic)
+
+instance NFData (XArray sh a)
+
+
+shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
+shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
+ where
+ go :: StaticShX sh' -> [Int] -> IShX sh'
+ go ZKX [] = ZSX
+ go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
+ go _ _ = error "Invalid shapeL"
+
+fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
+fromVector sh v
+ | Dict <- lemKnownNatRank sh
+ = XArray (S.fromVector (shxToList sh) v)
+
+toVector :: Storable a => XArray sh a -> VS.Vector a
+toVector (XArray arr) = S.toVector arr
+
+-- | This allows observing the strides in the underlying orthotope array. This
+-- can be useful for optimisation, but should be considered an implementation
+-- detail: strides may change in new versions of this library without notice.
+arrayStrides :: XArray sh a -> [Int]
+arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides
+
+scalar :: Storable a => a -> XArray '[] a
+scalar = XArray . S.scalar
+
+-- | Will throw if the array does not have the casted-to shape.
+cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> StaticShX sh'
+ -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+cast ssh1 sh2 ssh' (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh'
+ , Refl <- lemRankApp (ssxFromShX sh2) ssh'
+ = let arrsh :: IShX sh1
+ (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ in if shxToList arrsh == shxToList sh2
+ then XArray arr
+ else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
+
+unScalar :: Storable a => XArray '[] a -> a
+unScalar (XArray a) = S.unScalar a
+
+replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
+replicate sh ssh' (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh'
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh')
+ , Refl <- lemRankApp (ssxFromShX sh) ssh'
+ = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
+ S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr)
+ arr)
+
+replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
+replicateScal sh x
+ | Dict <- lemKnownNatRank sh
+ = XArray (S.constant (shxToList sh) x)
+
+generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
+generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
+
+-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
+-- generateM sh f | Dict <- lemKnownNatRank sh =
+-- XArray . S.fromVector (shxShapeL sh)
+-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
+
+indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
+indexPartial (XArray arr) ZIX = XArray arr
+indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
+
+index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
+index xarr i
+ | Refl <- lemAppNil @sh
+ = let XArray arr' = indexPartial xarr i :: XArray '[] a
+ in S.unScalar arr'
+
+append :: forall n m sh a. Storable a
+ => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
+append ssh (XArray a) (XArray b)
+ | Dict <- lemKnownNatRankSSX ssh
+ = XArray (S.append a b)
+
+-- | All arrays must have the same shape, except possibly for the outermost
+-- dimension.
+concat :: Storable a
+ => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a
+concat ssh l
+ | Dict <- lemKnownNatRankSSX ssh
+ = XArray (S.concatOuter (coerce (toList l)))
+
+-- | If the prefix of the shape of the input array (@sh@) is empty (i.e.
+-- contains a zero), then there is no way to deduce the full shape of the output
+-- array (more precisely, the @sh2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the shape with zeros wherever we cannot deduce
+-- what it should be.
+--
+-- For example, if:
+--
+-- @
+-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21]
+-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float
+-- @
+--
+-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@
+-- in this shape: we don't know if @f@ intended to return an array with shape 0
+-- here (it probably didn't), but there is no better number to put here absent
+-- a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
+rerank :: forall sh sh1 sh2 a b.
+ (Storable a, Storable b)
+ => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
+ -> (XArray sh1 a -> XArray sh2 b)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
+rerank ssh ssh1 ssh2 f xarr@(XArray arr)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
+ = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
+ in if 0 `elem` shxToList sh
+ then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
+ else case () of
+ () | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
+ (\a -> let XArray r = f (XArray a) in r)
+ arr)
+
+rerankTop :: forall sh1 sh2 sh a b.
+ (Storable a, Storable b)
+ => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
+ -> (XArray sh1 a -> XArray sh2 b)
+ -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
+rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
+
+-- | The caveat about empty arrays at @rerank@ applies here too.
+rerank2 :: forall sh sh1 sh2 a b c.
+ (Storable a, Storable b, Storable c)
+ => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
+ -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
+rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
+ = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
+ in if 0 `elem` shxToList sh
+ then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
+ else case () of
+ () | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
+ (\a b -> let XArray r = f (XArray a) (XArray b) in r)
+ arr1 arr2)
+
+-- | The list argument gives indices into the original dimension list.
+transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh)
+ => StaticShX sh
+ -> Perm is
+ -> XArray sh a
+ -> XArray (PermutePrefix is sh) a
+transpose ssh perm (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
+ , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankDropLen ssh perm
+ = XArray (S.transpose (permToList' perm) arr)
+
+-- | The list argument gives indices into the original dimension list.
+--
+-- The permutation (the list) must have length <= @n@. If it is longer, this
+-- function throws.
+transposeUntyped :: forall n sh a.
+ SNat n -> StaticShX sh -> [Int]
+ -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
+transposeUntyped sn ssh perm (XArray arr)
+ | length perm <= fromSNat' sn
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
+ = XArray (S.transpose perm arr)
+ | otherwise
+ = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
+
+transpose2 :: forall sh1 sh2 a.
+ StaticShX sh1 -> StaticShX sh2
+ -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
+transpose2 ssh1 ssh2 (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh2
+ , Refl <- lemRankApp ssh2 ssh1
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
+ , Refl <- lemRankAppComm ssh1 ssh2
+ , let n1 = ssxLength ssh1
+ = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
+
+sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
+sumFull _ (XArray arr) =
+ S.unScalar $
+ liftO1 (numEltSum1Inner (SNat @0)) $
+ S.fromVector [product (S.shapeL arr)] $
+ S.toVector arr
+
+sumInner :: forall sh sh' a. (Storable a, NumElt a)
+ => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
+sumInner ssh ssh' arr
+ | Refl <- lemAppNil @sh
+ = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ sh'F = shxFlatten sh' :$% ZSX
+ ssh'F = ssxFromShX sh'F
+
+ go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
+ go (XArray arr')
+ | Refl <- lemRankApp ssh ssh'F
+ , let sn = listxRank (let StaticShX l = ssh in l)
+ = XArray (liftO1 (numEltSum1Inner sn) arr')
+
+ in go $
+ transpose2 ssh'F ssh $
+ reshapePartial ssh' ssh sh'F $
+ transpose2 ssh ssh' $
+ arr
+
+sumOuter :: forall sh sh' a. (Storable a, NumElt a)
+ => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
+sumOuter ssh ssh' arr
+ | Refl <- lemAppNil @sh
+ = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ shF = shxFlatten sh :$% ZSX
+ in sumInner ssh' (ssxFromShX shF) $
+ transpose2 (ssxFromShX shF) ssh' $
+ reshapePartial ssh ssh' shF $
+ arr
+
+fromListOuter :: forall n sh a. Storable a
+ => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
+fromListOuter ssh l
+ | Dict <- lemKnownNatRankSSX ssh
+ = case ssh of
+ SKnown m :!% _ | fromSNat' m /= length l ->
+ error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
+ "does not match the type (" ++ show (fromSNat' m) ++ ")"
+ _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+
+toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
+toListOuter (XArray arr) =
+ case S.shapeL arr of
+ 0 : _ -> []
+ _ -> coerce (ORB.toList (S.unravel arr))
+
+fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
+fromList1 ssh l =
+ let n = length l
+ in case ssh of
+ SKnown m :!% _ | fromSNat' m /= n ->
+ error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
+ "does not match the type (" ++ show (fromSNat' m) ++ ")"
+ _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+
+toList1 :: Storable a => XArray '[n] a -> [a]
+toList1 (XArray arr) = S.toList arr
+
+-- | Throws if the given shape is not, in fact, empty.
+empty :: forall sh a. Storable a => IShX sh -> XArray sh a
+empty sh
+ | Dict <- lemKnownNatRank sh
+ , shxSize sh == 0
+ = XArray (S.fromVector (shxToList sh) VS.empty)
+ | otherwise
+ = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh
+
+slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
+slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr)
+
+sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a
+sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr)
+
+rev1 :: XArray (n : sh) a -> XArray (n : sh) a
+rev1 (XArray arr) = XArray (S.rev [0] arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
+reshape ssh1 sh2 (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh1
+ , Dict <- lemKnownNatRank sh2
+ = XArray (S.reshape (shxToList sh2) arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+reshapePartial ssh1 ssh' sh2 (XArray arr)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh')
+ = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
+
+-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
+iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a
+iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))
diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs
new file mode 100644
index 0000000..b424857
--- /dev/null
+++ b/src/Data/Bag.hs
@@ -0,0 +1,18 @@
+{-# LANGUAGE DeriveTraversable #-}
+module Data.Bag where
+
+
+-- | An ordered sequence that can be folded over.
+data Bag a = BZero | BOne a | BTwo (Bag a) (Bag a) | BList [Bag a]
+ deriving (Show, Functor, Foldable, Traversable)
+
+-- Really only here for 'pure'
+instance Applicative Bag where
+ pure = BOne
+ BZero <*> _ = BZero
+ BOne f <*> t = f <$> t
+ BTwo f1 f2 <*> t = BTwo (f1 <*> t) (f2 <*> t)
+ BList fs <*> t = BList [f <*> t | f <- fs]
+
+instance Semigroup (Bag a) where (<>) = BTwo
+instance Monoid (Bag a) where mempty = BZero
diff --git a/src/Data/INat.hs b/src/Data/INat.hs
deleted file mode 100644
index af8f18b..0000000
--- a/src/Data/INat.hs
+++ /dev/null
@@ -1,121 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.INat where
-
-import Data.Proxy
-import Data.Type.Equality ((:~:) (Refl))
-import Numeric.Natural
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-
--- | Evidence for the constraint @c a@.
-data Dict c a where
- Dict :: c a => Dict c a
-
--- | An inductive peano natural number. Intended to be used at the type level.
-data INat = Z | S INat
- deriving (Show)
-
--- | Singleton for a 'INat'.
-data SINat n where
- SZ :: SINat Z
- SS :: SINat n -> SINat (S n)
-deriving instance Show (SINat n)
-
--- | A singleton 'SINat' corresponding to @n@.
-class KnownINat n where inatSing :: SINat n
-instance KnownINat Z where inatSing = SZ
-instance KnownINat n => KnownINat (S n) where inatSing = SS inatSing
-
--- | Explicitly bidirectional pattern synonym that converts between a singleton
--- 'SINat' and evidence of a 'KnownINat' constraint. Analogous to 'GHC.SNat'.
-pattern SINat' :: () => KnownINat n => SINat n
-pattern SINat' <- (snatKnown -> Dict)
- where SINat' = inatSing
-
--- | A 'KnownINat' dictionary is just a singleton natural, so we can create
--- evidence of 'KnownINat' given an 'SINat'.
-snatKnown :: SINat n -> Dict KnownINat n
-snatKnown SZ = Dict
-snatKnown (SS n) | Dict <- snatKnown n = Dict
-
--- | Convert a 'INat' to a normal number.
-fromINat :: INat -> Natural
-fromINat Z = 0
-fromINat (S n) = 1 + fromINat n
-
--- | Convert an 'SINat' to a normal number.
-fromSINat :: SINat n -> Natural
-fromSINat SZ = 0
-fromSINat (SS n) = 1 + fromSINat n
-
--- | The value of a known inductive natural as a value-level integer.
-inatVal :: forall n. KnownINat n => Proxy n -> Natural
-inatVal _ = fromSINat (inatSing @n)
-
--- | Add two 'INat's
-type family n +! m where
- Z +! m = m
- S n +! m = S (n +! m)
-
--- | Convert a 'INat' to a "GHC.TypeLits" 'G.Nat'.
-type family FromINat n where
- FromINat Z = 0
- FromINat (S n) = 1 + FromINat n
-
--- | Convert a "GHC.TypeLits" 'G.Nat' to a 'INat'.
-type family ToINat (n :: Nat) where
- ToINat 0 = Z
- ToINat n = S (ToINat (n - 1))
-
-lemInjectiveFromINat :: n :~: ToINat (FromINat n)
-lemInjectiveFromINat = unsafeCoerce Refl
-
-lemSuccFromINat :: Proxy n -> 1 + FromINat n :~: FromINat (S n)
-lemSuccFromINat _ = unsafeCoerce Refl
-
-lemAddFromINat :: Proxy m -> Proxy n
- -> FromINat m + FromINat n :~: FromINat (m +! n)
-lemAddFromINat _ = unsafeCoerce Refl
-
-lemInjectiveToINat :: n :~: FromINat (ToINat n)
-lemInjectiveToINat = unsafeCoerce Refl
-
-lemSuccToINat :: Proxy n -> ToINat (1 + n) :~: S (ToINat n)
-lemSuccToINat _ = unsafeCoerce Refl
-
-lemAddToINat :: Proxy m -> Proxy n -> ToINat (m + n) :~: ToINat m +! ToINat n
-lemAddToINat _ _ = unsafeCoerce Refl
-
--- | If an inductive 'INat' is known, then the corresponding "GHC.TypeLits"
--- 'G.Nat' is also known.
-knownNatFromINat :: KnownINat n => Proxy n -> Dict KnownNat (FromINat n)
-knownNatFromINat (Proxy @n) = go (SINat' @n)
- where
- go :: SINat m -> Dict KnownNat (FromINat m)
- go SZ = Dict
- go (SS n) | Dict <- go n = Dict
-
--- * Some type-level inductive naturals
-
-type I0 = Z
-type I1 = S I0
-type I2 = S I1
-type I3 = S I2
-type I4 = S I3
-type I5 = S I4
-type I6 = S I5
-type I7 = S I6
-type I8 = S I7
-type I9 = S I8
diff --git a/test/Gen.hs b/test/Gen.hs
new file mode 100644
index 0000000..044de14
--- /dev/null
+++ b/test/Gen.hs
@@ -0,0 +1,174 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Gen where
+
+import Data.ByteString qualified as BS
+import Data.Foldable (toList)
+import Data.Type.Equality
+import Data.Type.Ord
+import Data.Vector.Storable qualified as VS
+import Foreign
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
+import System.Random qualified as Random
+
+import Util
+
+
+-- | Generates zero with small probability, because there's typically only one
+-- interesting case for 0 anyway.
+genRank :: Monad m => (forall n. SNat n -> PropertyT m ()) -> PropertyT m ()
+genRank k = do
+ rank <- forAll $ Gen.frequency [(1, return 0)
+ ,(49, Gen.int (Range.linear 1 8))]
+ TN.withSomeSNat (fromIntegral rank) k
+
+genLowBiased :: RealFloat a => (a, a) -> Gen a
+genLowBiased (lo, hi) = do
+ x <- Gen.realFloat (Range.linearFrac 0 1)
+ return (lo + x * x * x * (hi - lo))
+
+shuffleShR :: IShR n -> Gen (IShR n)
+shuffleShR = \sh -> go (length sh) (toList sh) sh
+ where
+ go :: Int -> [Int] -> IShR n -> Gen (IShR n)
+ go _ _ ZSR = return ZSR
+ go nbag bag (_ :$: sh) = do
+ idx <- Gen.int (Range.linear 0 (nbag - 1))
+ let (dim, bag') = case splitAt idx bag of
+ (pre, n : post) -> (n, pre ++ post)
+ _ -> error "unreachable"
+ (dim :$:) <$> go (nbag - 1) bag' sh
+
+genShR :: SNat n -> Gen (IShR n)
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
+ let n = fromSNat' sn
+ targetSize <- Gen.int (Range.linear 0 targetMax)
+ let genDims :: SNat m -> Int -> Gen (IShR m)
+ genDims SZ _ = return ZSR
+ genDims (SS m) 0 = do
+ dim <- Gen.int (Range.linear 0 20)
+ dims <- genDims m 0
+ return (dim :$: dims)
+ genDims (SS m) tgt = do
+ dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt))))
+ ,(2 , return tgt)
+ ,(4 , return 1)
+ ,(1 , return 0)]
+ dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
+ return (dim :$: dims)
+ dims <- genDims sn targetSize
+ let dimsL = toList dims
+ maxdim = maximum dimsL
+ cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize)
+ shuffleShR (min cap <$> dims)
+
+-- | Example: given 3 and 7, might return:
+--
+-- @
+-- ([ 13, 4, 27 ]
+-- ,[1, 13, 1, 1, 4, 27, 1]
+-- ,[4, 13, 1, 3, 4, 27, 2])
+-- @
+--
+-- The up-replicated dimensions are always nonzero and not very large, but the
+-- other dimensions might be zero.
+genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
+genReplicatedShR = \m n -> do
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
+ (sh2, sh3) <- injectOnes n sh1 sh1
+ return (sh1, sh2, sh3)
+ where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
+ injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
+ injectOnes n@SNat shOnes sh
+ | m@SNat <- shrRank sh
+ = case cmpNat n m of
+ LTI -> error "unreachable"
+ EQI -> return (shOnes, sh)
+ GTI -> do
+ index <- Gen.int (Range.linear 0 (fromSNat' m))
+ value <- Gen.int (uncurry Range.linear repvalrange)
+ Refl <- return (lem n m)
+ injectOnes n (inject index 1 shOnes) (inject index value sh)
+
+ lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+ lem _ _ = unsafeCoerceRefl
+
+ inject :: Int -> Int -> IShR m -> IShR (m + 1)
+ inject 0 v sh = v :$: sh
+ inject i v (w :$: sh) = w :$: inject (i - 1) v sh
+ inject _ _ ZSR = error "unreachable"
+
+genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
+genStorables rng f = do
+ n <- Gen.int rng
+ seed <- Gen.resize 99 $ Gen.int Range.linearBounded
+ let gen0 = Random.mkStdGen seed
+ (bs, _) = Random.uniformByteString (8 * n) gen0
+ let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]])
+ return $ VS.generate n (f . readW64)
+
+genStaticShX :: Monad m => SNat n -> (forall sh. Rank sh ~ n => StaticShX sh -> PropertyT m ()) -> PropertyT m ()
+genStaticShX = \n k -> case n of
+ SZ -> k ZKX
+ SS n' ->
+ genItem $ \item ->
+ genStaticShX n' $ \ssh ->
+ k (item :!% ssh)
+ where
+ genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m ()
+ genItem k = do
+ b <- forAll Gen.bool
+ if b
+ then do
+ n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4))
+ ,(1, return 0)]
+ TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn)
+ else k (SUnknown ())
+
+genShX :: StaticShX sh -> Gen (IShX sh)
+genShX ZKX = return ZSX
+genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh
+genShX (SUnknown () :!% ssh) = do
+ dim <- Gen.int (Range.linear 1 4)
+ (SUnknown dim :$%) <$> genShX ssh
+
+genPermR :: Int -> Gen PermR
+genPermR n = Gen.shuffle [0 .. n-1]
+
+genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r
+genPerm n@SNat k = do
+ list <- forAll $ genPermR (fromSNat' n)
+ permFromList list $ \perm -> do
+ case permCheckPermutation perm $
+ case sameNat' (permRank perm) n of
+ Just Refl -> Just (k perm)
+ Nothing -> Nothing
+ of
+ Just (Just act) -> act
+ _ -> error ""
diff --git a/test/Main.hs b/test/Main.hs
index 2363813..575bb15 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,29 +1,15 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ImportQualifiedPost #-}
module Main where
-import Data.Array.Nested
+import Test.Tasty
+import Tests.C qualified
+import Tests.Permutation qualified
-arr :: Ranked I2 (Shaped [2, 3] (Double, Int))
-arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
- sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) ->
- let s = 24*i + 6*j + 3*k + l
- in (fromIntegral s, s)
-
-foo :: (Double, Int)
-foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS)
-
-bad :: Ranked I2 (Ranked I1 Double)
-bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
- rgenerate (i :$: ZSR) $ \(k :.: ZIR) ->
- let s = 24*i + 6*j + 3*k
- in fromIntegral s
main :: IO ()
-main = do
- print arr
- print foo
- print (rtranspose [1,0] arr)
- -- print bad
+main = defaultMain $
+ testGroup "Tests"
+ [Tests.C.tests
+ ,Tests.Permutation.tests
+ ]
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
new file mode 100644
index 0000000..9567393
--- /dev/null
+++ b/test/Tests/C.hs
@@ -0,0 +1,160 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.C where
+
+import Control.Monad
+import Data.Array.RankedS qualified as OR
+import Data.Foldable (toList)
+import Data.Functor.Const
+import Data.Type.Equality
+import Foreign
+import GHC.TypeLits
+
+import Data.Array.Nested
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
+import Hedgehog.Range qualified as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+import Util
+
+
+-- | Appropriate for simple different summation orders
+fineTol :: Double
+fineTol = 1e-8
+
+debugCoverage :: Bool
+debugCoverage = False
+
+prop_sum_nonempty :: Property
+prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
+ -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ genShR inrank
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ guard (all (> 0) (shrTail sh)) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_empty :: Property
+prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
+ -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above.
+ _outrank :: SNat n <- return $ SNat @(nm1 + 1)
+ let inrank = SNat @(n + 1)
+ sh <- forAll $ do
+ shtt <- genShR outrankm1 -- nm1
+ sht <- shuffleShR (0 :$: shtt) -- n
+ n <- Gen.int (Range.linear 0 20)
+ return (n :$: sht) -- n + 1
+ guard (0 `elem` shrTail sh)
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
+ let arr = OR.fromList @(n + 1) @Double (toList sh) []
+ let rarr = rfromOrthotope inrank arr
+ OR.toList (rtoOrthotope (rsumOuter1 rarr)) === []
+
+prop_sum_lasteq1 :: Property
+prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
+ let inrank = SNat @(n + 1)
+ outsh <- forAll $ genShR outrank
+ guard (all (> 0) outsh)
+ let insh = shrAppend outsh (1 :$: ZSR)
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
+ genStorables (Range.singleton (product insh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope inrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
+
+prop_sum_replicated :: Bool -> Property
+prop_sum_replicated doTranspose = property $
+ genRank $ \inrank1@(SNat @m) ->
+ genRank $ \outrank@(SNat @nm1) -> do
+ inrank2 :: SNat n <- return $ SNat @(nm1 + 1)
+ (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of
+ LTI -> return Refl -- actually we only continue if m < n
+ _ -> discard
+ (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
+ guard (all (> 0) sh3)
+ arr <- forAllT $
+ OR.stretch (toList sh3)
+ . OR.reshape (toList sh2)
+ . OR.fromVector @Double @m (toList sh1) <$>
+ genStorables (Range.singleton (product sh1))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ arrTrans <-
+ if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
+ return $ OR.transpose perm arr
+ else return arr
+ let rarr = rfromOrthotope inrank2 arrTrans
+ almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+
+prop_negate_with :: forall f b. Show b
+ => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ())
+ -> (forall n. f n -> IShR n -> Gen b)
+ -> (forall n. f n -> b -> OR.Array n Double -> OR.Array n Double)
+ -> Property
+prop_negate_with genRank' genB preproc = property $
+ genRank' $ \extra rank@(SNat @n) -> do
+ sh <- forAll $ genShR rank
+ guard (all (> 0) sh)
+ arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ bval <- forAll $ genB extra sh
+ let arr' = preproc extra bval arr
+ annotate (show (OR.shapeL arr'))
+ let rarr = rfromOrthotope rank arr'
+ rtoOrthotope (negate rarr) === OR.mapA negate arr'
+
+tests :: TestTree
+tests = testGroup "C"
+ [testGroup "sum"
+ [testProperty "nonempty" prop_sum_nonempty
+ ,testProperty "empty" prop_sum_empty
+ ,testProperty "last==1" prop_sum_lasteq1
+ ,testProperty "replicated" (prop_sum_replicated False)
+ ,testProperty "replicated_transposed" (prop_sum_replicated True)
+ ]
+ ,testGroup "negate"
+ [testProperty "normalised" $ prop_negate_with
+ (\k -> genRank (k (Const ())))
+ (\_ _ -> pure ())
+ (\_ _ -> id)
+ ,testProperty "slice 1D" $ prop_negate_with @((:~:) 1)
+ (\k -> k Refl (SNat @1))
+ (\Refl (n :$: _) -> do lo <- Gen.integral (Range.constant 0 (n-1))
+ len <- Gen.integral (Range.constant 0 (n-lo))
+ return [(lo, len)])
+ (\_ -> OR.slice)
+ ,testProperty "slice nD" $ prop_negate_with
+ (\k -> genRank (k (Const ())))
+ (\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1))
+ len <- Gen.integral (Range.constant 0 (n-lo-1))
+ return (lo, len)
+ pairs <- mapM genPair (toList sh)
+ return pairs)
+ (\_ -> OR.slice)
+ ]
+ ]
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
new file mode 100644
index 0000000..98a6da5
--- /dev/null
+++ b/test/Tests/Permutation.hs
@@ -0,0 +1,39 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Tests.Permutation where
+
+import Data.Type.Equality
+
+import Data.Array.Nested.Permutation
+
+import Hedgehog
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
+import Test.Tasty
+import Test.Tasty.Hedgehog
+
+-- import Debug.Trace
+
+import Gen
+
+
+tests :: TestTree
+tests = testGroup "Permutation"
+ [testProperty "permCheckPermutation" $ property $ do
+ n <- forAll $ Gen.int (Range.linear 0 10)
+ list <- forAll $ genPermR n
+ let r = permFromList list $ \perm ->
+ permCheckPermutation perm ()
+ case r of
+ Just () -> return ()
+ Nothing -> failure
+ ,testProperty "permInverse" $ property $
+ genRank $ \n ->
+ genPerm n $ \perm ->
+ genStaticShX n $ \ssh ->
+ permInverse perm $ \_invperm proof ->
+ case proof ssh of
+ Refl -> return ()
+ ]
diff --git a/test/Util.hs b/test/Util.hs
new file mode 100644
index 0000000..8a5ba72
--- /dev/null
+++ b/test/Util.hs
@@ -0,0 +1,51 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Util where
+
+import Data.Array.RankedS qualified as OR
+import Data.Kind
+import GHC.TypeLits
+import Hedgehog
+import Hedgehog.Internal.Property (failDiff)
+
+import Data.Array.Nested.Types (fromSNat')
+
+
+-- Returns highest value that satisfies the predicate, or `lo` if none does
+binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a
+binarySearch div2 = \lo hi f -> case (f lo, f hi) of
+ (False, _) -> lo
+ (_, True) -> hi
+ (_, _ ) -> go lo hi f
+ where
+ go lo hi f = -- invariant: f lo && not (f hi)
+ let mid = lo + div2 (hi - lo)
+ in if mid `elem` [lo, hi]
+ then mid
+ else if f mid then go mid hi f else go lo mid f
+
+orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a
+orSumOuter1 (sn@SNat :: SNat n) =
+ let n = fromSNat' sn
+ in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
+
+class AlmostEq f where
+ type AlmostEqConstr f :: Type -> Constraint
+ -- | absolute tolerance, lhs, rhs
+ almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
+ => a -> f a -> f a -> m ()
+
+instance AlmostEq (OR.Array n) where
+ type AlmostEqConstr (OR.Array n) = OR.Unbox
+ almostEq atol lhs rhs
+ | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =
+ success
+ | otherwise =
+ failDiff lhs rhs