aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.stylish-haskell.yaml452
-rw-r--r--CHANGELOG.md7
-rw-r--r--README.md191
-rw-r--r--bench/Main.hs8
-rwxr-xr-xgentrace.sh2
-rw-r--r--ox-arrays.cabal78
-rw-r--r--release-hints.txt1
-rw-r--r--src/Data/Array/Nested.hs47
-rw-r--r--src/Data/Array/Nested/Convert.hs333
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs86
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs59
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs559
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs495
-rw-r--r--src/Data/Array/Nested/Lemmas.hs (renamed from src/Data/Array/Mixed/Lemmas.hs)125
-rw-r--r--src/Data/Array/Nested/Mixed.hs (renamed from src/Data/Array/Nested/Internal/Mixed.hs)185
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs (renamed from src/Data/Array/Mixed/Shape.hs)119
-rw-r--r--src/Data/Array/Nested/Permutation.hs (renamed from src/Data/Array/Mixed/Permutation.hs)28
-rw-r--r--src/Data/Array/Nested/Ranked.hs323
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs268
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs369
-rw-r--r--src/Data/Array/Nested/Shaped.hs272
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs255
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs (renamed from src/Data/Array/Nested/Internal/Shape.hs)402
-rw-r--r--src/Data/Array/Nested/Trace.hs8
-rw-r--r--src/Data/Array/Nested/Types.hs (renamed from src/Data/Array/Mixed/Types.hs)24
-rw-r--r--src/Data/Array/Strided/Orthotope.hs (renamed from src/Data/Array/Mixed/Internal/Arith.hs)4
-rw-r--r--src/Data/Array/XArray.hs (renamed from src/Data/Array/Mixed/XArray.hs)28
-rw-r--r--test/Gen.hs24
-rw-r--r--test/Tests/C.hs15
-rw-r--r--test/Tests/Permutation.hs2
-rw-r--r--test/Util.hs2
31 files changed, 2878 insertions, 1893 deletions
diff --git a/.stylish-haskell.yaml b/.stylish-haskell.yaml
new file mode 100644
index 0000000..bfd25ea
--- /dev/null
+++ b/.stylish-haskell.yaml
@@ -0,0 +1,452 @@
+# stylish-haskell configuration file
+# ==================================
+
+# The stylish-haskell tool is mainly configured by specifying steps. These steps
+# are a list, so they have an order, and one specific step may appear more than
+# once (if needed). Each file is processed by these steps in the given order.
+steps:
+ # Convert some ASCII sequences to their Unicode equivalents. This is disabled
+ # by default.
+ # - unicode_syntax:
+ # # In order to make this work, we also need to insert the UnicodeSyntax
+ # # language pragma. If this flag is set to true, we insert it when it's
+ # # not already present. You may want to disable it if you configure
+ # # language extensions using some other method than pragmas. Default:
+ # # true.
+ # add_language_pragma: true
+
+ # Format module header
+ #
+ # Currently, this option is not configurable and will format all exports and
+ # module declarations to minimize diffs
+ #
+ # - module_header:
+ # # How many spaces use for indentation in the module header.
+ # indent: 4
+ #
+ # # Should export lists be sorted? Sorting is only performed within the
+ # # export section, as delineated by Haddock comments.
+ # sort: true
+ #
+ # # See `separate_lists` for the `imports` step.
+ # separate_lists: true
+
+ # Format record definitions. This is disabled by default.
+ #
+ # You can control the layout of record fields. The only rules that can't be configured
+ # are these:
+ #
+ # - "|" is always aligned with "="
+ # - "," in fields is always aligned with "{"
+ # - "}" is likewise always aligned with "{"
+ #
+ # - records:
+ # # How to format equals sign between type constructor and data constructor.
+ # # Possible values:
+ # # - "same_line" -- leave "=" AND data constructor on the same line as the type constructor.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of the next line.
+ # equals: "indent 2"
+ #
+ # # How to format first field of each record constructor.
+ # # Possible values:
+ # # - "same_line" -- "{" and first field goes on the same line as the data constructor.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of the data constructor
+ # first_field: "indent 2"
+ #
+ # # How many spaces to insert between the column with "," and the beginning of the comment in the next line.
+ # field_comment: 2
+ #
+ # # How many spaces to insert before "deriving" clause. Deriving clauses are always on separate lines.
+ # deriving: 2
+ #
+ # # How many spaces to insert before "via" clause counted from indentation of deriving clause
+ # # Possible values:
+ # # - "same_line" -- "via" part goes on the same line as "deriving" keyword.
+ # # - "indent N" -- insert a new line and N spaces from the beginning of "deriving" keyword.
+ # via: "indent 2"
+ #
+ # # Sort typeclass names in the "deriving" list alphabetically.
+ # sort_deriving: true
+ #
+ # # Wheter or not to break enums onto several lines
+ # #
+ # # Default: false
+ # break_enums: false
+ #
+ # # Whether or not to break single constructor data types before `=` sign
+ # #
+ # # Default: true
+ # break_single_constructors: true
+ #
+ # # Whether or not to curry constraints on function.
+ # #
+ # # E.g: @allValues :: Enum a => Bounded a => Proxy a -> [a]@
+ # #
+ # # Instead of @allValues :: (Enum a, Bounded a) => Proxy a -> [a]@
+ # #
+ # # Default: false
+ # curried_context: false
+
+ # Align the right hand side of some elements. This is quite conservative
+ # and only applies to statements where each element occupies a single
+ # line.
+ # Possible values:
+ # - always - Always align statements.
+ # - adjacent - Align statements that are on adjacent lines in groups.
+ # - never - Never align statements.
+ # All default to always.
+ - simple_align:
+ cases: never
+ top_level_patterns: never
+ records: never
+ multi_way_if: never
+
+ # Import cleanup
+ - imports:
+ # There are different ways we can align names and lists.
+ #
+ # - global: Align the import names and import list throughout the entire
+ # file.
+ #
+ # - file: Like global, but don't add padding when there are no qualified
+ # imports in the file.
+ #
+ # - group: Only align the imports per group (a group is formed by adjacent
+ # import lines).
+ #
+ # - none: Do not perform any alignment.
+ #
+ # Default: global.
+ align: group
+
+ # The following options affect only import list alignment.
+ #
+ # List align has following options:
+ #
+ # - after_alias: Import list is aligned with end of import including
+ # 'as' and 'hiding' keywords.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # > init, last, length)
+ #
+ # - with_alias: Import list is aligned with start of alias or hiding.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # > init, last, length)
+ #
+ # - with_module_name: Import list is aligned `list_padding` spaces after
+ # the module name.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # init, last, length)
+ #
+ # This is mainly intended for use with `pad_module_names: false`.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head,
+ # init, last, length, scanl, scanr, take, drop,
+ # sort, nub)
+ #
+ # - new_line: Import list starts always on new line.
+ #
+ # > import qualified Data.List as List
+ # > (concat, foldl, foldr, head, init, last, length)
+ #
+ # - repeat: Repeat the module name to align the import list.
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, head)
+ # > import qualified Data.List as List (init, last, length)
+ #
+ # Default: after_alias
+ list_align: after_alias
+
+ # Right-pad the module names to align imports in a group:
+ #
+ # - true: a little more readable
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr,
+ # > init, last, length)
+ # > import qualified Data.List.Extra as List (concat, foldl, foldr,
+ # > init, last, length)
+ #
+ # - false: diff-safe
+ #
+ # > import qualified Data.List as List (concat, foldl, foldr, init,
+ # > last, length)
+ # > import qualified Data.List.Extra as List (concat, foldl, foldr,
+ # > init, last, length)
+ #
+ # Default: true
+ pad_module_names: false
+
+ # Long list align style takes effect when import is too long. This is
+ # determined by 'columns' setting.
+ #
+ # - inline: This option will put as much specs on same line as possible.
+ #
+ # - new_line: Import list will start on new line.
+ #
+ # - new_line_multiline: Import list will start on new line when it's
+ # short enough to fit to single line. Otherwise it'll be multiline.
+ #
+ # - multiline: One line per import list entry.
+ # Type with constructor list acts like single import.
+ #
+ # > import qualified Data.Map as M
+ # > ( empty
+ # > , singleton
+ # > , ...
+ # > , delete
+ # > )
+ #
+ # Default: inline
+ long_list_align: new_line_multiline
+
+ # Align empty list (importing instances)
+ #
+ # Empty list align has following options
+ #
+ # - inherit: inherit list_align setting
+ #
+ # - right_after: () is right after the module name:
+ #
+ # > import Vector.Instances ()
+ #
+ # Default: inherit
+ empty_list_align: inherit
+
+ # List padding determines indentation of import list on lines after import.
+ # This option affects 'long_list_align'.
+ #
+ # - <integer>: constant value
+ #
+ # - module_name: align under start of module name.
+ # Useful for 'file' and 'group' align settings.
+ #
+ # Default: 4
+ list_padding: 2
+
+ # Separate lists option affects formatting of import list for type
+ # or class. The only difference is single space between type and list
+ # of constructors, selectors and class functions.
+ #
+ # - true: There is single space between Foldable type and list of it's
+ # functions.
+ #
+ # > import Data.Foldable (Foldable (fold, foldl, foldMap))
+ #
+ # - false: There is no space between Foldable type and list of it's
+ # functions.
+ #
+ # > import Data.Foldable (Foldable(fold, foldl, foldMap))
+ #
+ # Default: true
+ separate_lists: false
+
+ # Space surround option affects formatting of import lists on a single
+ # line. The only difference is single space after the initial
+ # parenthesis and a single space before the terminal parenthesis.
+ #
+ # - true: There is single space associated with the enclosing
+ # parenthesis.
+ #
+ # > import Data.Foo ( foo )
+ #
+ # - false: There is no space associated with the enclosing parenthesis
+ #
+ # > import Data.Foo (foo)
+ #
+ # Default: false
+ space_surround: false
+
+ # Enabling this argument will use the new GHC lib parse to format imports.
+ #
+ # This currently assumes a few things, it will assume that you want post
+ # qualified imports. It is also not as feature complete as the old
+ # imports formatting.
+ #
+ # It does not remove redundant lines or merge lines. As such, the full
+ # feature scope is still pending.
+ #
+ # It _is_ however, a fine alternative if you are using features that are
+ # not parseable by haskell src extensions and you're comfortable with the
+ # presets.
+ #
+ # Default: false
+ ghc_lib_parser: false
+
+ # Post qualify option moves any qualifies found in import declarations
+ # to the end of the declaration. This also adjust padding for any
+ # unqualified import declarations.
+ #
+ # - true: Qualified as <module name> is moved to the end of the
+ # declaration.
+ #
+ # > import Data.Bar
+ # > import Data.Foo qualified as F
+ #
+ # - false: Qualified remains in the default location and unqualified
+ # imports are padded to align with qualified imports.
+ #
+ # > import Data.Bar
+ # > import qualified Data.Foo as F
+ #
+ # Default: false
+ post_qualify: true
+
+ # A list of rules specifying how to group modules and how to
+ # order the groups.
+ #
+ # Each rule has a match field; the rule only applies to module
+ # names matched by this pattern. Patterns are POSIX extended
+ # regular expressions; see the documentation of Text.Regex.TDFA
+ # for details:
+ # https://hackage.haskell.org/package/regex-tdfa-1.3.1.2/docs/Text-Regex-TDFA.html
+ #
+ # Rules are processed in order, so only the *first* rule that
+ # matches a specific module will apply. Any module names that do
+ # not match a single rule will be put into a single group at the
+ # end of the import block.
+ #
+ # Example: group MyApp modules first, with everything else in
+ # one group at the end.
+ #
+ # group_rules:
+ # - match: "^MyApp\\>"
+ #
+ # > import MyApp
+ # > import MyApp.Foo
+ # >
+ # > import Control.Monad
+ # > import MyApps
+ # > import Test.MyApp
+ #
+ # A rule can also optionally have a sub_group pattern. Imports
+ # that match the rule will be broken up into further groups by
+ # the part of the module name matched by the sub_group pattern.
+ #
+ # Example: group MyApp modules first, then everything else
+ # sub-grouped by the first part of the module name.
+ #
+ # group_rules:
+ # - match: "^MyApp\\>"
+ # - match: "."
+ # sub_group: "^[^.]+"
+ #
+ # > import MyApp
+ # > import MyApp.Foo
+ # >
+ # > import Control.Applicative
+ # > import Control.Monad
+ # >
+ # > import Data.Map
+ #
+ # A pattern only needs to match part of the module name, which
+ # could be in the middle. You can use ^pattern to anchor to the
+ # beginning of the module name, pattern$ to anchor to the end
+ # and ^pattern$ to force a full match. Example:
+ #
+ # - "Test\\." would match "Test.Foo" and "Foo.Test.Lib"
+ # - "^Test\\." would match "Test.Foo" but not "Foo.Test.Lib"
+ # - "\\.Test$" would match "Foo.Test" but not "Foo.Test.Lib"
+ # - "^Test$" would *only* match "Test"
+ #
+ # You can use \\< and \\> to anchor against the beginning and
+ # end of words, respectively. For example:
+ #
+ # - "^Test\\." would match "Test.Foo" but not "Test" or "Tests"
+ # - "^Test\\>" would match "Test.Foo" and "Test", but not
+ # "Tests"
+ #
+ # The default is a single rule that matches everything and
+ # sub-groups based on the first component of the module name.
+ #
+ # Default: [{ "match" : ".*", "sub_group": "^[^.]+" }]
+# group_rules:
+# - match: ".*"
+# sub_group: "^[^.]+"
+# - match: "^Data.Array\\>"
+# sub_group: "^[^.]+"
+# - match: "^HordeAd\\>"
+
+ # Language pragmas
+ - language_pragmas:
+ # We can generate different styles of language pragma lists.
+ #
+ # - vertical: Vertical-spaced language pragmas, one per line.
+ #
+ # - compact: A more compact style.
+ #
+ # - compact_line: Similar to compact, but wrap each line with
+ # `{-#LANGUAGE #-}'.
+ #
+ # Default: vertical.
+# style: compact
+
+ # Align affects alignment of closing pragma brackets.
+ #
+ # - true: Brackets are aligned in same column.
+ #
+ # - false: Brackets are not aligned together. There is only one space
+ # between actual import and closing bracket.
+ #
+ # Default: true
+ align: false
+
+ # stylish-haskell can detect redundancy of some language pragmas. If this
+ # is set to true, it will remove those redundant pragmas. Default: true.
+ remove_redundant: true
+
+ # Language prefix to be used for pragma declaration, this allows you to
+ # use other options non case-sensitive like "language" or "Language".
+ # If a non correct String is provided, it will default to: LANGUAGE.
+ language_prefix: LANGUAGE
+
+ # Replace tabs by spaces. This is disabled by default.
+ # - tabs:
+ # # Number of spaces to use for each tab. Default: 8, as specified by the
+ # # Haskell report.
+ # spaces: 8
+
+ # Remove trailing whitespace
+ - trailing_whitespace: {}
+
+ # Squash multiple spaces between the left and right hand sides of some
+ # elements into single spaces. Basically, this undoes the effect of
+ # simple_align but is a bit less conservative.
+ # - squash: {}
+
+# A common setting is the number of columns (parts of) code will be wrapped
+# to. Different steps take this into account.
+#
+# Set this to null to disable all line wrapping.
+#
+# Default: 80.
+columns: 200
+
+# By default, line endings are converted according to the OS. You can override
+# preferred format here.
+#
+# - native: Native newline format. CRLF on Windows, LF on other OSes.
+#
+# - lf: Convert to LF ("\n").
+#
+# - crlf: Convert to CRLF ("\r\n").
+#
+# Default: native.
+newline: native
+
+# Sometimes, language extensions are specified in a cabal file or from the
+# command line instead of using language pragmas in the file. stylish-haskell
+# needs to be aware of these, so it can parse the file correctly.
+#
+# No language extensions are enabled by default.
+#language_extensions:
+# - NoStarIsType
+ # - TemplateHaskell
+ # - QuasiQuotes
+
+# Attempt to find the cabal file in ancestors of the current directory, and
+# parse options (currently only language extensions) from that.
+#
+# Default: true
+cabal: true
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..009d267
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,7 @@
+# Changelog for `ox-arrays`
+
+This package intends to follow the [PVP](https://pvp.haskell.org/).
+
+## 0.1.0.0
+- Initial release
+- Various aspects of the API are still experimental, and breaking changes are expected in the future.
diff --git a/README.md b/README.md
index 9b8d543..01bcbac 100644
--- a/README.md
+++ b/README.md
@@ -1,56 +1,165 @@
-Wrapper library around `orthotope` that defines nested arrays, including
-tuples, of (eventually) unboxed values. The arrays are represented in
-struct-of-arrays form via the `Data.Vector.Unboxed` data family trick. Below
-the surface layer, there is a more low-level wrapper around `orthotope` that
-defines an array type type-indexed by `[Maybe Nat]`: some dimensions are
-shape-typed (i.e. have their size statically known), and some not.
+## ox-arrays
-An overview of the API:
+ox-arrays is an array library that defines nested arrays, including tuples, of
+(eventually) unboxed values. The arrays are represented in struct-of-arrays
+form via the `Data.Vector.Unboxed` data family trick; the component arrays are
+`orthotope` arrays
+([RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html))
+which describe elements using a _stride vector_ or
+[LMAD](https://dl.acm.org/doi/pdf/10.1145/509705.509708) so that `transpose`
+and `replicate` need only modify array metadata, not actually move around data.
+
+Because of the struct-of-arrays representation, nested arrays are not fully
+general: indeed, arrays are not actually nested under the hood, so if one has an
+array of arrays, those element arrays must all have the same shape (length,
+width, etc.). If one has an array of tuples of arrays, then all the `fst`
+components must have the same shape and all the `snd` components must have the
+same shape, but the two pair components themselves can be different.
+
+However, the nesting functionality of ox-arrays can be completely ignored if you
+only care about other parts of its API, or the vectorised arithmetic operations
+(using hand-written C code). Nesting support mostly does not get in the way, and
+has essentially no overhead (both when it's used and when it's not used).
+
+ox-arrays defines three array types: `Ranked`, `Shaped` and `Mixed`.
+- `Ranked` corresponds to `orthotope`'s
+ [RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html)
+ and has the _rank_ of the array (its number of dimensions) on the type level.
+ For example, `Ranked 2 Float` is a two-dimensional array of `Float`s, i.e. a
+ matrix.
+- `Shaped` corresponds to `orthotope`'s
+ [ShapedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html).
+ and has the full _shape_ of the array (its dimension sizes) on the type level
+ as a type-level list of `Nat`s. For example, `Shaped [2,3] Float` is a 2-by-3
+ matrix. The innermost dimension correspond to the right-most element in the
+ list.
+- `Mixed` is halfway between the two: it has a type parameter of kind
+ `[Maybe Nat]` whose length is the rank of the array; `Nothing` elements have
+ unknown size, whereas `Just` elements have the indicated size. The type
+ `Mixed [Nothing, Nothing] a` is equivalent to `Ranked 2 a`; the type
+ `Mixed [Just n, Just m] a` is equivalent to `Shaped [n, m] a`.
+
+In various places in the API of a library like ox-arrays, one can make a
+decision between 1. requiring a type class constraint providing certain
+information (e.g.
+[KnownNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:KnownNat)
+or `orthotope`'s
+[Shape](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html#t:Shape)),
+or 2. taking singleton _values_ that encode said information in a way that is
+linked to the type level (e.g.
+[SNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:SNat)).
+`orthotope` chooses the type class approach; ox-arrays chooses the singleton
+approach. Singletons are more verbose at times, but give the programmer more
+insight in what data is flowing where, and more importantly, more control: type
+class inference is very nice and implicit, but if it's not powerful enough for
+the trickery you're doing, you're out of luck. Singletons allow you to explain
+as precisely as you want to GHC what exactly you're doing.
+
+Below the surface layer, there is a more low-level wrapper (`XArray`) around
+`orthotope` that defines a non-nested `Mixed`-style array type.
+
+Here is a little taster of the API, to get a sense for the design:
```haskell
-data Ranked (n :: INat) a {- e.g. -} Ranked 3 Float
-data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
-data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
-
-Ranked I0 a = Ranked Z a ~~= Acc.Array Z a = Acc.Scalar a
-Ranked I1 a = Ranked (S Z) a ~~= Acc.Array (Z :. Int) a = Acc.Vector a
-Ranked I2 a = Ranked (S (S Z)) a ~~= Acc.Array (Z :. Int :. Int) a = Acc.Matrix a
+import GHC.TypeLits (Nat, SNat)
+
+data Ranked (n :: Nat) a {- e.g. -} Ranked 3 Float
+data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float
+data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float
+-- Shape types are written Sh{R,S,X}. The 'I' prefix denotes a Int-filled shape;
+-- ShR and ShX are more general containers. ShS is a singleton.
+rshape :: Elt a => Ranked n a -> IShR n
+sshape :: Elt a => Shaped sh a -> ShS sh
+mshape :: Elt a => Mixed xsh a -> IShX xsh
-rshape :: (Elt a, KnownINat n) => Ranked n a -> IxR n
-sshape :: (Elt a, KnownShape sh) => Shaped sh a -> IxS sh
-mshape :: (Elt a, KnownShapeX xsh) => Mixed xsh a -> IxX xsh
+-- Index types are written Ix{R,S,X}.
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+mindex :: Elt a => Mixed xsh a -> IIxX xsh -> a
-rindex :: Elt a => Ranked n a -> IxR n -> a
-sindex :: Elt a => Shaped sh a -> IxS sh -> a
-mindex :: Elt a => Mixed xsh a -> IxX xsh -> a
+-- The index types can be used as if they were defined as follows; pattern
+-- synonyms are provided to construct the illusion. (The actual definitions are
+-- a bit more general and indirect.)
+data IIxR n where
+ ZIR :: IIxR 0
+ (:.:) :: Int -> IIxR n -> IIxR (n + 1)
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
+data IIxS sh where
+ ZIS :: IIxS '[]
+ (:.$) :: Int -> IIxS sh -> IIxS (n : sh)
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
+data IIxX xsh where
+ ZIX :: IIxX '[]
+ (:.%) :: Int -> IIxX xsh -> IIxX (mn : xsh)
-data IxX sh where
- IZX :: IxX '[]
- (::@) :: Int -> IxX sh -> IxX (Just n : sh)
- (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+-- Similarly, the shape types can be used as if they were defined as follows.
+data IShR n where
+ ZSR :: IShR 0
+ (:$:) :: Int -> IShR n -> IShR (n + 1)
+data ShS sh where
+ ZSS :: ShS '[]
+ (:$$) :: SNat n -> ShS sh -> ShS (n : sh)
+
+data IShX xsh where
+ ZSX :: IShX '[]
+ (:$%) :: SMayNat Int SNat mn -> IShX xsh -> IShX (mn : xsh)
+-- where:
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+
+-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed
+-- shape -- that is to say, only the statically-known part of a mixed shape.
+-- StaticShX provides for this need. It can be used as if defined as follows:
+data StaticShX xsh where
+ ZKX :: StaticShX '[]
+ (:!%) :: SMayNat () SNat mn -> StaticShX xsh -> StaticShX (mn : xsh)
+
+-- The Elt class describes types that can be used as elements of an array. While
+-- it is technically possible to define new instances of this class, typical
+-- usage should regard Elt as closed. The user-relevant instances are the
+-- following:
class Elt a
-instance Elt ()
-instance Elt Double
-instance Elt Int
-instance (Elt a, Elt b) => Elt (a, b)
-instance (Elt a, KnownINat n) => Elt (Ranked n a)
-instance (Elt a, KnownShape sh) => Elt (Shaped sh a)
-instance (Elt a, KnownShapeX xsh) => Elt (Mixed xsh a)
-
-rgenerate :: Elt a => IxR n -> (IxR n -> a) -> Ranked n a
-sgenerate :: (Elt a, KnownShape sh) => (IxS sh -> a) -> Shaped sh a
-mgenerate :: (Elt a, KnownShapeX xsh) => IxX xsh -> (IxX xsh -> a) -> Mixed xsh a
+instance Elt ()
+instance Elt Bool
+instance Elt Float
+instance Elt Double
+instance Elt Int
+instance (Elt a, Elt b) => Elt (a, b)
+instance Elt a => Elt (Ranked n a)
+instance Elt a => Elt (Shaped sh a)
+instance Elt a => Elt (Mixed xsh a)
+-- Essentially all functions that ox-arrays offers on arrays are first-order:
+-- add two arrays elementwise, transpose an array, append arrays, compute
+-- minima/maxima, zip/unzip, nest/unnest, etc. The first-order approach allows
+-- operations, especially arithmetic ones, to be vectorised using hand-written
+-- C code, without needing any sort of JIT compilation.
+rappend :: Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+mappend :: Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+
+-- Exceptionally, also one higher-order function is provided per array type:
+-- 'generate'. These functions have the caveat that regularity of arrays must be
+-- preserved: all returned 'a's must have equal shape. See the documentation of
+-- 'mgenerate'.
+-- Warning: because the invocations of the function you pass cannot be
+-- vectorised, 'generate' is rather slow if 'a' is small.
+-- The 'KnownElt' class captures an API infelicity where constraint-based shape
+-- passing is the only practical option.
+rgenerate :: KnownElt a => IShR n -> (IxR n -> a) -> Ranked n a
+sgenerate :: KnownElt a => ShS sh -> (IxS sh -> a) -> Shaped sh a
+mgenerate :: KnownElt a => IShX xsh -> (IxX xsh -> a) -> Mixed xsh a
+
+-- Under the hood, Ranked and Shaped are both newtypes over Mixed. Mixed itself
+-- is a data family over XArray, which is a newtype over orthotope's RankedS.
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
```
+
+About the name: when importing `orthotope` array modules, a possible naming
+convention is to use qualified imports as `OR` for "orthotope ranked" arrays and
+`OS` for "orthotope shaped" arrays. ox-arrays was started to fill the `OX` gap,
+then grew out of proportion.
diff --git a/bench/Main.hs b/bench/Main.hs
index 5901d8b..b604eb9 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -15,11 +15,11 @@ import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
import Text.Show (showListWith)
-import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed (Mixed (M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
-import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
+import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
+import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
import Data.Array.Strided.Arith.Internal qualified as Arith
+import Data.Array.XArray (XArray(..))
enableMisc :: Bool
@@ -51,7 +51,7 @@ main_tests = defaultMain
" str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
- iota n = riota @Double n
+ iota = riota @Double
in
[dotprodBench "dot 1D"
(iota 10_000_000
diff --git a/gentrace.sh b/gentrace.sh
index 7be2b9c..c3f1240 100755
--- a/gentrace.sh
+++ b/gentrace.sh
@@ -8,7 +8,7 @@ module Data.Array.Nested.Trace (
-- * Re-exports from the plain "Data.Array.Nested" module
EOF
-sed -n '/^module/,/^) where/!d; /^\s*-- /d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
+sed -n '/^module/,/^) where/!d; /^\s*--\( \|$\)/d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs
cat <<'EOF'
) where
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index c46e216..be4bb03 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -2,13 +2,22 @@ cabal-version: 3.0
name: ox-arrays
version: 0.1.0.0
synopsis: An efficient CPU-based multidimensional array (tensor) library
-description: An efficient and richly typed CPU-based multidimensional array (tensor) library built upon the optimized tensor representation (strides list) implemented in the orthotope package.
-author: Tom Smeding
-maintainer: Tom Smeding
+description:
+ An efficient and richly typed CPU-based multidimensional array (tensor)
+ library built upon the optimized tensor representation (strides list)
+ implemented in the orthotope package. See the README.
+
+ If you use this package: let me know (e.g. via email) if you find it useful!
+ Both positive feedback (keep this!) and negative feedback (I needed this but
+ ox-arrays doesn't provide it) is welcome.
+copyright: (c) 2025 Tom Smeding, Mikolaj Konarski
+author: Tom Smeding, Mikolaj Konarski
+maintainer: Tom Smeding <xhackage@tomsmeding.com>
license: BSD-3-Clause
category: Array, Tensors
build-type: Simple
+extra-doc-files: README.md CHANGELOG.md
extra-source-files: cbits/arith_lists.h
flag trace-wrappers
@@ -16,7 +25,7 @@ flag trace-wrappers
Compile modules that define wrappers around the array methods that trace
their arguments and results. This is conditional on a flag because these
modules make documentation generation fail.
- (https://gitlab.haskell.org/ghc/ghc/-/issues/24964 , should be fixed in
+ (@https://gitlab.haskell.org/ghc/ghc/-/issues/24964@ , should be fixed in
GHC 9.12)
default: False
manual: True
@@ -38,29 +47,50 @@ flag pedantic-c-warnings
default: False
manual: True
+flag default-show-instances
+ description:
+ Use default GHC-derived Show instances for arrays, shapes and indices. This
+ exposes the internal struct-of-arrays representation and is less readable,
+ but can be useful for ox-arrays debugging.
+ default: False
+ manual: True
+
+common basics
+ default-language: Haskell2010
+ ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
+
library
+ import: basics
exposed-modules:
-- put this module on top so ghci considers it the "main" module
Data.Array.Nested
- Data.Array.Mixed.Internal.Arith
- Data.Array.Mixed.Lemmas
- Data.Array.Mixed.Permutation
- Data.Array.Mixed.Shape
- Data.Array.Mixed.Types
- Data.Array.Mixed.XArray
- Data.Array.Nested.Internal.Convert
- Data.Array.Nested.Internal.Mixed
- Data.Array.Nested.Internal.Lemmas
- Data.Array.Nested.Internal.Ranked
- Data.Array.Nested.Internal.Shape
- Data.Array.Nested.Internal.Shaped
+ Data.Array.Nested.Convert
+ Data.Array.Nested.Mixed
+ Data.Array.Nested.Mixed.Shape
+ Data.Array.Nested.Lemmas
+ Data.Array.Nested.Permutation
+ Data.Array.Nested.Ranked
+ Data.Array.Nested.Ranked.Base
+ Data.Array.Nested.Ranked.Shape
+ Data.Array.Nested.Shaped
+ Data.Array.Nested.Shaped.Base
+ Data.Array.Nested.Shaped.Shape
+ Data.Array.Nested.Types
+ Data.Array.Strided.Orthotope
+ Data.Array.XArray
Data.Bag
if flag(trace-wrappers)
exposed-modules:
Data.Array.Nested.Trace
Data.Array.Nested.Trace.TH
+ build-depends:
+ template-haskell
+ other-extensions: TemplateHaskell
+
+ if flag(default-show-instances)
+ cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES
build-depends:
strided-array-ops,
@@ -73,11 +103,8 @@ library
vector
hs-source-dirs: src
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
- other-extensions: TemplateHaskell
-
library strided-array-ops
+ import: basics
exposed-modules:
Data.Array.Strided
Data.Array.Strided.Array
@@ -104,11 +131,10 @@ library strided-array-ops
-- hmatrix assumes sse2, so we can too
cc-options: -msse2
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
other-extensions: TemplateHaskell
test-suite test
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
other-modules:
@@ -129,20 +155,18 @@ test-suite test
tasty-hedgehog,
vector
hs-source-dirs: test
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
test-suite example
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
build-depends:
ox-arrays,
base
hs-source-dirs: example
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
benchmark bench
+ import: basics
type: exitcode-stdio-1.0
main-is: Main.hs
build-depends:
@@ -154,8 +178,6 @@ benchmark bench
tasty-bench,
vector
hs-source-dirs: bench
- default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
source-repository head
type: git
diff --git a/release-hints.txt b/release-hints.txt
index 259c671..d300da0 100644
--- a/release-hints.txt
+++ b/release-hints.txt
@@ -1,2 +1,3 @@
- Temporarily enable -Wredundant-constraints
- Has too many false-positives to enable normally, but sometimes catches actual redundant constraints
+- Don't forget to rerun gentrace.sh
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9869cba..c3635e9 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -10,8 +10,9 @@ module Data.Array.Nested (
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
rrerank,
- rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rfromListLinear, rfromListPrimLinear, rtoListLinear,
+ rreplicate, rreplicateScal,
+ rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear,
+ rtoList, rtoListOuter, rtoListLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest, rzip, runzip,
@@ -19,7 +20,7 @@ module Data.Array.Nested (
rlift, rlift2,
-- ** Conversions
rtoXArrayPrim, rfromXArrayPrim,
- rcastToShaped, rtoMixed, rcastToMixed,
+ rtoMixed, rcastToMixed, rcastToShaped,
rfromOrthotope, rtoOrthotope,
-- ** Additional arithmetic operations
--
@@ -36,8 +37,9 @@ module Data.Array.Nested (
-- TODO: sconcat? What should its type be?
semptyArray,
srerank,
- sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sfromListLinear, sfromListPrimLinear, stoListLinear,
+ sreplicate, sreplicateScal,
+ sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear,
+ stoList, stoListOuter, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest, szip, sunzip,
@@ -45,7 +47,7 @@ module Data.Array.Nested (
slift, slift2,
-- ** Conversions
stoXArrayPrim, sfromXArrayPrim,
- stoRanked, stoMixed, scastToMixed,
+ stoMixed, scastToMixed, stoRanked,
sfromOrthotope, stoOrthotope,
-- ** Additional arithmetic operations
--
@@ -63,8 +65,9 @@ module Data.Array.Nested (
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
mrerank,
- mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mfromListLinear, mfromListPrimLinear, mtoListLinear,
+ mreplicate, mreplicateScal,
+ mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear,
+ mtoList, mtoListOuter, mtoListLinear,
mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
@@ -73,9 +76,8 @@ module Data.Array.Nested (
-- ** Conversions
mtoXArrayPrim, mfromXArrayPrim,
mcast,
- mcastSafe, SafeMCast, SafeMCastSpec(..),
- mtoRanked, mcastToShaped,
- castCastable, Castable(..),
+ mcastToShaped, mtoRanked,
+ convert, Conversion(..),
-- ** Additional arithmetic operations
--
-- $integralRealFloat
@@ -103,22 +105,23 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Convert
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Internal.Shape
-import Data.Array.Nested.Internal.Shaped
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
import Foreign.Storable
import GHC.TypeLits
-- $integralRealFloat
--
--- These functions separate top-level functions, and not exposed in instances
--- for 'RealFloat' and 'Integral', because those classes include a variety of
--- other functions that make no sense for arrays.
+-- These functions are separate top-level functions, and not exposed in
+-- instances for 'RealFloat' and 'Integral', because those classes include a
+-- variety of other functions that make no sense for arrays.
-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but
-- having 'Num', 'Fractional' and 'Floating' available is just too useful.
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
new file mode 100644
index 0000000..2438f68
--- /dev/null
+++ b/src/Data/Array/Nested/Convert.hs
@@ -0,0 +1,333 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+module Data.Array.Nested.Convert (
+ -- * Shape\/index\/list casting functions
+ -- ** To ranked
+ ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
+ listrCast, ixrCast, shrCast,
+ -- ** To shaped
+ ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
+ ixsCast,
+ -- ** To mixed
+ ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
+ ixxCast, shxCast, shxCast',
+
+ -- * Array conversions
+ convert,
+ Conversion(..),
+
+ -- * Special cases of array conversions
+ --
+ -- | These functions can all be implemented using 'convert' in some way,
+ -- but some have fewer constraints.
+ rtoMixed, rcastToMixed, rcastToShaped,
+ stoMixed, scastToMixed, stoRanked,
+ mcast, mcastToShaped, mtoRanked,
+) where
+
+import Control.Category
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped.Base
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+
+-- * Shape or index or list casting functions
+
+-- * To ranked
+
+ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
+ixrFromIxS ZIS = ZIR
+ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX ZIX = ZIR
+ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+
+shrFromShS :: ShS sh -> IShR (Rank sh)
+shrFromShS ZSS = ZSR
+shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
+
+-- shrFromShX re-exported
+-- shrFromShX2 re-exported
+-- listrCast re-exported
+-- ixrCast re-exported
+-- shrCast re-exported
+
+-- * To shaped
+
+-- TODO: these take a ShS because there are KnownNats inside IxS.
+
+ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
+ixsFromIxR ZSS ZIR = ZIS
+ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+ixsFromIxR _ _ = error "unreachable"
+
+-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
+-- following, but more efficient:
+--
+-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx)
+ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i
+ixsFromIxR' ZSS ZIR = ZIS
+ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
+ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
+
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX ZSS ZIX = ZIS
+ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+
+-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
+-- the following, but more efficient:
+--
+-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx)
+ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i
+ixsFromIxX' ZSS ZIX = ZIS
+ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx
+ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank"
+
+-- | Produce an existential 'ShS' from an 'IShR'.
+withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r
+withShsFromShR ZSR k = k ZSS
+withShsFromShR (n :$: sh) k =
+ withShsFromShR sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
+
+-- shsFromShX re-exported
+
+-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
+-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
+withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r
+withShsFromShX ZSX k = k ZSS
+withShsFromShX (SKnown sn@SNat :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ k (sn :$$ sh')
+withShsFromShX (SUnknown n :$% sh) k =
+ withShsFromShX sh $ \sh' ->
+ withSomeSNat (fromIntegral @Int @Integer n) $ \case
+ Just sn@SNat -> k (sn :$$ sh')
+ Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+
+shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
+shsFromSSX = shsFromShX Prelude.. shxFromSSX
+
+-- ixsCast re-exported
+
+-- * To mixed
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR ZIR = ZIX
+ixxFromIxR (n :.: (idx :: IxR m i)) =
+ castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixxFromIxR idx)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS ZIS = ZIX
+ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+
+shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
+shxFromShR ZSR = ZSX
+shxFromShR (n :$: (idx :: ShR m i)) =
+ castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shxFromShR idx)
+
+shxFromShS :: ShS sh -> IShX (MapJust sh)
+shxFromShS ZSS = ZSX
+shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+
+-- ixxCast re-exported
+-- shxCast re-exported
+-- shxCast' re-exported
+
+
+-- * Array conversions
+
+-- | The constructors that perform runtime shape checking are marked with a
+-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types
+-- ensure that the shapes are already compatible. To convert between 'Ranked'
+-- and 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Conversion' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'XArray's. This leads to the inclusion of some operations that do
+-- not look like simple conversions (casts) at first glance, like 'ConvZip'.
+--
+-- /Note/: Haddock gleefully renames type variables in constructors so that
+-- they match the data type head as much as possible. See the source for a more
+-- readable presentation of this data type.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Conversion (Mixed sh a) (Shaped sh' a)
+
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
+ where
+ -- The 'esh' is the extension shape: the conversion happens under a whole
+ -- bunch of additional dimensions that it does not touch. These dimensions
+ -- are 'esh'.
+ -- The strategy is to unwind step-by-step to a large Mixed array, and to
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
+ x))
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go ConvX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x)
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go ConvZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go ConvUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
+
+ lemRankAppRankEq :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
+ lemRankAppRankEq _ _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
+
+
+-- * Special cases of array conversions
+
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked = convert ConvXR
+
+rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
+rtoMixed (Ranked arr) = arr
+
+-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
+-- compatibility check.
+rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
+rcastToMixed sshx rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank rarr)
+ = mcast sshx arr
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => ShS sh' -> Mixed sh a -> Shaped sh' a
+mcastToShaped targetsh = convert (ConvXS' targetsh)
+
+stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
+stoMixed (Shaped arr) = arr
+
+-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
+-- compatibility check.
+scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => StaticShX sh' -> Shaped sh a -> Mixed sh' a
+scastToMixed sshx sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mcast sshx arr
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr
+
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped targetsh arr
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
deleted file mode 100644
index c316161..0000000
--- a/src/Data/Array/Nested/Internal/Convert.hs
+++ /dev/null
@@ -1,86 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Convert where
-
-import Control.Category
-import Data.Proxy
-import Data.Type.Equality
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Lemmas
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Internal.Shape
-import Data.Array.Nested.Internal.Shaped
-
-
-stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
-stoRanked sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- = mtoRanked arr
-
-rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
-rcastToShaped (Ranked arr) targetsh
- | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
- , Refl <- lemRankMapJust targetsh
- = mcastToShaped arr targetsh
-
--- | The only constructor that performs runtime shape checking is 'CastXS''.
--- For the other construtors, the types ensure that the shapes are already
--- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
-data Castable a b where
- CastId :: Castable a a
- CastCmp :: Castable b c -> Castable a b -> Castable a c
-
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
-
- CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
-
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
-
-instance Category Castable where
- id = CastId
- (.) = CastCmp
-
-castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
-castCastable = \c x -> munScalar (go c (mscalar x))
- where
- -- The 'esh' is the extension shape: the casting happens under a whole
- -- bunch of additional dimensions that it does not touch. These dimensions
- -- are 'esh'.
- -- The strategy is to unwind step-by-step to a large Mixed array, and to
- -- perform the required checks and castings when re-nesting back up.
- go :: Castable a b -> Mixed esh a -> Mixed esh b
- go CastId x = x
- go (CastCmp c1 c2) x = go c1 (go c2 x)
- go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
- M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
- (go c x)))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
- (go c x)))
- go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
- go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
- go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
-
- lemRankAppMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppMapJust _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
deleted file mode 100644
index f894f78..0000000
--- a/src/Data/Array/Nested/Internal/Lemmas.hs
+++ /dev/null
@@ -1,59 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Lemmas where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Shape
-
-
-lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
-lemRankMapJust ZSS = Refl
-lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
-
-lemMapJustApp :: ShS sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustApp ZSS _ = Refl
-lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
-
-lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
-lemTakeLenMapJust PNil _ = Refl
-lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
-lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
-
-lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
-lemDropLenMapJust PNil _ = Refl
-lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
-lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
-
-lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
-lemIndexMapJust SZ (_ :$$ _) = Refl
-lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
- | Refl <- lemIndexMapJust i sh
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = Refl
-lemIndexMapJust _ ZSS = error "Index of empty"
-
-lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
-lemPermuteMapJust PNil _ = Refl
-lemPermuteMapJust (i `PCons` is) sh
- | Refl <- lemPermuteMapJust is sh
- , Refl <- lemIndexMapJust i sh
- = Refl
-
-lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
-lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKX
- go (n :$$ sh) = SKnown n :!% go sh
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
deleted file mode 100644
index daf0374..0000000
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ /dev/null
@@ -1,559 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Ranked where
-
-import Prelude hiding (mappend, mconcat)
-
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
-import Data.Array.RankedS qualified as S
-import Data.Bifunctor (first)
-import Data.Coerce (coerce)
-import Data.Foldable (toList)
-import Data.Kind (Type)
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray(..))
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Shape
-import Data.Array.Strided.Arith
-
-
--- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'Nat'.
---
--- Valid elements of a ranked arrays are described by the 'Elt' type class.
--- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
--- supported (and are represented as a single, flattened, struct-of-arrays
--- array internally).
---
--- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: Nat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
-deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
-
-instance (Show a, Elt a) => Show (Ranked n a) where
- showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
- in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Ranked n a) where
- rnf (Ranked arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance Elt a => Elt (Ranked n a) where
- mshape (M_Ranked arr) = mshape arr
- mindex (M_Ranked arr) i = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i =
- coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
-
- mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoListOuter (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift ssh2 f (M_Ranked arr) =
- coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
- coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
- @(NonEmpty (Mixed sh2 (Ranked n a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
-
- mconcat l = M_Ranked (mconcat (coerce l))
-
- mrnf (M_Ranked arr) = mrnf arr
-
- type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
-
- mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Ranked arr) = marrayStrides arr
-
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
- memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
- | Dict <- lemKnownReplicate (SNat @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
-
-liftRanked1 :: forall n a b.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
- -> Ranked n a -> Ranked n b
-liftRanked1 = coerce
-
-liftRanked2 :: forall n a b c.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
- -> Ranked n a -> Ranked n b -> Ranked n c
-liftRanked2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Ranked n a) where
- (+) = liftRanked2 (+)
- (-) = liftRanked2 (-)
- (*) = liftRanked2 (*)
- negate = liftRanked1 negate
- abs = liftRanked1 abs
- signum = liftRanked1 signum
- fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
- recip = liftRanked1 recip
- (/) = liftRanked2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
- exp = liftRanked1 exp
- log = liftRanked1 log
- sqrt = liftRanked1 sqrt
- (**) = liftRanked2 (**)
- logBase = liftRanked2 logBase
- sin = liftRanked1 sin
- cos = liftRanked1 cos
- tan = liftRanked1 tan
- asin = liftRanked1 asin
- acos = liftRanked1 acos
- atan = liftRanked1 atan
- sinh = liftRanked1 sinh
- cosh = liftRanked1 cosh
- tanh = liftRanked1 tanh
- asinh = liftRanked1 asinh
- acosh = liftRanked1 acosh
- atanh = liftRanked1 atanh
- log1p = liftRanked1 GHC.Float.log1p
- expm1 = liftRanked1 GHC.Float.expm1
- log1pexp = liftRanked1 GHC.Float.log1pexp
- log1mexp = liftRanked1 GHC.Float.log1mexp
-
-rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-rquotArray = liftRanked2 mquotArray
-rremArray = liftRanked2 mremArray
-
-ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-ratan2Array = liftRanked2 matan2Array
-
-
-remptyArray :: KnownElt a => Ranked 1 a
-remptyArray = mtoRanked (memptyArray ZSX)
-
-rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shCvtXR' (mshape arr)
-
-rrank :: Elt a => Ranked n a -> SNat n
-rrank = shrRank . rshape
-
--- | The total number of elements in the array.
-rsize :: Elt a => Ranked n a -> Int
-rsize = shrSize . rshape
-
-rindex :: Elt a => Ranked n a -> IIxR n -> a
-rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-
-rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
-rindexPartial (Ranked arr) idx =
- Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
- (ixCvtRX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
-rgenerate sh f
- | sn@SNat <- shrRank sh
- , Dict <- lemKnownReplicate sn
- , Refl <- lemRankReplicate sn
- = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-
--- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. Elt a
- => SNat n2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a
-rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-
--- | See the documentation of 'mlift2'.
-rlift2 :: forall n1 n2 n3 a. Elt a
- => SNat n3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
-rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (msumOuter1P arr)
-
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
-
-rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
-rsumAllPrim (Ranked arr) = msumAllPrim arr
-
-rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
-rtranspose perm arr
- | sn@SNat <- rrank arr
- , Dict <- lemKnownReplicate sn
- , length perm <= fromIntegral (natVal (Proxy @n))
- = rlift sn
- (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
- arr
- | otherwise
- = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
-
-rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
-rconcat
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce mconcat
-
-rappend :: forall n a. Elt a
- => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
-rappend arr1 arr2
- | sn@SNat <- rrank arr1
- , Dict <- lemKnownReplicate sn
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
- arr1 arr2
-
-rscalar :: Elt a => a -> Ranked 0 a
-rscalar x = Ranked (mscalar x)
-
-rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
-rfromVectorP sh v
- | Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mfromVectorP (shCvtRX sh) v)
-
-rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
-rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
-
-rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
-rtoVectorP = coerce mtoVectorP
-
-rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
-rtoVector = coerce mtoVector
-
-rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-
-rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
-
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
-
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
-
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
-
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
-
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
-
-rtoListLinear :: Elt a => Ranked n a -> [a]
-rtoListLinear (Ranked arr) = mtoListLinear arr
-
-rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
-rfromOrthotope sn arr
- | Refl <- lemRankReplicate sn
- = let xarr = XArray arr
- in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
-
-rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
-rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh)
- = arr
-
-runScalar :: Elt a => Ranked 0 a -> a
-runScalar arr = rindex arr ZIR
-
-rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
-rnest n arr
- | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
- = coerce (mnest (ssxFromSNat n) (coerce arr))
-
-runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
-runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
- | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
- = Ranked arr
-
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
-rzip = coerce mzip
-
-runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
-runzip = coerce munzip
-
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
-
--- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
--- input array, then there is no way to deduce the full shape of the output
--- array (more precisely, the @n2@ part): that could only come from calling
--- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
--- this case; we choose to fill the @n2@ part of the output shape with zeros.
---
--- For example, if:
---
--- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
--- f :: Ranked 2 Int -> Ranked 3 Float
--- @
---
--- then:
---
--- @
--- rrerank _ _ _ f arr :: Ranked 5 Float
--- @
---
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank sn sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
-
-rreplicate :: forall n m a. Elt a
- => IShR n -> Ranked m a -> Ranked (n + m) a
-rreplicate sh (Ranked arr)
- | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
- = Ranked (mreplicate (shCvtRX sh) arr)
-
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
- | Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mreplicateScalP (shCvtRX sh) x)
-
-rreplicateScal :: forall n a. PrimElt a
- => IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
-
-rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (rrank arr)
- (\_ -> X.sliceU i n)
- arr
-
-rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (rrank arr)
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
-
-rreshape :: forall n n' a. Elt a
- => IShR n' -> Ranked n a -> Ranked n' a
-rreshape sh' rarr@(Ranked arr)
- | Dict <- lemKnownReplicate (rrank rarr)
- , Dict <- lemKnownReplicate (shrRank sh')
- = Ranked (mreshape (shCvtRX sh') arr)
-
-rflatten :: Elt a => Ranked n a -> Ranked 1 a
-rflatten (Ranked arr) = mtoRanked (mflatten arr)
-
-riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
-riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
-
--- | Throws if the array is empty.
-rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
-rminIndexPrim rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (mminIndexPrim arr)
-
--- | Throws if the array is empty.
-rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
-rmaxIndexPrim rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (mmaxIndexPrim arr)
-
-rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
-rdot1Inner arr1 arr2
- | SNat <- rrank arr1
- , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
- = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
-
--- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'rdot1Inner' if applicable.
-rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
-rdot = coerce mdot
-
-rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
-
-rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr)
-
-rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-
-rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-
-rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
-rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
-
-rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
-rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
-
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked arr
- | Refl <- lemRankReplicate (shxRank (mshape arr))
- = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
- where
- convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
- convSh ZSX = ZSX
- convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
- = SUnknown (fromSMayNat' smn) :$% convSh sh
-
-rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
-rtoMixed (Ranked arr) = arr
-
--- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
--- compatibility check.
-rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
-rcastToMixed sshx rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank rarr)
- = mcast sshx arr
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
deleted file mode 100644
index 372439f..0000000
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ /dev/null
@@ -1,495 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Shaped where
-
-import Prelude hiding (mappend, mconcat)
-
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
-import Data.Array.Internal.RankedG qualified as RG
-import Data.Array.Internal.RankedS qualified as RS
-import Data.Array.Internal.ShapedG qualified as SG
-import Data.Array.Internal.ShapedS qualified as SS
-import Data.Bifunctor (first)
-import Data.Coerce (coerce)
-import Data.Kind (Type)
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray)
-import Data.Array.Mixed.XArray qualified as X
-import Data.Array.Nested.Internal.Lemmas
-import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Shape
-import Data.Array.Strided.Arith
-
-
--- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's. Note that
--- these are "GHC.TypeLits" naturals, because we do not need induction over
--- them and we want very large arrays to be possible.
---
--- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
--- and 'Shaped' itself is again an instance of 'Elt' as well.
---
--- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
-deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
-
-instance (Show a, Elt a) => Show (Shaped n a) where
- showsPrec d arr@(Shaped marr) =
- let sh = show (shsToList (sshape arr))
- in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Shaped sh a) where
- rnf (Shaped arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
-
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
-
-instance Elt a => Elt (Shaped sh a) where
- mshape (M_Shaped arr) = mshape arr
- mindex (M_Shaped arr) i = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
-
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
-
- mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoListOuter (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift ssh2 f (M_Shaped arr) =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
- coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
- @(NonEmpty (Mixed sh2 (Shaped sh a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
-
- mconcat l = M_Shaped (mconcat (coerce l))
-
- mrnf (M_Shaped arr) = mrnf arr
-
- type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-
- mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Shaped arr) = marrayStrides arr
-
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
-instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
- memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
-
-liftShaped1 :: forall sh a b.
- (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
- -> Shaped sh a -> Shaped sh b
-liftShaped1 = coerce
-
-liftShaped2 :: forall sh a b c.
- (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
- -> Shaped sh a -> Shaped sh b -> Shaped sh c
-liftShaped2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
- (+) = liftShaped2 (+)
- (-) = liftShaped2 (-)
- (*) = liftShaped2 (*)
- negate = liftShaped1 negate
- abs = liftShaped1 abs
- signum = liftShaped1 signum
- fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
- recip = liftShaped1 recip
- (/) = liftShaped2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
- exp = liftShaped1 exp
- log = liftShaped1 log
- sqrt = liftShaped1 sqrt
- (**) = liftShaped2 (**)
- logBase = liftShaped2 logBase
- sin = liftShaped1 sin
- cos = liftShaped1 cos
- tan = liftShaped1 tan
- asin = liftShaped1 asin
- acos = liftShaped1 acos
- atan = liftShaped1 atan
- sinh = liftShaped1 sinh
- cosh = liftShaped1 cosh
- tanh = liftShaped1 tanh
- asinh = liftShaped1 asinh
- acosh = liftShaped1 acosh
- atanh = liftShaped1 atanh
- log1p = liftShaped1 GHC.Float.log1p
- expm1 = liftShaped1 GHC.Float.expm1
- log1pexp = liftShaped1 GHC.Float.log1pexp
- log1mexp = liftShaped1 GHC.Float.log1mexp
-
-squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
-squotArray = liftShaped2 mquotArray
-sremArray = liftShaped2 mremArray
-
-satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
-satan2Array = liftShaped2 matan2Array
-
-
-semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
-semptyArray sh = Shaped (memptyArray (shCvtSX sh))
-
-sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shCvtXS' (mshape arr)
-
-srank :: Elt a => Shaped sh a -> SNat (Rank sh)
-srank = shsRank . sshape
-
--- | The total number of elements in the array.
-ssize :: Elt a => Shaped sh a -> Int
-ssize = shsSize . sshape
-
-sindex :: Elt a => Shaped sh a -> IIxS sh -> a
-sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
-shsTakeIx _ _ ZIS = ZSS
-shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
-
-sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
-sindexPartial sarr@(Shaped arr) idx =
- Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
- (ixCvtSX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
-
--- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. Elt a
- => ShS sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a
-slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
-
--- | See the documentation of 'mlift'.
-slift2 :: forall sh1 sh2 sh3 a. Elt a
- => ShS sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
-slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
-
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
-
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
-
-ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
-ssumAllPrim (Shaped arr) = msumAllPrim arr
-
-stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
- => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
-stranspose perm sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- , Refl <- lemTakeLenMapJust perm (sshape sarr)
- , Refl <- lemDropLenMapJust perm (sshape sarr)
- , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr))
- , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
- = Shaped (mtranspose perm arr)
-
-sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend = coerce mappend
-
-sscalar :: Elt a => a -> Shaped '[] a
-sscalar x = Shaped (mscalar x)
-
-sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
-
-sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
-sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
-
-stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
-stoVectorP = coerce mtoVectorP
-
-stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
-stoVector = coerce mtoVector
-
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-
-sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
-
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
-
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
-
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
-
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l)
-
-stoListLinear :: Elt a => Shaped sh a -> [a]
-stoListLinear (Shaped arr) = mtoListLinear arr
-
-sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
-sfromOrthotope sh (SS.A (SG.A arr)) =
- Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))
-
-stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a
-stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr)
-
-sunScalar :: Elt a => Shaped '[] a -> a
-sunScalar arr = sindex arr ZIS
-
-snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
-snest sh arr
- | Refl <- lemMapJustApp sh (Proxy @sh')
- = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr))
-
-sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
-sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
- | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
- = Shaped arr
-
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
-szip = coerce mzip
-
-sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
-sunzip = coerce munzip
-
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
- (shCvtSX sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
-
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
-
-sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
-sreplicate sh (Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh')
- = Shaped (mreplicate (shCvtSX sh) arr)
-
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
-
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
-
-sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n@SNat arr =
- let _ :$$ sh = sshape arr
- in slift (n :$$ sh) (\_ -> X.slice i n) arr
-
-srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
-
-sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a
-sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
-
-sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
-sflatten arr =
- case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
- n@SNat -> sreshape (n :$$ ZSS) arr
-
-siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
-siota sn = Shaped (miota sn)
-
--- | Throws if the array is empty.
-sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
-
--- | Throws if the array is empty.
-smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
-
-sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
- => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
-sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
- | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
- , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
- = case sshape sarr1 of
- _ :$$ _
- | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
- -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
- _ -> error "unreachable"
-
--- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
--- Prefer 'sdot1Inner' if applicable.
-sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
-sdot = coerce mdot
-
-stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
-stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)
-
-stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
-stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr)
-
-sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
-
-sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
-
-sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
-sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
-
-stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
-stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
-
-mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => Mixed sh a -> ShS sh' -> Shaped sh' a
-mcastToShaped arr targetsh
- | Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
-
-stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
-stoMixed (Shaped arr) = arr
-
--- | A more weakly-typed version of 'stoMixed' that does a runtime shape
--- compatibility check.
-scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => StaticShX sh' -> Shaped sh a -> Mixed sh' a
-scastToMixed sshx sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- = mcast sshx arr
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index 560f762..8cac298 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -6,27 +6,19 @@
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Lemmas where
+module Data.Array.Nested.Lemmas where
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
--- * Reasoning helpers
-
-subst1 :: forall f a b. a :~: b -> f a :~: f b
-subst1 Refl = Refl
-
-subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
-subst2 Refl = Refl
-
-
--- * Lemmas
+-- * Lemmas about numbers and lists
-- ** Nat
@@ -36,7 +28,6 @@ lemLeqSuccSucc _ _ = unsafeCoerceRefl
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
-
-- ** Append
lemAppNil :: l ++ '[] :~: l
@@ -48,41 +39,7 @@ lemAppAssoc _ _ _ = unsafeCoerceRefl
lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl
-
--- ** Rank
-
-lemRankApp :: forall sh1 sh2.
- StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
-lemRankApp ZKX _ = Refl
-lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
- = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
- lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
- lemRankApp ssh1 ssh2
- where
- lem :: proxy a -> proxy b -> proxy c
- -> c :~: b + a
- -> b + a :~: c
- lem _ _ _ Refl = Refl
-
- lem2 :: proxy a -> proxy b -> proxy c
- -> (a + b :~: c)
- -> c + 1 :~: (a + 1 + b)
- lem2 _ _ _ Refl = Refl
-
-lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
-lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
-
-lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate SZ = Refl
-lemRankReplicate (SS (n :: SNat nm1))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
- , Refl <- lemRankReplicate n
- = Refl
-
-
--- ** Various type families
+-- ** Simple type families
lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
@@ -126,6 +83,8 @@ lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+-- * Lemmas about shapes
+
-- ** Known shapes
lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
@@ -135,3 +94,69 @@ lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
+
+-- ** Rank
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem (Proxy @(Rank sh1T)) Proxy Proxy $
+ sym (lemRankApp ssh1 ssh2)
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem _ _ _ Refl = Refl
+
+lemRankAppComm :: proxy sh1 -> proxy sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl
+
+lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = unsafeCoerceRefl
+
+lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust ZSS = Refl
+lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
+
+-- ** Related to MapJust and/or Permutation
+
+lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemTakeLenMapJust PNil _ = Refl
+lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
+lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemDropLenMapJust PNil _ = Refl
+lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
+lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemIndexMapJust SZ (_ :$$ _) = Refl
+lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemIndexMapJust i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemIndexMapJust _ ZSS = error "Index of empty"
+
+lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemPermuteMapJust PNil _ = Refl
+lemPermuteMapJust (i `PCons` is) sh
+ | Refl <- lemPermuteMapJust is sh
+ , Refl <- lemIndexMapJust i sh
+ = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index a2f9737..144230e 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
@@ -16,7 +17,7 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
-module Data.Array.Nested.Internal.Mixed where
+module Data.Array.Nested.Mixed where
import Prelude hiding (mconcat)
@@ -28,7 +29,7 @@ import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
import Data.Int
-import Data.Kind (Constraint, Type)
+import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty qualified as NE
import Data.Proxy
@@ -40,15 +41,14 @@ import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray(..))
-import Data.Array.Mixed.XArray qualified as X
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
import Data.Bag
@@ -140,27 +140,39 @@ data family Mixed sh a
-- sizes of the elements of an empty array, which is information that should
-- ostensibly not exist; the full array is still empty.
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+#define ANDSHOW , Show
+#else
+#define ANDSHOW
+#endif
+
data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
- deriving (Eq, Ord, Generic)
+ deriving (Eq, Ord, Generic ANDSHOW)
-- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic)
-newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector)
+newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic ANDSHOW) -- no content, orthotope optimises this (via Vector)
-- etc.
data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
+#endif
-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b))
deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b))
data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (Show (Mixed (sh1 ++ sh2) a)) => Show (Mixed sh1 (Mixed sh2 a))
+#endif
deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a))
deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a))
@@ -204,10 +216,12 @@ showsMixedArray fromlistPrefix replicatePrefix d arr =
_ ->
showString fromlistPrefix . showString " " . shows (mtoListLinear arr)
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Mixed sh a) where
showsPrec d arr =
let sh = show (shxToList (mshape arr))
in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr
+#endif
instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
@@ -366,9 +380,7 @@ class Elt a where
-- of this class with those of 'Elt': some instances have an additional
-- "known-shape" constraint.
--
--- This class is (currently) only required for 'mgenerate',
--- 'Data.Array.Nested.Ranked.rgenerate' and
--- 'Data.Array.Nested.Shaped.sgenerate'.
+-- This class is (currently) only required for `memptyArray` and 'mgenerate'.
class Elt a => KnownElt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArrayUnsafe :: IShX sh -> Mixed sh a
@@ -384,11 +396,11 @@ class Elt a => KnownElt a where
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
mlift :: forall sh1 sh2.
@@ -426,17 +438,17 @@ instance Storable a => Elt (Primitive a) where
=> StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- sh2 = shxCast' sh1 ssh2
- in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+ sh2 = shxCast' ssh2 sh1
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
mtranspose perm (M_Primitive sh arr) =
M_Primitive (shxPermutePrefix perm sh)
- (X.transpose (ssxFromShape sh) perm arr)
+ (X.transpose (ssxFromShX sh) perm arr)
mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
- let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l)
- in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result
+ let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l)
+ in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result
mrnf (M_Primitive sh a) = rnf sh `seq` rnf a
@@ -453,7 +465,7 @@ instance Storable a => Elt (Primitive a) where
:: forall sh' sh s.
IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShape sh') arr
+ let arrsh = X.shape (ssxFromShX sh') arr
offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
@@ -540,16 +552,16 @@ instance Elt a => Elt (Mixed sh' a) where
-- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
+ mindex (M_Nest _ arr) = mindexPartial arr
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
@@ -569,7 +581,7 @@ instance Elt a => Elt (Mixed sh' a) where
(sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -586,7 +598,7 @@ instance Elt a => Elt (Mixed sh' a) where
(sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -604,7 +616,7 @@ instance Elt a => Elt (Mixed sh' a) where
(sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
in fmap (M_Nest sh2) result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
f' sshT
@@ -618,15 +630,15 @@ instance Elt a => Elt (Mixed sh' a) where
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
= let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
-> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
- , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ | let sh' = shxDropSh @sh @sh' sh (mshape arr)
+ , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
, Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
, Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
, Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
@@ -637,14 +649,14 @@ instance Elt a => Elt (Mixed sh' a) where
mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
mconcat l@(M_Nest sh1 _ :| _) =
let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result
+ in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result
mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -671,7 +683,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsUnsafeNew sh example
| shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh')))
where
sh' = mshape example
@@ -729,14 +741,14 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1P (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr
+msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -744,7 +756,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
where
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
- ssh = ssxFromShape sh
+ ssh = ssxFromShX sh
snm :: SMayNat () SNat (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
@@ -770,14 +782,10 @@ mtoVector arr = mtoVectorP (toPrimitive arr)
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1 l)
mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromListPrim l =
@@ -790,10 +798,8 @@ mfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
mtoListLinear :: Elt a => Mixed sh a -> [a]
mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
@@ -807,8 +813,11 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
@@ -818,13 +827,13 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
-> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
-> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
+ let sh1 = shxDropSSX ssh sh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2)
+ (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
(\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
arr)
--- | See the caveats at @X.rerank@.
+-- | See the caveats at 'Data.Array.XArray.rerank'.
mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
=> StaticShX sh -> IShX sh2
-> (Mixed sh1 a -> Mixed sh2 b)
@@ -835,8 +844,8 @@ mrerank ssh sh2 f (toPrimitive -> arr) =
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
mreplicate sh arr =
- let ssh' = ssxFromShape (mshape arr)
- in mlift (ssxAppend (ssxFromShape sh) ssh')
+ let ssh' = ssxFromShX (mshape arr)
+ in mlift (ssxAppend (ssxFromShX sh) ssh')
(\(sshT :: StaticShX shT) ->
case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
Refl -> X.replicate sh (ssxAppend ssh' sshT))
@@ -852,18 +861,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
let _ :$% sh = mshape arr
- in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+ in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr
mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
mreshape sh' arr =
- mlift (ssxFromShape sh')
- (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ mlift (ssxFromShX sh')
+ (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh')
arr
mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
@@ -875,12 +884,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-- | Throws if the array is empty.
mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
+ ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr))
-- | Throws if the array is empty.
mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+ ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
@@ -890,7 +899,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
= case sh1 of
_ :$% _
| sh1 == sh2
- , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) ->
fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
| otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"
ZSX -> error "unreachable"
@@ -925,31 +934,3 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
-
-mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
- => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
-mcast ssh2 arr
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
-
--- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors
-data SafeMCastSpec
- = MCastId
- | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec
- | MCastForget
-
-type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint
-type family SafeMCast spec sh1 sh2 where
- SafeMCast MCastId sh sh = ()
- SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B)
- SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing
-
--- | This is an O(1) operation: the 'SafeMCast' constraint ensures that
--- type-level shape information can only be forgotten, not introduced, and thus
--- that no runtime shape checks are required. The @spec@ describes to
--- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@.
---
--- To see how to construct the spec, read the equations of 'SafeMCast' closely.
-mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a
-mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a)
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index eb8434f..852dd5e 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -20,7 +21,7 @@
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Shape where
+module Data.Array.Nested.Mixed.Shape where
import Control.DeepSeq (NFData(..))
import Data.Bifunctor (first)
@@ -30,7 +31,6 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
import GHC.Exts (withDict)
import GHC.Generics (Generic)
@@ -38,7 +38,7 @@ import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
-import Data.Array.Mixed.Types
+import Data.Array.Nested.Types
-- | The length of a type-level list. If the argument is a shape, then the
@@ -59,8 +59,12 @@ deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
infixr 3 ::%
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (forall n. Show (f n)) => Show (ListX sh f)
+#else
instance (forall n. Show (f n)) => Show (ListX sh f) where
showsPrec _ = listxShow shows
+#endif
instance (forall n. NFData (f n)) => NFData (ListX sh f) where
rnf ZX = ()
@@ -141,9 +145,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f
+listxDrop ZX long = long
+listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'
listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
@@ -167,7 +171,7 @@ listxZipWith f (i ::% is) (j ::% js) =
-- * Mixed indices
--- | This is a newtype over 'ListX'.
+-- | An index into a mixed-typed array.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
@@ -186,10 +190,16 @@ infixr 3 :.%
{-# COMPLETE ZIX, (:.%) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxX sh = IxX sh Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxX sh i)
+#else
instance Show i => Show (IxX sh i) where
- showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l
+ showsPrec _ (IxX l) = listxShow (shows . getConst) l
+#endif
instance Functor (IxX sh) where
fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
@@ -225,7 +235,7 @@ ixxTail (IxX list) = IxX (listxTail list)
ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
ixxAppend = coerce (listxAppend @_ @(Const i))
-ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i
ixxDrop = coerce (listxDrop @(Const i) @(Const i))
ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
@@ -234,6 +244,11 @@ ixxInit = coerce (listxInit @(Const i))
ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))
+ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
+ixxCast ZKX ZIX = ZIX
+ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx
+ixxCast _ _ = error "ixxCast: ranks don't match"
+
ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
ixxZip ZIX ZIX = ZIX
ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
@@ -326,8 +341,12 @@ infixr 3 :$%
type IShX sh = ShX sh Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ShX sh i)
+#else
instance Show i => Show (ShX sh i) where
showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+#endif
instance Functor (ShX sh) where
fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
@@ -377,10 +396,10 @@ shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
-shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
+shxFromList :: StaticShX sh -> [Int] -> IShX sh
shxFromList topssh topl = go topssh topl
where
- go :: StaticShX sh' -> [Int] -> ShX sh' Int
+ go :: StaticShX sh' -> [Int] -> IShX sh'
go ZKX [] = ZSX
go (SKnown sn :!% sh) (i : is)
| i == fromSNat' sn = SKnown sn :$% go sh is
@@ -395,6 +414,19 @@ shxToList :: IShX sh -> [Int]
shxToList ZSX = []
shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
+shxFromSSX ZKX = ZSX
+shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
+ | Refl <- lemMapJustCons @sh Refl
+ = SKnown n :$% shxFromSSX sh
+shxFromSSX (SUnknown _ :!% _) = error "unreachable"
+
+-- | This may fail if @sh@ has @Nothing@s in it.
+shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i)
+shxFromSSX2 ZKX = Just ZSX
+shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
+shxFromSSX2 (SUnknown _ :!% _) = Nothing
+
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
@@ -404,13 +436,13 @@ shxHead (ShX list) = listxHead list
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
-shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
-shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
@@ -419,12 +451,9 @@ shxInit = coerce (listxInit @(SMayNat i SNat))
shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
shxLast = coerce (listxLast @(SMayNat i SNat))
-shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shxTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
-> ShX sh i -> ShX sh j -> ShX sh k
@@ -437,7 +466,7 @@ shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
-shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
@@ -448,17 +477,17 @@ shxEnum = \sh -> go sh id []
go ZSX f = (f ZIX :)
go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
-shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
-shxCast ZSX ZKX = Just ZSX
-shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
-shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
+shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
+shxCast ZKX ZSX = Just ZSX
+shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
+shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh
shxCast _ _ = Nothing
-- | Partial version of 'shxCast'.
-shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
-shxCast' sh ssh = case shxCast sh ssh of
+shxCast' :: StaticShX sh' -> IShX sh -> IShX sh'
+shxCast' ssh sh = case shxCast ssh sh of
Just sh' -> sh'
Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
@@ -483,8 +512,12 @@ infixr 3 :!%
{-# COMPLETE ZKX, (:!%) #-}
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (StaticShX sh)
+#else
instance Show (StaticShX sh) where
showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+#endif
instance NFData (StaticShX sh) where
rnf (StaticShX ZX) = ()
@@ -514,34 +547,34 @@ ssxHead (StaticShX list) = listxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+
+ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
ssxInit = coerce (listxInit @(SMayNat () SNat))
ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
ssxLast = coerce (listxLast @(SMayNat () SNat))
--- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
-ssxToShX' ZKX = Just ZSX
-ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
-ssxToShX' (SUnknown _ :!% _) = Nothing
-
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
| Refl <- lemReplicateSucc @(Nothing @Nat) @n'
= SUnknown () :!% ssxReplicate n
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom :: StaticShX sh -> Int -> [Int]
+ssxIotaFrom ZKX _ = []
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1)
-ssxFromShape :: IShX sh -> StaticShX sh
-ssxFromShape ZSX = ZKX
-ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+ssxFromShX :: ShX sh i -> StaticShX sh
+ssxFromShX ZSX = ZKX
+ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh
ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
@@ -557,7 +590,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
-withKnownShX k = withDict @(KnownShX sh) k
+withKnownShX = withDict @(KnownShX sh)
-- * Flattening
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index cedfa22..03d1640 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -4,7 +4,6 @@
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -15,7 +14,7 @@
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Permutation where
+module Data.Array.Nested.Permutation where
import Data.Coerce (coerce)
import Data.Functor.Const
@@ -25,19 +24,20 @@ import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
+import GHC.Exts (withDict)
import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Types
-- * Permutations
-- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
+-- dimension list is most similar to @backpermute@ in the @vector@ package; see
+-- 'Permute' for code that implements this.
data Perm list where
PNil :: Perm '[]
PCons :: SNat a -> Perm l -> Perm (a : l)
@@ -45,6 +45,13 @@ infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
+instance TestEquality Perm where
+ testEquality PNil PNil = Just Refl
+ testEquality (x `PCons` xs) (y `PCons` ys)
+ | Just Refl <- testEquality x y
+ , Just Refl <- testEquality xs ys = Just Refl
+ testEquality _ _ = Nothing
+
permRank :: Perm list -> SNat (Rank list)
permRank PNil = SNat
permRank (_ `PCons` l) | SNat <- permRank l = SNat
@@ -119,6 +126,9 @@ class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
+withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r
+withKnownPerm = withDict @(KnownPerm l)
+
-- | Untyped permutations for ranked arrays
type PermR = [Int]
@@ -224,7 +234,7 @@ permInverse = \perm k ->
++ " ; invperm = " ++ show invperm)
(permCheckPermutation invperm
(k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
+ (\ssh -> case permCheckInverse perm invperm ssh of
Just eq -> eq
Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
++ " ; invperm = " ++ show invperm)))
@@ -238,9 +248,9 @@ permInverse = \perm k ->
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ permCheckInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
+ permCheckInverse perm perminv ssh =
ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
new file mode 100644
index 0000000..9778c54
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -0,0 +1,323 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked (
+ Ranked(Ranked),
+ rquotArray, rremArray, ratan2Array,
+ rshape, rrank,
+ module Data.Array.Nested.Ranked,
+ liftRanked1, liftRanked2,
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
+
+
+remptyArray :: KnownElt a => Ranked 1 a
+remptyArray = mtoRanked (memptyArray ZSX)
+
+-- | The total number of elements in the array.
+rsize :: Elt a => Ranked n a -> Int
+rsize = shrSize . rshape
+
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)
+
+rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (ixxFromIxR idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgenerate sh f
+ | sn@SNat <- shrRank sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
+ = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
+
+-- | See the documentation of 'mlift'.
+rlift :: forall n1 n2 a. Elt a
+ => SNat n2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
+
+-- | See the documentation of 'mlift2'.
+rlift2 :: forall n1 n2 n3 a. Elt a
+ => SNat n3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
+rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+
+rsumOuter1P :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1P (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (msumOuter1P arr)
+
+rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+
+rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
+rsumAllPrim (Ranked arr) = msumAllPrim arr
+
+rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
+rtranspose perm arr
+ | sn@SNat <- rrank arr
+ , Dict <- lemKnownReplicate sn
+ , length perm <= fromIntegral (natVal (Proxy @n))
+ = rlift sn
+ (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
+ arr
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+
+rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
+rconcat
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce mconcat
+
+rappend :: forall n a. Elt a
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend arr1 arr2
+ | sn@SNat <- rrank arr1
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
+ arr1 arr2
+
+rscalar :: Elt a => a -> Ranked 0 a
+rscalar x = Ranked (mscalar x)
+
+rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
+ | Dict <- lemKnownReplicate (shrRank sh)
+ = Ranked (mfromVectorP (shxFromShR sh) v)
+
+rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+
+rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
+rtoVectorP = coerce mtoVectorP
+
+rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
+rtoVector = coerce mtoVector
+
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 l = Ranked (mfromList1 l)
+
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = rreshape sh (rfromList1 l)
+
+rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
+rfromListPrim l = Ranked (mfromListPrim l)
+
+rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
+
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+
+rtoListLinear :: Elt a => Ranked n a -> [a]
+rtoListLinear (Ranked arr) = mtoListLinear arr
+
+rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
+rfromOrthotope sn arr
+ | Refl <- lemRankReplicate sn
+ = let xarr = XArray arr
+ in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
+
+rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh)
+ = arr
+
+runScalar :: Elt a => Ranked 0 a -> a
+runScalar arr = rindex arr ZIR
+
+rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
+rnest n arr
+ | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
+ = coerce (mnest (ssxFromSNat n) (coerce arr))
+
+runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
+runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
+ | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked arr
+
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip = coerce mzip
+
+runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
+runzip = coerce munzip
+
+rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
+rrerankP sn sh2 f (Ranked arr)
+ | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
+ , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
+ = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr)
+
+-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
+-- input array, then there is no way to deduce the full shape of the output
+-- array (more precisely, the @n2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the @n2@ part of the output shape with zeros.
+--
+-- For example, if:
+--
+-- @
+-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- f :: Ranked 2 Int -> Ranked 3 Float
+-- @
+--
+-- then:
+--
+-- @
+-- rrerank _ _ _ f arr :: Ranked 6 Float
+-- @
+--
+-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
+-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
+-- to return an array with shape all-0 here (it probably didn't), but there is
+-- no better number to put here absent a subarray of the input to pass to @f@.
+rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked (n + n1) a -> Ranked (n + n2) b
+rrerank sn sh2 f (rtoPrimitive -> arr) =
+ rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
+
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shxFromShR sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
+ | Dict <- lemKnownReplicate (shrRank sh)
+ = Ranked (mreplicateScalP (shxFromShR sh) x)
+
+rreplicateScal :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+
+rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n arr
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = rlift (rrank arr)
+ (\_ -> X.sliceU i n)
+ arr
+
+rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 arr =
+ rlift (rrank arr)
+ (\(_ :: StaticShX sh') ->
+ case lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
+ arr
+
+rreshape :: forall n n' a. Elt a
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' rarr@(Ranked arr)
+ | Dict <- lemKnownReplicate (rrank rarr)
+ , Dict <- lemKnownReplicate (shrRank sh')
+ = Ranked (mreshape (shxFromShR sh') arr)
+
+rflatten :: Elt a => Ranked n a -> Ranked 1 a
+rflatten (Ranked arr) = mtoRanked (mflatten arr)
+
+riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
+riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+
+-- | Throws if the array is empty.
+rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rminIndexPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixrFromIxX (mminIndexPrim arr)
+
+-- | Throws if the array is empty.
+rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rmaxIndexPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixrFromIxX (mmaxIndexPrim arr)
+
+rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
+rdot1Inner arr1 arr2
+ | SNat <- rrank arr1
+ , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
+ = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'rdot1Inner' if applicable.
+rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
+rdot = coerce mdot
+
+rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
+rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr)
+
+rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
+rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr)
+
+rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
new file mode 100644
index 0000000..babc809
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -0,0 +1,268 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Ranked.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+import Data.Foldable (toList)
+#endif
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
+#endif
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
+deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Ranked n a) where
+ showsPrec d arr@(Ranked marr) =
+ let sh = show (toList (rshape arr))
+ in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Ranked n a) where
+ rnf (Ranked arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance Elt a => Elt (Ranked n a) where
+ mshape (M_Ranked arr) = mshape arr
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
+ @(NonEmpty (Mixed sh2 (Ranked n a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
+ mconcat l = M_Ranked (mconcat (coerce l))
+
+ mrnf (M_Ranked arr) = mrnf arr
+
+ type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+
+ mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Ranked arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
+
+liftRanked1 :: forall n a b.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
+ -> Ranked n a -> Ranked n b
+liftRanked1 = coerce
+
+liftRanked2 :: forall n a b c.
+ (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
+ -> Ranked n a -> Ranked n b -> Ranked n c
+liftRanked2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+ (+) = liftRanked2 (+)
+ (-) = liftRanked2 (-)
+ (*) = liftRanked2 (*)
+ negate = liftRanked1 negate
+ abs = liftRanked1 abs
+ signum = liftRanked1 signum
+ fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
+ recip = liftRanked1 recip
+ (/) = liftRanked2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
+ exp = liftRanked1 exp
+ log = liftRanked1 log
+ sqrt = liftRanked1 sqrt
+ (**) = liftRanked2 (**)
+ logBase = liftRanked2 logBase
+ sin = liftRanked1 sin
+ cos = liftRanked1 cos
+ tan = liftRanked1 tan
+ asin = liftRanked1 asin
+ acos = liftRanked1 acos
+ atan = liftRanked1 atan
+ sinh = liftRanked1 sinh
+ cosh = liftRanked1 cosh
+ tanh = liftRanked1 tanh
+ asinh = liftRanked1 asinh
+ acosh = liftRanked1 acosh
+ atanh = liftRanked1 atanh
+ log1p = liftRanked1 GHC.Float.log1p
+ expm1 = liftRanked1 GHC.Float.expm1
+ log1pexp = liftRanked1 GHC.Float.log1pexp
+ log1mexp = liftRanked1 GHC.Float.log1mexp
+
+rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+rquotArray = liftRanked2 mquotArray
+rremArray = liftRanked2 mremArray
+
+ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+ratan2Array = liftRanked2 matan2Array
+
+
+rshape :: Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = shrFromShX2 (mshape arr)
+
+rrank :: Elt a => Ranked n a -> SNat n
+rrank = shrRank . rshape
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
+shrFromShX ZSX = ZSR
+shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
+shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
+shrFromShX2 sh
+ | Refl <- lemRankReplicate (Proxy @n)
+ = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
new file mode 100644
index 0000000..8b670e5
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -0,0 +1,369 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Kind (Type)
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Types
+
+
+-- * Ranked lists
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ListR n i)
+#else
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+#endif
+
+instance NFData i => NFData (ListR n i) where
+ rnf ZR = ()
+ rnf (x ::: l) = rnf x `seq` rnf l
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqRank ZR ZR = Just Refl
+listrEqRank (_ ::: sh) (_ ::: sh')
+ | Just Refl <- listrEqRank sh sh'
+ = Just Refl
+listrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqual ZR ZR = Just Refl
+listrEqual (i ::: sh) (j ::: sh')
+ | Just Refl <- listrEqual sh sh'
+ , i == j
+ = Just Refl
+listrEqual _ _ = Nothing
+
+listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR n' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrLength :: ListR n i -> Int
+listrLength = length
+
+listrRank :: ListR n i -> SNat n
+listrRank ZR = SNat
+listrRank (_ ::: sh) = snatSucc (listrRank sh)
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrHead :: ListR (n + 1) i -> i
+listrHead (i ::: _) = i
+listrHead ZR = error "unreachable"
+
+listrTail :: ListR (n + 1) i -> ListR n i
+listrTail (_ ::: sh) = sh
+listrTail ZR = error "unreachable"
+
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
+-- | Performs a runtime check that the lengths are identical.
+listrCast :: SNat n' -> ListR n i -> ListR n' i
+listrCast = listrCastWithName "listrCast"
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
+listrZip ZR ZR = ZR
+listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
+listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+
+listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
+listrZipWith _ ZR ZR = ZR
+listrZipWith f (i ::: irest) (j ::: jrest) =
+ f i j ::: listrZipWith f irest jrest
+listrZipWith _ _ _ =
+ error "listrZipWith: impossible pattern needlessly required"
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrRank sperm, listrRank sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- * Ranked indices
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
+type IIxR n = IxR n Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxR n i)
+#else
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+#endif
+
+instance NFData i => NFData (IxR sh i)
+
+ixrLength :: IxR sh i -> Int
+ixrLength (IxR l) = listrLength l
+
+ixrRank :: IxR n i -> SNat n
+ixrRank (IxR sh) = listrRank sh
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixrHead :: IxR (n + 1) i -> i
+ixrHead (IxR list) = listrHead list
+
+ixrTail :: IxR (n + 1) i -> IxR n i
+ixrTail (IxR list) = IxR (listrTail list)
+
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
+-- | Performs a runtime check that the lengths are identical.
+ixrCast :: SNat n' -> IxR n i -> IxR n' i
+ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
+
+ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
+ixrAppend = coerce (listrAppend @_ @i)
+
+ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
+ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+
+ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
+ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- * Ranked shapes
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: ShR sh = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ShR n i)
+#else
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+#endif
+
+instance NFData i => NFData (ShR sh i)
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+
+-- | This compares the shapes for value equality.
+shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+
+shrLength :: ShR sh i -> Int
+shrLength (ShR l) = listrLength l
+
+-- | This function can also be used to conjure up a 'KnownNat' dictionary;
+-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
+-- synonym yields 'KnownNat' evidence.
+shrRank :: ShR n i -> SNat n
+shrRank (ShR sh) = listrRank sh
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrHead :: ShR (n + 1) i -> i
+shrHead (ShR list) = listrHead list
+
+shrTail :: ShR (n + 1) i -> ShR n i
+shrTail (ShR list) = ShR (listrTail list)
+
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
+-- | Performs a runtime check that the lengths are identical.
+shrCast :: SNat n' -> ShR n i -> ShR n' i
+shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+
+shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
+shrAppend = coerce (listrAppend @_ @i)
+
+shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
+shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+
+shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
+shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList topl = go (SNat @n) topl
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
+ ++ show (fromSNat (SNat @n)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList
+
+
+-- * Internal helper functions
+
+listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
+listrCastWithName _ SZ ZR = ZR
+listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
new file mode 100644
index 0000000..198a068
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -0,0 +1,272 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Shaped (
+ Shaped(Shaped),
+ squotArray, sremArray, satan2Array,
+ sshape,
+ module Data.Array.Nested.Shaped,
+ liftShaped1, liftShaped2,
+) where
+
+import Prelude hiding (mappend, mconcat)
+
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+import Data.Array.Internal.ShapedG qualified as SG
+import Data.Array.Internal.ShapedS qualified as SS
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Shaped.Base
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+import Data.Array.XArray qualified as X
+
+
+semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
+semptyArray sh = Shaped (memptyArray (shxFromShS sh))
+
+srank :: Elt a => Shaped sh a -> SNat (Rank sh)
+srank = shsRank . sshape
+
+-- | The total number of elements in the array.
+ssize :: Elt a => Shaped sh a -> Int
+ssize = shsSize . sshape
+
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
+
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx _ _ ZIS = ZSS
+shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+
+sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
+sindexPartial sarr@(Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
+ (ixxFromIxS idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+
+-- | See the documentation of 'mlift'.
+slift :: forall sh1 sh2 a. Elt a
+ => ShS sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
+
+-- | See the documentation of 'mlift'.
+slift2 :: forall sh1 sh2 sh3 a. Elt a
+ => ShS sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+
+ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+
+ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+
+ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
+ssumAllPrim (Shaped arr) = msumAllPrim arr
+
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ , Refl <- lemTakeLenMapJust perm (sshape sarr)
+ , Refl <- lemDropLenMapJust perm (sshape sarr)
+ , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr))
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
+ = Shaped (mtranspose perm arr)
+
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
+
+sscalar :: Elt a => a -> Shaped '[] a
+sscalar x = Shaped (mscalar x)
+
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+
+sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+
+stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
+stoVectorP = coerce mtoVectorP
+
+stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
+stoVector = coerce mtoVector
+
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
+
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
+
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+
+sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromListPrim sn l
+ | Refl <- lemAppNil @'[Just n]
+ = let ssh = SUnknown () :!% ZKX
+ xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
+ in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+
+sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
+
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+
+stoListLinear :: Elt a => Shaped sh a -> [a]
+stoListLinear (Shaped arr) = mtoListLinear arr
+
+sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
+sfromOrthotope sh (SS.A (SG.A arr)) =
+ Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))
+
+stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a
+stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr ZIS
+
+snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
+snest sh arr
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr))
+
+sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
+sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
+ | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
+ = Shaped arr
+
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip = coerce mzip
+
+sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
+sunzip = coerce munzip
+
+srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
+srerankP sh sh2 f sarr@(Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh1)
+ , Refl <- lemMapJustApp sh (Proxy @sh2)
+ = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr))))
+ (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr)
+
+-- | See the caveats at 'Data.Array.XArray.rerank'.
+srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
+srerank sh sh2 f (stoPrimitive -> arr) =
+ sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shxFromShS sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)
+
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n@SNat arr =
+ let _ :$$ sh = sshape arr
+ in slift (n :$$ sh) (\_ -> X.slice i n) arr
+
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
+
+sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
+
+sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
+sflatten arr =
+ case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
+ n@SNat -> sreshape (n :$$ ZSS) arr
+
+siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
+siota sn = Shaped (miota sn)
+
+-- | Throws if the array is empty.
+sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
+
+-- | Throws if the array is empty.
+smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+
+sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
+ => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
+sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
+ | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
+ , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
+ = case sshape sarr1 of
+ _ :$$ _
+ | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
+ -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
+ _ -> error "unreachable"
+
+-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
+-- Prefer 'sdot1Inner' if applicable.
+sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
+sdot = coerce mdot
+
+stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr)
+
+stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr)
+
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr)
+
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr)
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
new file mode 100644
index 0000000..879e6b5
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -0,0 +1,255 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_HADDOCK not-home #-}
+module Data.Array.Nested.Shaped.Base where
+
+import Prelude hiding (mappend, mconcat)
+
+import Control.DeepSeq (NFData(..))
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Foreign.Storable (Storable)
+import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Shaped.Shape
+import Data.Array.Nested.Types
+import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray)
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+#endif
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
+
+#ifndef OXAR_DEFAULT_SHOW_INSTANCES
+instance (Show a, Elt a) => Show (Shaped n a) where
+ showsPrec d arr@(Shaped marr) =
+ let sh = show (shsToList (sshape arr))
+ in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr
+#endif
+
+instance Elt a => NFData (Shaped sh a) where
+ rnf (Shaped arr) = rnf arr
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+ deriving (Generic)
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
+#endif
+
+deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
+
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
+ @(NonEmpty (Mixed sh2 (Shaped sh a))) $
+ mliftL ssh2 f (coerce l)
+
+ mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ mconcat l = M_Shaped (mconcat (coerce l))
+
+ mrnf (M_Shaped arr) = mrnf arr
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ marrayStrides (M_Shaped arr) = marrayStrides arr
+
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArrayUnsafe i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArrayUnsafe i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+liftShaped1 :: forall sh a b.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
+ -> Shaped sh a -> Shaped sh b
+liftShaped1 = coerce
+
+liftShaped2 :: forall sh a b c.
+ (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
+ -> Shaped sh a -> Shaped sh b -> Shaped sh c
+liftShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = liftShaped2 (+)
+ (-) = liftShaped2 (-)
+ (*) = liftShaped2 (*)
+ negate = liftShaped1 negate
+ abs = liftShaped1 abs
+ signum = liftShaped1 signum
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+
+instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ recip = liftShaped1 recip
+ (/) = liftShaped2 (/)
+
+instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ exp = liftShaped1 exp
+ log = liftShaped1 log
+ sqrt = liftShaped1 sqrt
+ (**) = liftShaped2 (**)
+ logBase = liftShaped2 logBase
+ sin = liftShaped1 sin
+ cos = liftShaped1 cos
+ tan = liftShaped1 tan
+ asin = liftShaped1 asin
+ acos = liftShaped1 acos
+ atan = liftShaped1 atan
+ sinh = liftShaped1 sinh
+ cosh = liftShaped1 cosh
+ tanh = liftShaped1 tanh
+ asinh = liftShaped1 asinh
+ acosh = liftShaped1 acosh
+ atanh = liftShaped1 atanh
+ log1p = liftShaped1 GHC.Float.log1p
+ expm1 = liftShaped1 GHC.Float.expm1
+ log1pexp = liftShaped1 GHC.Float.log1pexp
+ log1mexp = liftShaped1 GHC.Float.log1mexp
+
+squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+squotArray = liftShaped2 mquotArray
+sremArray = liftShaped2 mremArray
+
+satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+satan2Array = liftShaped2 matan2Array
+
+
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shsFromShX (mshape arr)
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
+shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
+ n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 97b9456..5f9ba79 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,11 +1,9 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -24,10 +22,9 @@
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Shape where
+module Data.Array.Nested.Shaped.Shape where
import Control.DeepSeq (NFData(..))
-import Data.Array.Mixed.Types
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
@@ -42,333 +39,16 @@ import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
-infixr 3 :::
-
-instance Show i => Show (ListR n i) where
- showsPrec _ = listrShow shows
-
-instance NFData i => NFData (ListR n i) where
- rnf ZR = ()
- rnf (x ::: l) = rnf x `seq` rnf l
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqRank ZR ZR = Just Refl
-listrEqRank (_ ::: sh) (_ ::: sh')
- | Just Refl <- listrEqRank sh sh'
- = Just Refl
-listrEqRank _ _ = Nothing
-
--- | This compares the lists for value equality.
-listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqual ZR ZR = Just Refl
-listrEqual (i ::: sh) (j ::: sh')
- | Just Refl <- listrEqual sh sh'
- , i == j
- = Just Refl
-listrEqual _ _ = Nothing
-
-listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
-listrShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR n' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrLength :: ListR n i -> Int
-listrLength = length
-
-listrRank :: ListR n i -> SNat n
-listrRank ZR = SNat
-listrRank (_ ::: sh) = snatSucc (listrRank sh)
-
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend ZR sh = sh
-listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
-
-listrHead :: ListR (n + 1) i -> i
-listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
-
-listrTail :: ListR (n + 1) i -> ListR n i
-listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
-
-listrInit :: ListR (n + 1) i -> ListR n i
-listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
-listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
-
-listrLast :: ListR (n + 1) i -> i
-listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
-listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
-listrZip ZR ZR = ZR
-listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
-listrZip _ _ = error "listrZip: impossible pattern needlessly required"
-
-listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
-listrZipWith _ ZR ZR = ZR
-listrZipWith f (i ::: irest) (j ::: jrest) =
- f i j ::: listrZipWith f irest jrest
-listrZipWith _ _ _ =
- error "listrZipWith: impossible pattern needlessly required"
-
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "listrPermutePrefix: Index in permutation out of range"
-
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
- where i :.: IxR sh = IxR (i ::: sh)
-infixr 3 :.:
-
-{-# COMPLETE ZIR, (:.:) #-}
-
-type IIxR n = IxR n Int
-
-instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = listrShow shows l
-
-instance NFData i => NFData (IxR sh i)
-
-ixrLength :: IxR sh i -> Int
-ixrLength (IxR l) = listrLength l
-
-ixrRank :: IxR n i -> SNat n
-ixrRank (IxR sh) = listrRank sh
-
-ixrZero :: SNat n -> IIxR n
-ixrZero SZ = ZIR
-ixrZero (SS n) = 0 :.: ixrZero n
-
-ixCvtXR :: IIxX sh -> IIxR (Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixCvtRX idx)
-
-ixrHead :: IxR (n + 1) i -> i
-ixrHead (IxR list) = listrHead list
-
-ixrTail :: IxR (n + 1) i -> IxR n i
-ixrTail (IxR list) = IxR (listrTail list)
-
-ixrInit :: IxR (n + 1) i -> IxR n i
-ixrInit (IxR list) = IxR (listrInit list)
-
-ixrLast :: IxR (n + 1) i -> i
-ixrLast (IxR list) = listrLast list
-
-ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @_ @i)
-
-ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
-ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
-
-ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
-ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
-type role ShR nominal representational
-type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
-
-pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
-infixr 3 :$:
-
-{-# COMPLETE ZSR, (:$:) #-}
-
-type IShR n = ShR n Int
-
-instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
-
-instance NFData i => NFData (ShR sh i)
-
-shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
-shCvtXR' ZSX =
- castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
- ZSR
-shCvtXR' (n :$% (idx :: IShX sh))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
- castWith (subst2 (lem1 @sh Refl))
- (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
- where
- lem1 :: forall sh' n' k.
- k : sh' :~: Replicate n' Nothing
- -> Rank sh' + 1 :~: n'
- lem1 Refl = unsafeCoerceRefl
-
- lem2 :: k : sh :~: Replicate n Nothing
- -> sh :~: Replicate (Rank sh) Nothing
- lem2 Refl = unsafeCoerceRefl
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shCvtRX idx)
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
--- | This compares the shapes for value equality.
-shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
-
-shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
-
--- | This function can also be used to conjure up a 'KnownNat' dictionary;
--- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
--- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
-
--- | The number of elements in an array described by this shape.
-shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
-
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
-
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
-
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
-
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
-
-shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
-
-shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
-
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList topl = go (SNat @n) topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
- ++ show (fromSNat (SNat @n)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (IxR n i) where
- type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+-- * Shaped lists
+-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
+-- removed in a future release.
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
@@ -379,8 +59,12 @@ deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
infixr 3 ::$
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+#else
instance (forall n. Show (f n)) => Show (ListS sh f) where
showsPrec _ = listsShow shows
+#endif
instance (forall m. NFData (f m)) => NFData (ListS n f) where
rnf ZS = ()
@@ -497,11 +181,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape"
listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+-- * Shaped indices
-- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
@@ -510,6 +192,8 @@ newtype IxS sh i = IxS (ListS sh (Const i))
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
+-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be
+-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (KnownNat n, n : sh ~ sh1)
@@ -520,10 +204,16 @@ infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxS sh = IxS sh Int
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (IxS sh i)
+#else
instance Show i => Show (IxS sh i) where
showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+#endif
instance Functor (IxS sh) where
fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
@@ -543,14 +233,6 @@ ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
-ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
-
-ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
-ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
-
ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)
@@ -563,6 +245,12 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
+ixsCast ZSS ZIS = ZIS
+ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
+ixsCast _ _ = error "ixsCast: ranks don't match"
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -578,6 +266,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+-- * Shaped shapes
+
-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
-- Note that because the shape of a shape-typed array is known statically, you
@@ -601,8 +291,12 @@ infixr 3 :$$
{-# COMPLETE ZSS, (:$$) #-}
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show (ShS sh)
+#else
instance Show (ShS sh) where
showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+#endif
instance NFData (ShS sh) where
rnf (ShS ZS) = ()
@@ -630,23 +324,6 @@ shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
-shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
-shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
-shCvtXS' (SUnknown _ :$% _) = error "impossible"
-
-shCvtSX :: ShS sh -> IShX (MapJust sh)
-shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
-
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list
@@ -690,7 +367,7 @@ instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
-withKnownShS k = withDict @(KnownShS sh) k
+withKnownShS = withDict @(KnownShS sh)
shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
@@ -700,6 +377,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
+-- This function may be removed in a future release.
+shsFromListS :: ListS sh f -> ShS sh
+shsFromListS ZS = ZSS
+shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
+-- function may be removed in a future release.
+shsFromIxS :: IxS sh i -> ShS sh
+shsFromIxS (IxS l) = shsFromListS l
+
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 838e2b0..8a29aa5 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -37,10 +37,12 @@ module Data.Array.Nested.Trace (
ShS(..), KnownShS(..),
Mixed,
+ ListX(ZX, (::%)),
IxX(..), IIxX,
- ShX(..), KnownShX(..),
+ ShX(..), KnownShX(..), IShX,
StaticShX(..),
SMayNat(..),
+ Conversion(..),
Elt,
PrimElt,
@@ -54,7 +56,7 @@ module Data.Array.Nested.Trace (
Perm(..),
IsPermutation,
KnownPerm(..),
- NumElt, FloatElt,
+ NumElt, IntElt, FloatElt,
Rank, Product,
Replicate,
MapJust,
@@ -67,4 +69,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Nested/Types.hs
index 3f5b1e7..4444acd 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -6,13 +6,16 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Types (
+module Data.Array.Nested.Types (
+ -- * Reasoning helpers
+ subst1, subst2,
+
-- * Reified evidence of a type class
Dict(..),
@@ -27,6 +30,7 @@ module Data.Array.Mixed.Types (
Replicate,
lemReplicateSucc,
MapJust,
+ lemMapJustEmpty, lemMapJustCons,
Head,
Tail,
Init,
@@ -43,6 +47,14 @@ import GHC.TypeNats qualified as TN
import Unsafe.Coerce qualified
+-- Reasoning helpers
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
-- | Evidence for the constraint @c a@.
data Dict c a where
Dict :: c a => Dict c a
@@ -100,10 +112,16 @@ type family Replicate n a where
lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = unsafeCoerceRefl
-type family MapJust l where
+type family MapJust l = r | r -> l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
+lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[]
+lemMapJustEmpty Refl = unsafeCoerceRefl
+
+lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
+lemMapJustCons Refl = unsafeCoerceRefl
+
type family Head l where
Head (x : _) = x
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Strided/Orthotope.hs
index ebda388..5c38d14 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Strided/Orthotope.hs
@@ -1,6 +1,6 @@
{-# LANGUAGE ImportQualifiedPost #-}
-module Data.Array.Mixed.Internal.Arith (
- module Data.Array.Mixed.Internal.Arith,
+module Data.Array.Strided.Orthotope (
+ module Data.Array.Strided.Orthotope,
module Data.Array.Strided.Arith,
) where
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/XArray.hs
index cb790e1..bf47622 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -11,7 +11,7 @@
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.XArray where
+module Data.Array.XArray where
import Control.DeepSeq (NFData)
import Data.Array.Internal qualified as OI
@@ -31,11 +31,11 @@ import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Types
+import Data.Array.Strided.Orthotope
type XArray :: [Maybe Nat] -> Type -> Type
@@ -76,7 +76,7 @@ cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
-> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
cast ssh1 sh2 ssh' (XArray arr)
| Refl <- lemRankApp ssh1 ssh'
- , Refl <- lemRankApp (ssxFromShape sh2) ssh'
+ , Refl <- lemRankApp (ssxFromShX sh2) ssh'
= let arrsh :: IShX sh1
(arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
in if shxToList arrsh == shxToList sh2
@@ -89,8 +89,8 @@ unScalar (XArray a) = S.unScalar a
replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
replicate sh ssh' (XArray arr)
| Dict <- lemKnownNatRankSSX ssh'
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
- , Refl <- lemRankApp (ssxFromShape sh) ssh'
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh')
+ , Refl <- lemRankApp (ssxFromShX sh) ssh'
= XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr)
arr)
@@ -243,7 +243,7 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
@@ -258,7 +258,7 @@ sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
= let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
sh'F = shxFlatten sh' :$% ZSX
- ssh'F = ssxFromShape sh'F
+ ssh'F = ssxFromShX sh'F
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
@@ -278,8 +278,8 @@ sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
= let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
shF = shxFlatten sh :$% ZSX
- in sumInner ssh' (ssxFromShape shF) $
- transpose2 (ssxFromShape shF) ssh' $
+ in sumInner ssh' (ssxFromShX shF) $
+ transpose2 (ssxFromShX shF) ssh' $
reshapePartial ssh ssh' shF $
arr
@@ -340,7 +340,7 @@ reshape ssh1 sh2 (XArray arr)
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh')
= XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
diff --git a/test/Gen.hs b/test/Gen.hs
index bf002ca..044de14 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -20,11 +20,10 @@ import Foreign
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Hedgehog
import Hedgehog.Gen qualified as Gen
@@ -60,9 +59,12 @@ shuffleShR = \sh -> go (length sh) (toList sh) sh
(dim :$:) <$> go (nbag - 1) bag' sh
genShR :: SNat n -> Gen (IShR n)
-genShR sn = do
+genShR = genShRwithTarget 100_000
+
+genShRwithTarget :: Int -> SNat n -> Gen (IShR n)
+genShRwithTarget targetMax sn = do
let n = fromSNat' sn
- targetSize <- Gen.int (Range.linear 0 100_000)
+ targetSize <- Gen.int (Range.linear 0 targetMax)
let genDims :: SNat m -> Int -> Gen (IShR m)
genDims SZ _ = return ZSR
genDims (SS m) 0 = do
@@ -94,10 +96,14 @@ genShR sn = do
-- other dimensions might be zero.
genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n)
genReplicatedShR = \m n -> do
- sh1 <- genShR m
+ let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m))
+ sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m
(sh2, sh3) <- injectOnes n sh1 sh1
return (sh1, sh2, sh3)
where
+ repvalrange = (1::Int, 5)
+ repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double
+
injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
injectOnes n@SNat shOnes sh
| m@SNat <- shrRank sh
@@ -106,7 +112,7 @@ genReplicatedShR = \m n -> do
EQI -> return (shOnes, sh)
GTI -> do
index <- Gen.int (Range.linear 0 (fromSNat' m))
- value <- Gen.int (Range.linear 1 5)
+ value <- Gen.int (uncurry Range.linear repvalrange)
Refl <- return (lem n m)
injectOnes n (inject index 1 shOnes) (inject index value sh)
@@ -116,7 +122,7 @@ genReplicatedShR = \m n -> do
inject :: Int -> Int -> IShR m -> IShR (m + 1)
inject 0 v sh = v :$: sh
inject i v (w :$: sh) = w :$: inject (i - 1) v sh
- inject _ v ZSR = v :$: ZSR -- invalid input, but meh
+ inject _ _ ZSR = error "unreachable"
genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a)
genStorables rng f = do
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 4861eb1..9567393 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -18,13 +18,13 @@ import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types (fromSNat')
import Hedgehog
import Hedgehog.Gen qualified as Gen
-import Hedgehog.Internal.Property (forAllT)
+import Hedgehog.Internal.Property (LabelName(..), forAllT)
import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -39,6 +39,9 @@ import Util
fineTol :: Double
fineTol = 1e-8
+debugCoverage :: Bool
+debugCoverage = False
+
prop_sum_nonempty :: Property
prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
@@ -62,7 +65,7 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (0 `elem` toList (shrTail sh))
+ guard (0 `elem` shrTail sh)
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @(n + 1) @Double (toList sh) []
let rarr = rfromOrthotope inrank arr
@@ -89,6 +92,10 @@ prop_sum_replicated doTranspose = property $
LTI -> return Refl -- actually we only continue if m < n
_ -> discard
(sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2
+ when debugCoverage $ do
+ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
+ label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
+ label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
guard (all (> 0) sh3)
arr <- forAllT $
OR.stretch (toList sh3)
diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs
index 1e7ad13..98a6da5 100644
--- a/test/Tests/Permutation.hs
+++ b/test/Tests/Permutation.hs
@@ -6,7 +6,7 @@ module Tests.Permutation where
import Data.Type.Equality
-import Data.Array.Mixed.Permutation
+import Data.Array.Nested.Permutation
import Hedgehog
import Hedgehog.Gen qualified as Gen
diff --git a/test/Util.hs b/test/Util.hs
index 34cf8ab..8a5ba72 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -15,7 +15,7 @@ import GHC.TypeLits
import Hedgehog
import Hedgehog.Internal.Property (failDiff)
-import Data.Array.Mixed.Types (fromSNat')
+import Data.Array.Nested.Types (fromSNat')
-- Returns highest value that satisfies the predicate, or `lo` if none does