diff options
45 files changed, 8680 insertions, 1938 deletions
@@ -1,2 +1,3 @@ dist-newstyle/ cabal.project.local +.ccls-cache/ diff --git a/.stylish-haskell.yaml b/.stylish-haskell.yaml new file mode 100644 index 0000000..bfd25ea --- /dev/null +++ b/.stylish-haskell.yaml @@ -0,0 +1,452 @@ +# stylish-haskell configuration file +# ================================== + +# The stylish-haskell tool is mainly configured by specifying steps. These steps +# are a list, so they have an order, and one specific step may appear more than +# once (if needed). Each file is processed by these steps in the given order. +steps: + # Convert some ASCII sequences to their Unicode equivalents. This is disabled + # by default. + # - unicode_syntax: + # # In order to make this work, we also need to insert the UnicodeSyntax + # # language pragma. If this flag is set to true, we insert it when it's + # # not already present. You may want to disable it if you configure + # # language extensions using some other method than pragmas. Default: + # # true. + # add_language_pragma: true + + # Format module header + # + # Currently, this option is not configurable and will format all exports and + # module declarations to minimize diffs + # + # - module_header: + # # How many spaces use for indentation in the module header. + # indent: 4 + # + # # Should export lists be sorted? Sorting is only performed within the + # # export section, as delineated by Haddock comments. + # sort: true + # + # # See `separate_lists` for the `imports` step. + # separate_lists: true + + # Format record definitions. This is disabled by default. + # + # You can control the layout of record fields. The only rules that can't be configured + # are these: + # + # - "|" is always aligned with "=" + # - "," in fields is always aligned with "{" + # - "}" is likewise always aligned with "{" + # + # - records: + # # How to format equals sign between type constructor and data constructor. + # # Possible values: + # # - "same_line" -- leave "=" AND data constructor on the same line as the type constructor. + # # - "indent N" -- insert a new line and N spaces from the beginning of the next line. + # equals: "indent 2" + # + # # How to format first field of each record constructor. + # # Possible values: + # # - "same_line" -- "{" and first field goes on the same line as the data constructor. + # # - "indent N" -- insert a new line and N spaces from the beginning of the data constructor + # first_field: "indent 2" + # + # # How many spaces to insert between the column with "," and the beginning of the comment in the next line. + # field_comment: 2 + # + # # How many spaces to insert before "deriving" clause. Deriving clauses are always on separate lines. + # deriving: 2 + # + # # How many spaces to insert before "via" clause counted from indentation of deriving clause + # # Possible values: + # # - "same_line" -- "via" part goes on the same line as "deriving" keyword. + # # - "indent N" -- insert a new line and N spaces from the beginning of "deriving" keyword. + # via: "indent 2" + # + # # Sort typeclass names in the "deriving" list alphabetically. + # sort_deriving: true + # + # # Wheter or not to break enums onto several lines + # # + # # Default: false + # break_enums: false + # + # # Whether or not to break single constructor data types before `=` sign + # # + # # Default: true + # break_single_constructors: true + # + # # Whether or not to curry constraints on function. + # # + # # E.g: @allValues :: Enum a => Bounded a => Proxy a -> [a]@ + # # + # # Instead of @allValues :: (Enum a, Bounded a) => Proxy a -> [a]@ + # # + # # Default: false + # curried_context: false + + # Align the right hand side of some elements. This is quite conservative + # and only applies to statements where each element occupies a single + # line. + # Possible values: + # - always - Always align statements. + # - adjacent - Align statements that are on adjacent lines in groups. + # - never - Never align statements. + # All default to always. + - simple_align: + cases: never + top_level_patterns: never + records: never + multi_way_if: never + + # Import cleanup + - imports: + # There are different ways we can align names and lists. + # + # - global: Align the import names and import list throughout the entire + # file. + # + # - file: Like global, but don't add padding when there are no qualified + # imports in the file. + # + # - group: Only align the imports per group (a group is formed by adjacent + # import lines). + # + # - none: Do not perform any alignment. + # + # Default: global. + align: group + + # The following options affect only import list alignment. + # + # List align has following options: + # + # - after_alias: Import list is aligned with end of import including + # 'as' and 'hiding' keywords. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # > init, last, length) + # + # - with_alias: Import list is aligned with start of alias or hiding. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # > init, last, length) + # + # - with_module_name: Import list is aligned `list_padding` spaces after + # the module name. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # init, last, length) + # + # This is mainly intended for use with `pad_module_names: false`. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # init, last, length, scanl, scanr, take, drop, + # sort, nub) + # + # - new_line: Import list starts always on new line. + # + # > import qualified Data.List as List + # > (concat, foldl, foldr, head, init, last, length) + # + # - repeat: Repeat the module name to align the import list. + # + # > import qualified Data.List as List (concat, foldl, foldr, head) + # > import qualified Data.List as List (init, last, length) + # + # Default: after_alias + list_align: after_alias + + # Right-pad the module names to align imports in a group: + # + # - true: a little more readable + # + # > import qualified Data.List as List (concat, foldl, foldr, + # > init, last, length) + # > import qualified Data.List.Extra as List (concat, foldl, foldr, + # > init, last, length) + # + # - false: diff-safe + # + # > import qualified Data.List as List (concat, foldl, foldr, init, + # > last, length) + # > import qualified Data.List.Extra as List (concat, foldl, foldr, + # > init, last, length) + # + # Default: true + pad_module_names: false + + # Long list align style takes effect when import is too long. This is + # determined by 'columns' setting. + # + # - inline: This option will put as much specs on same line as possible. + # + # - new_line: Import list will start on new line. + # + # - new_line_multiline: Import list will start on new line when it's + # short enough to fit to single line. Otherwise it'll be multiline. + # + # - multiline: One line per import list entry. + # Type with constructor list acts like single import. + # + # > import qualified Data.Map as M + # > ( empty + # > , singleton + # > , ... + # > , delete + # > ) + # + # Default: inline + long_list_align: new_line_multiline + + # Align empty list (importing instances) + # + # Empty list align has following options + # + # - inherit: inherit list_align setting + # + # - right_after: () is right after the module name: + # + # > import Vector.Instances () + # + # Default: inherit + empty_list_align: inherit + + # List padding determines indentation of import list on lines after import. + # This option affects 'long_list_align'. + # + # - <integer>: constant value + # + # - module_name: align under start of module name. + # Useful for 'file' and 'group' align settings. + # + # Default: 4 + list_padding: 2 + + # Separate lists option affects formatting of import list for type + # or class. The only difference is single space between type and list + # of constructors, selectors and class functions. + # + # - true: There is single space between Foldable type and list of it's + # functions. + # + # > import Data.Foldable (Foldable (fold, foldl, foldMap)) + # + # - false: There is no space between Foldable type and list of it's + # functions. + # + # > import Data.Foldable (Foldable(fold, foldl, foldMap)) + # + # Default: true + separate_lists: false + + # Space surround option affects formatting of import lists on a single + # line. The only difference is single space after the initial + # parenthesis and a single space before the terminal parenthesis. + # + # - true: There is single space associated with the enclosing + # parenthesis. + # + # > import Data.Foo ( foo ) + # + # - false: There is no space associated with the enclosing parenthesis + # + # > import Data.Foo (foo) + # + # Default: false + space_surround: false + + # Enabling this argument will use the new GHC lib parse to format imports. + # + # This currently assumes a few things, it will assume that you want post + # qualified imports. It is also not as feature complete as the old + # imports formatting. + # + # It does not remove redundant lines or merge lines. As such, the full + # feature scope is still pending. + # + # It _is_ however, a fine alternative if you are using features that are + # not parseable by haskell src extensions and you're comfortable with the + # presets. + # + # Default: false + ghc_lib_parser: false + + # Post qualify option moves any qualifies found in import declarations + # to the end of the declaration. This also adjust padding for any + # unqualified import declarations. + # + # - true: Qualified as <module name> is moved to the end of the + # declaration. + # + # > import Data.Bar + # > import Data.Foo qualified as F + # + # - false: Qualified remains in the default location and unqualified + # imports are padded to align with qualified imports. + # + # > import Data.Bar + # > import qualified Data.Foo as F + # + # Default: false + post_qualify: true + + # A list of rules specifying how to group modules and how to + # order the groups. + # + # Each rule has a match field; the rule only applies to module + # names matched by this pattern. Patterns are POSIX extended + # regular expressions; see the documentation of Text.Regex.TDFA + # for details: + # https://hackage.haskell.org/package/regex-tdfa-1.3.1.2/docs/Text-Regex-TDFA.html + # + # Rules are processed in order, so only the *first* rule that + # matches a specific module will apply. Any module names that do + # not match a single rule will be put into a single group at the + # end of the import block. + # + # Example: group MyApp modules first, with everything else in + # one group at the end. + # + # group_rules: + # - match: "^MyApp\\>" + # + # > import MyApp + # > import MyApp.Foo + # > + # > import Control.Monad + # > import MyApps + # > import Test.MyApp + # + # A rule can also optionally have a sub_group pattern. Imports + # that match the rule will be broken up into further groups by + # the part of the module name matched by the sub_group pattern. + # + # Example: group MyApp modules first, then everything else + # sub-grouped by the first part of the module name. + # + # group_rules: + # - match: "^MyApp\\>" + # - match: "." + # sub_group: "^[^.]+" + # + # > import MyApp + # > import MyApp.Foo + # > + # > import Control.Applicative + # > import Control.Monad + # > + # > import Data.Map + # + # A pattern only needs to match part of the module name, which + # could be in the middle. You can use ^pattern to anchor to the + # beginning of the module name, pattern$ to anchor to the end + # and ^pattern$ to force a full match. Example: + # + # - "Test\\." would match "Test.Foo" and "Foo.Test.Lib" + # - "^Test\\." would match "Test.Foo" but not "Foo.Test.Lib" + # - "\\.Test$" would match "Foo.Test" but not "Foo.Test.Lib" + # - "^Test$" would *only* match "Test" + # + # You can use \\< and \\> to anchor against the beginning and + # end of words, respectively. For example: + # + # - "^Test\\." would match "Test.Foo" but not "Test" or "Tests" + # - "^Test\\>" would match "Test.Foo" and "Test", but not + # "Tests" + # + # The default is a single rule that matches everything and + # sub-groups based on the first component of the module name. + # + # Default: [{ "match" : ".*", "sub_group": "^[^.]+" }] +# group_rules: +# - match: ".*" +# sub_group: "^[^.]+" +# - match: "^Data.Array\\>" +# sub_group: "^[^.]+" +# - match: "^HordeAd\\>" + + # Language pragmas + - language_pragmas: + # We can generate different styles of language pragma lists. + # + # - vertical: Vertical-spaced language pragmas, one per line. + # + # - compact: A more compact style. + # + # - compact_line: Similar to compact, but wrap each line with + # `{-#LANGUAGE #-}'. + # + # Default: vertical. +# style: compact + + # Align affects alignment of closing pragma brackets. + # + # - true: Brackets are aligned in same column. + # + # - false: Brackets are not aligned together. There is only one space + # between actual import and closing bracket. + # + # Default: true + align: false + + # stylish-haskell can detect redundancy of some language pragmas. If this + # is set to true, it will remove those redundant pragmas. Default: true. + remove_redundant: true + + # Language prefix to be used for pragma declaration, this allows you to + # use other options non case-sensitive like "language" or "Language". + # If a non correct String is provided, it will default to: LANGUAGE. + language_prefix: LANGUAGE + + # Replace tabs by spaces. This is disabled by default. + # - tabs: + # # Number of spaces to use for each tab. Default: 8, as specified by the + # # Haskell report. + # spaces: 8 + + # Remove trailing whitespace + - trailing_whitespace: {} + + # Squash multiple spaces between the left and right hand sides of some + # elements into single spaces. Basically, this undoes the effect of + # simple_align but is a bit less conservative. + # - squash: {} + +# A common setting is the number of columns (parts of) code will be wrapped +# to. Different steps take this into account. +# +# Set this to null to disable all line wrapping. +# +# Default: 80. +columns: 200 + +# By default, line endings are converted according to the OS. You can override +# preferred format here. +# +# - native: Native newline format. CRLF on Windows, LF on other OSes. +# +# - lf: Convert to LF ("\n"). +# +# - crlf: Convert to CRLF ("\r\n"). +# +# Default: native. +newline: native + +# Sometimes, language extensions are specified in a cabal file or from the +# command line instead of using language pragmas in the file. stylish-haskell +# needs to be aware of these, so it can parse the file correctly. +# +# No language extensions are enabled by default. +#language_extensions: +# - NoStarIsType + # - TemplateHaskell + # - QuasiQuotes + +# Attempt to find the cabal file in ancestors of the current directory, and +# parse options (currently only language extensions) from that. +# +# Default: true +cabal: true diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..009d267 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,7 @@ +# Changelog for `ox-arrays` + +This package intends to follow the [PVP](https://pvp.haskell.org/). + +## 0.1.0.0 +- Initial release +- Various aspects of the API are still experimental, and breaking changes are expected in the future. @@ -1,56 +1,165 @@ -Wrapper library around `orthotope` that defines nested arrays, including -tuples, of (eventually) unboxed values. The arrays are represented in -struct-of-arrays form via the `Data.Vector.Unboxed` data family trick. Below -the surface layer, there is a more low-level wrapper around `orthotope` that -defines an array type type-indexed by `[Maybe Nat]`: some dimensions are -shape-typed (i.e. have their size statically known), and some not. +## ox-arrays -An overview of the API: +ox-arrays is an array library that defines nested arrays, including tuples, of +(eventually) unboxed values. The arrays are represented in struct-of-arrays +form via the `Data.Vector.Unboxed` data family trick; the component arrays are +`orthotope` arrays +([RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html)) +which describe elements using a _stride vector_ or +[LMAD](https://dl.acm.org/doi/pdf/10.1145/509705.509708) so that `transpose` +and `replicate` need only modify array metadata, not actually move around data. + +Because of the struct-of-arrays representation, nested arrays are not fully +general: indeed, arrays are not actually nested under the hood, so if one has an +array of arrays, those element arrays must all have the same shape (length, +width, etc.). If one has an array of tuples of arrays, then all the `fst` +components must have the same shape and all the `snd` components must have the +same shape, but the two pair components themselves can be different. + +However, the nesting functionality of ox-arrays can be completely ignored if you +only care about other parts of its API, or the vectorised arithmetic operations +(using hand-written C code). Nesting support mostly does not get in the way, and +has essentially no overhead (both when it's used and when it's not used). + +ox-arrays defines three array types: `Ranked`, `Shaped` and `Mixed`. +- `Ranked` corresponds to `orthotope`'s + [RankedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-RankedS.html) + and has the _rank_ of the array (its number of dimensions) on the type level. + For example, `Ranked 2 Float` is a two-dimensional array of `Float`s, i.e. a + matrix. +- `Shaped` corresponds to `orthotope`'s + [ShapedS](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html). + and has the full _shape_ of the array (its dimension sizes) on the type level + as a type-level list of `Nat`s. For example, `Shaped [2,3] Float` is a 2-by-3 + matrix. The innermost dimension correspond to the right-most element in the + list. +- `Mixed` is halfway between the two: it has a type parameter of kind + `[Maybe Nat]` whose length is the rank of the array; `Nothing` elements have + unknown size, whereas `Just` elements have the indicated size. The type + `Mixed [Nothing, Nothing] a` is equivalent to `Ranked 2 a`; the type + `Mixed [Just n, Just m] a` is equivalent to `Shaped [n, m] a`. + +In various places in the API of a library like ox-arrays, one can make a +decision between 1. requiring a type class constraint providing certain +information (e.g. +[KnownNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:KnownNat) +or `orthotope`'s +[Shape](https://hackage.haskell.org/package/orthotope-0.1.7.0/docs/Data-Array-ShapedS.html#t:Shape)), +or 2. taking singleton _values_ that encode said information in a way that is +linked to the type level (e.g. +[SNat](https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-TypeLits.html#t:SNat)). +`orthotope` chooses the type class approach; ox-arrays chooses the singleton +approach. Singletons are more verbose at times, but give the programmer more +insight in what data is flowing where, and more importantly, more control: type +class inference is very nice and implicit, but if it's not powerful enough for +the trickery you're doing, you're out of luck. Singletons allow you to explain +as precisely as you want to GHC what exactly you're doing. + +Below the surface layer, there is a more low-level wrapper (`XArray`) around +`orthotope` that defines a non-nested `Mixed`-style array type. + +Here is a little taster of the API, to get a sense for the design: ```haskell -data Ranked (n :: INat) a {- e.g. -} Ranked 3 Float -data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float -data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float - -Ranked I0 a = Ranked Z a ~~= Acc.Array Z a = Acc.Scalar a -Ranked I1 a = Ranked (S Z) a ~~= Acc.Array (Z :. Int) a = Acc.Vector a -Ranked I2 a = Ranked (S (S Z)) a ~~= Acc.Array (Z :. Int :. Int) a = Acc.Matrix a +import GHC.TypeLits (Nat, SNat) + +data Ranked (n :: Nat) a {- e.g. -} Ranked 3 Float +data Shaped (sh :: '[Nat]) a {- e.g. -} Shaped [2,3,4] Float +data Mixed (xsh :: '[Maybe Nat]) a {- e.g. -} Mixed [Just 2, Nothing, Just 4] Float +-- Shape types are written Sh{R,S,X}. The 'I' prefix denotes a Int-filled shape; +-- ShR and ShX are more general containers. ShS is a singleton. +rshape :: Elt a => Ranked n a -> IShR n +sshape :: Elt a => Shaped sh a -> ShS sh +mshape :: Elt a => Mixed xsh a -> IShX xsh -rshape :: (Elt a, KnownINat n) => Ranked n a -> IxR n -sshape :: (Elt a, KnownShape sh) => Shaped sh a -> IxS sh -mshape :: (Elt a, KnownShapeX xsh) => Mixed xsh a -> IxX xsh +-- Index types are written Ix{R,S,X}. +rindex :: Elt a => Ranked n a -> IIxR n -> a +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +mindex :: Elt a => Mixed xsh a -> IIxX xsh -> a -rindex :: Elt a => Ranked n a -> IxR n -> a -sindex :: Elt a => Shaped sh a -> IxS sh -> a -mindex :: Elt a => Mixed xsh a -> IxX xsh -> a +-- The index types can be used as if they were defined as follows; pattern +-- synonyms are provided to construct the illusion. (The actual definitions are +-- a bit more general and indirect.) +data IIxR n where + ZIR :: IIxR 0 + (:.:) :: Int -> IIxR n -> IIxR (n + 1) -data IxR n where - IZR :: IxR Z - (:::) :: Int -> IxR n -> IxR (S n) +data IIxS sh where + ZIS :: IIxS '[] + (:.$) :: Int -> IIxS sh -> IIxS (n : sh) -data IxS sh where - IZS :: IxS '[] - (::$) :: Int -> IxS sh -> IxS (n : sh) +data IIxX xsh where + ZIX :: IIxX '[] + (:.%) :: Int -> IIxX xsh -> IIxX (mn : xsh) -data IxX sh where - IZX :: IxX '[] - (::@) :: Int -> IxX sh -> IxX (Just n : sh) - (::?) :: Int -> IxX sh -> IxX (Nothing : sh) +-- Similarly, the shape types can be used as if they were defined as follows. +data IShR n where + ZSR :: IShR 0 + (:$:) :: Int -> IShR n -> IShR (n + 1) +data ShS sh where + ZSS :: ShS '[] + (:$$) :: SNat n -> ShS sh -> ShS (n : sh) + +data IShX xsh where + ZSX :: IShX '[] + (:$%) :: SMayNat Int SNat mn -> IShX xsh -> IShX (mn : xsh) +-- where: +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) + +-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed +-- shape -- that is to say, only the statically-known part of a mixed shape. +-- StaticShX provides for this need. It can be used as if defined as follows: +data StaticShX xsh where + ZKX :: StaticShX '[] + (:!%) :: SMayNat () SNat mn -> StaticShX xsh -> StaticShX (mn : xsh) + +-- The Elt class describes types that can be used as elements of an array. While +-- it is technically possible to define new instances of this class, typical +-- usage should regard Elt as closed. The user-relevant instances are the +-- following: class Elt a -instance Elt () -instance Elt Double -instance Elt Int -instance (Elt a, Elt b) => Elt (a, b) -instance (Elt a, KnownINat n) => Elt (Ranked n a) -instance (Elt a, KnownShape sh) => Elt (Shaped sh a) -instance (Elt a, KnownShapeX xsh) => Elt (Mixed xsh a) - -rgenerate :: Elt a => IxR n -> (IxR n -> a) -> Ranked n a -sgenerate :: (Elt a, KnownShape sh) => (IxS sh -> a) -> Shaped sh a -mgenerate :: (Elt a, KnownShapeX xsh) => IxX xsh -> (IxX xsh -> a) -> Mixed xsh a +instance Elt () +instance Elt Bool +instance Elt Float +instance Elt Double +instance Elt Int +instance (Elt a, Elt b) => Elt (a, b) +instance Elt a => Elt (Ranked n a) +instance Elt a => Elt (Shaped sh a) +instance Elt a => Elt (Mixed xsh a) +-- Essentially all functions that ox-arrays offers on arrays are first-order: +-- add two arrays elementwise, transpose an array, append arrays, compute +-- minima/maxima, zip/unzip, nest/unnest, etc. The first-order approach allows +-- operations, especially arithmetic ones, to be vectorised using hand-written +-- C code, without needing any sort of JIT compilation. +rappend :: Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +mappend :: Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a + +-- Exceptionally, also one higher-order function is provided per array type: +-- 'generate'. These functions have the caveat that regularity of arrays must be +-- preserved: all returned 'a's must have equal shape. See the documentation of +-- 'mgenerate'. +-- Warning: because the invocations of the function you pass cannot be +-- vectorised, 'generate' is rather slow if 'a' is small. +-- The 'KnownElt' class captures an API infelicity where constraint-based shape +-- passing is the only practical option. +rgenerate :: KnownElt a => IShR n -> (IxR n -> a) -> Ranked n a +sgenerate :: KnownElt a => ShS sh -> (IxS sh -> a) -> Shaped sh a +mgenerate :: KnownElt a => IShX xsh -> (IxX xsh -> a) -> Mixed xsh a + +-- Under the hood, Ranked and Shaped are both newtypes over Mixed. Mixed itself +-- is a data family over XArray, which is a newtype over orthotope's RankedS. newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) ``` + +About the name: when importing `orthotope` array modules, a possible naming +convention is to use qualified imports as `OR` for "orthotope ranked" arrays and +`OS` for "orthotope shaped" arrays. ox-arrays was started to fill the `OX` gap, +then grew out of proportion. diff --git a/bench/Main.hs b/bench/Main.hs new file mode 100644 index 0000000..b604eb9 --- /dev/null +++ b/bench/Main.hs @@ -0,0 +1,244 @@ +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} +module Main where + +import Control.Exception (bracket) +import Control.Monad (when) +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Foldable (toList) +import Data.Vector.Storable qualified as VS +import Numeric.LinearAlgebra qualified as LA +import Test.Tasty.Bench +import Text.Show (showListWith) + +import Data.Array.Nested +import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) +import Data.Array.Nested.Ranked (liftRanked1, liftRanked2) +import Data.Array.Strided.Arith.Internal qualified as Arith +import Data.Array.XArray (XArray(..)) + + +enableMisc :: Bool +enableMisc = False + +bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark +bgroupIf True = bgroup +bgroupIf False = \name _ -> bgroup name [] + + +main :: IO () +main = do + let enable = False + bracket (Arith.statisticsEnable enable) + (\() -> do Arith.statisticsEnable False + when enable Arith.statisticsPrintAll) + (\() -> main_tests) + +main_tests :: IO () +main_tests = defaultMain + [bgroup "compare" tests_compare + ,bgroup "dotprod" $ + let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides + dotprodBench name (inp1, inp2) = + let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int + in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n) + l "" + in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++ + " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $ + nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2) + + iota = riota @Double + in + [dotprodBench "dot 1D" + (iota 10_000_000 + ,iota 10_000_000) + ,dotprodBench "revdot" + (rrev1 (iota 10_000_000) + ,rrev1 (iota 10_000_000)) + ,dotprodBench "dot 2D" + (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000) + ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)) + ,dotprodBench "batched dot" + (rreplicate (1000 :$: ZSR) (iota 10_000) + ,rreplicate (1000 :$: ZSR) (iota 10_000)) + ,dotprodBench "transposed dot" $ + let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000) + ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)) + in (rtranspose [1,0] a, rtranspose [1,0] b) + ,dotprodBench "repdot" $ + let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000) + ,rreplicate (1000 :$: ZSR) (iota 10_000)) + in (rtranspose [1,0] a, rtranspose [1,0] b) + ,dotprodBench "matvec" $ + let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000) + ,iota 10_000) + in (m, rreplicate (1000 :$: ZSR) v) + ,dotprodBench "vecmat" $ + let (v, m) = (iota 1_000 + ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)) + in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m) + ,dotprodBench "matmat" $ + let (n,m,k) = (100, 100, 1000) + (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m)) + ,rreshape (m :$: k :$: ZSR) (iota (m*k))) + in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1) + ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2)) + ,dotprodBench "matmatT" $ + let (n,m,k) = (100, 100, 1000) + (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m)) + ,rreshape (k :$: m :$: ZSR) (iota (m*k))) + in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1) + ,rreplicate (n :$: ZSR) m2) + ] + ,bgroup "orthotope" + [bench "normalize [1e6]" $ + let n = 1_000_000 + in nf (\a -> RS.normalize a) + (RS.rev [0] (RS.iota @Double n)) + ,bench "normalize noop [1e6]" $ + let n = 1_000_000 + in nf (\a -> RS.normalize a) + (RS.rev [0] (RS.rev [0] (RS.iota @Double n))) + ] + ,bgroupIf enableMisc "misc" + [let n = 1000 + k = 1000 + in bgroup ("fusion [" ++ show k ++ "]*" ++ show n) + [bench "sum (concat)" $ + nf (\as -> VS.sum (VS.concat as)) + (replicate n (VS.enumFromTo (1::Int) k)) + ,bench "sum (force (concat))" $ + nf (\as -> VS.sum (VS.force (VS.concat as))) + (replicate n (VS.enumFromTo (1::Int) k))] + ,bgroup "concat" + [bgroup "N" + [bgroup "hmatrix" + [bench ("LA.vjoin [500]*1e" ++ show ni) $ + let n = 10 ^ ni + k = 500 + in nf (\as -> LA.vjoin as) + (replicate n (VS.enumFromTo (1::Int) k)) + | ni <- [1::Int ..5]] + ,bgroup "vectorStorable" + [bench ("VS.concat [500]*1e" ++ show ni) $ + let n = 10 ^ ni + k = 500 + in nf (\as -> VS.concat as) + (replicate n (VS.enumFromTo (1::Int) k)) + | ni <- [1::Int ..5]] + ] + ,bgroup "K" + [bgroup "hmatrix" + [bench ("LA.vjoin [1e" ++ show ki ++ "]*500") $ + let n = 500 + k = 10 ^ ki + in nf (\as -> LA.vjoin as) + (replicate n (VS.enumFromTo (1::Int) k)) + | ki <- [1::Int ..5]] + ,bgroup "vectorStorable" + [bench ("VS.concat [1e" ++ show ki ++ "]*500") $ + let n = 500 + k = 10 ^ ki + in nf (\as -> VS.concat as) + (replicate n (VS.enumFromTo (1::Int) k)) + | ki <- [1::Int ..5]] + ] + ] + ] + ] + +tests_compare :: [Benchmark] +tests_compare = + let n = 1_000_000 in + [bgroup "Num" + [bench "sum(+) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b))) + (riota @Double n, riota n) + ,bench "sum(*) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b))) + (riota @Double n, riota n) + ,bench "sum(/) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b))) + (riota @Double n, riota n) + ,bench "sum(**) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b))) + (riota @Double n, riota n) + ,bench "sum(sin) Double [1e6]" $ + nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a))) + (riota @Double n) + ,bench "sum Double [1e6]" $ + nf (\a -> runScalar (rsumOuter1 a)) + (riota @Double n) + ] + ,bgroup "NumElt" + [bench "sum(+) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (a + b))) + (riota @Double n, riota n) + ,bench "sum(*) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + (riota @Double n, riota n) + ,bench "sum(/) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (a / b))) + (riota @Double n, riota n) + ,bench "sum(**) Double [1e6]" $ + nf (\(a, b) -> runScalar (rsumOuter1 (a ** b))) + (riota @Double n, riota n) + ,bench "sum(sin) Double [1e6]" $ + nf (\a -> runScalar (rsumOuter1 (sin a))) + (riota @Double n) + ,bench "sum Double [1e6]" $ + nf (\a -> runScalar (rsumOuter1 a)) + (riota @Double n) + ,bench "sum(*) Double [1e6] stride 1; -1" $ + nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + (riota @Double n, rrev1 (riota n)) + ,bench "dotprod Float [1e6]" $ + nf (\(a, b) -> rdot a b) + (riota @Float n, riota @Float n) + ,bench "dotprod Float [1e6] stride 1; -1" $ + nf (\(a, b) -> rdot a b) + (riota @Float n, rrev1 (riota @Float n)) + ,bench "dotprod Double [1e6]" $ + nf (\(a, b) -> rdot a b) + (riota @Double n, riota @Double n) + ,bench "dotprod Double [1e6] stride 1; -1" $ + nf (\(a, b) -> rdot a b) + (riota @Double n, rrev1 (riota @Double n)) + ] + ,bgroup "hmatrix" + [bench "sum(+) Double [1e6]" $ + nf (\(a, b) -> LA.sumElements (a + b)) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "sum(*) Double [1e6]" $ + nf (\(a, b) -> LA.sumElements (a * b)) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "sum(/) Double [1e6]" $ + nf (\(a, b) -> LA.sumElements (a / b)) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "sum(**) Double [1e6]" $ + nf (\(a, b) -> LA.sumElements (a ** b)) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "sum(sin) Double [1e6]" $ + nf (\a -> LA.sumElements (sin a)) + (LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "sum Double [1e6]" $ + nf (\a -> LA.sumElements a) + (LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "dotprod Float [1e6]" $ + nf (\(a, b) -> a LA.<.> b) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (fromIntegral (n - 1), 0.0)) + ,bench "dotprod Double [1e6]" $ + nf (\(a, b) -> a LA.<.> b) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (fromIntegral (n - 1), 0.0)) + ] + ] diff --git a/cabal.project b/cabal.project index 697d3bd..d102ed6 100644 --- a/cabal.project +++ b/cabal.project @@ -1,5 +1,2 @@ packages: . -with-compiler: ghc-9.8.2 - -allow-newer: - orthotope:deepseq +with-compiler: ghc-9.8.4 diff --git a/cbits/arith.c b/cbits/arith.c new file mode 100644 index 0000000..f19b01e --- /dev/null +++ b/cbits/arith.c @@ -0,0 +1,808 @@ +#include <stdio.h> +#include <stdint.h> +#include <inttypes.h> +#include <stdlib.h> +#include <stdbool.h> +#include <stdatomic.h> +#include <string.h> +#include <math.h> +#include <threads.h> +#include <sys/time.h> + +// These are the wrapper macros used in arith_lists.h. Preset them to empty to +// avoid having to touch macros unrelated to the particular operation set below. +#define LIST_BINOP(name, id, hsop) +#define LIST_IBINOP(name, id, hsop) +#define LIST_FBINOP(name, id, hsop) +#define LIST_UNOP(name, id, _) +#define LIST_FUNOP(name, id, _) +#define LIST_REDOP(name, id, _) + + +// Shorter names, due to CPP used both in function names and in C types. +typedef int32_t i32; +typedef int64_t i64; + + +// PRECONDITIONS +// +// All strided array operations in this file assume that none of the shape +// components are zero -- that is, the input arrays are non-empty. This must +// be arranged on the Haskell side. +// +// Furthermore, note that while the Haskell side has an offset into the backing +// vector, the C side assumes that the offset is zero. Shift the pointer if +// necessary. + + +/***************************************************************************** + * Performance statistics * + *****************************************************************************/ + +// Each block holds a buffer with variable-length messages. Each message starts +// with a tag byte; the respective sublists below give the fields after that tag +// byte. +// - 1: unary operation performance measurement +// - u8: some identifier +// - i32: input rank +// - i64[rank]: input shape +// - i64[rank]: input strides +// - f64: seconds taken +// - 2: binary operation performance measurement +// - u8: a stats_binary_id +// - i32: input rank +// - i64[rank]: input shape +// - i64[rank]: input 1 strides +// - i64[rank]: input 2 strides +// - f64: seconds taken +// The 'prev' and 'cap' fields are set only once on creation of a block, and can +// thus be read without restrictions. The 'len' field is potentially mutated +// from different threads and must be handled with care. +struct stats_block { + struct stats_block *prev; // backwards linked list; NULL if first block + size_t cap; // bytes capacity of buffer in this block + atomic_size_t len; // bytes filled in this buffer + uint8_t buf[]; // trailing VLA +}; + +enum stats_binary_id { + sbi_dotprod = 1, +}; + +// Atomic because blocks may be allocated from different threads. +static _Atomic(struct stats_block*) stats_current = NULL; +static atomic_bool stats_enabled = false; + +void oxarrays_stats_enable(i32 yes) { atomic_store(&stats_enabled, yes == 1); } + +static uint8_t* stats_alloc(size_t nbytes) { +try_again: ; + struct stats_block *block = atomic_load(&stats_current); + size_t curlen = block != NULL ? atomic_load(&block->len) : 0; + size_t curcap = block != NULL ? block->cap : 0; + + if (block == NULL || curlen + nbytes > curcap) { + const size_t newcap = stats_current == NULL ? 4096 : 2 * stats_current->cap; + struct stats_block *new = malloc(sizeof(struct stats_block) + newcap); + new->prev = stats_current; + curcap = new->cap = newcap; + curlen = new->len = 0; + if (!atomic_compare_exchange_strong(&stats_current, &block, new)) { + // Race condition, simply free this memory block and try again + free(new); + goto try_again; + } + block = new; + } + + // Try to update the 'len' field of the block we captured at the start of the + // function. Note that it doesn't matter if someone else already allocated a + // new block in the meantime; we're still accessing the same block here, which + // may succeed or fail independently. + while (!atomic_compare_exchange_strong(&block->len, &curlen, curlen + nbytes)) { + // curlen was updated to the actual value. + // If the block got full in the meantime, try again from the start + if (curlen + nbytes > curcap) goto try_again; + } + + return block->buf + curlen; +} + +__attribute__((unused)) +static void stats_record_unary(enum stats_binary_id id, i32 rank, const i64 *shape, const i64 *strides, double secs) { + if (!atomic_load(&stats_enabled)) return; + uint8_t *buf = stats_alloc(1 + 1 + 4 + 2*rank*8 + 8); + *buf = 1; buf += 1; + *buf = id; buf += 1; + *(i32*)buf = rank; buf += 4; + memcpy((i64*)buf, shape, rank * 8); buf += rank * 8; + memcpy((i64*)buf, strides, rank * 8); buf += rank * 8; + *(double*)buf = secs; +} + +__attribute__((unused)) +static void stats_record_binary(enum stats_binary_id id, i32 rank, const i64 *shape, const i64 *strides1, const i64 *strides2, double secs) { + if (!atomic_load(&stats_enabled)) return; + uint8_t *buf = stats_alloc(1 + 1 + 4 + 3*rank*8 + 8); + *buf = 2; buf += 1; + *buf = id; buf += 1; + *(i32*)buf = rank; buf += 4; + memcpy((i64*)buf, shape, rank * 8); buf += rank * 8; + memcpy((i64*)buf, strides1, rank * 8); buf += rank * 8; + memcpy((i64*)buf, strides2, rank * 8); buf += rank * 8; + *(double*)buf = secs; +} + +#define TIME_START(varname_) \ + struct timeval varname_ ## _start, varname_ ## _end; \ + gettimeofday(&varname_ ## _start, NULL); +#define TIME_END(varname_) \ + (gettimeofday(&varname_ ## _end, NULL), \ + ((varname_ ## _end).tv_sec - (varname_ ## _start).tv_sec) + \ + ((varname_ ## _end).tv_usec - (varname_ ## _start).tv_usec) / (double)1e6) + +static size_t stats_print_unary(uint8_t *buf) { + uint8_t *orig_buf = buf; + + enum stats_binary_id id = *buf; buf += 1; + i32 rank = *(i32*)buf; buf += 4; + i64 *shape = (i64*)buf; buf += rank * 8; + i64 *strides = (i64*)buf; buf += rank * 8; + double secs = *(double*)buf; buf += 8; + + i64 shsize = 1; for (i32 i = 0; i < rank; i++) shsize *= shape[i]; + + printf("unary %d sz %" PRIi64 " ms %.3lf sh=[", (int)id, shsize, secs * 1000); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); } + printf("] str=["); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides[i]); } + printf("]\n"); + + return buf - orig_buf; +} + +static size_t stats_print_binary(uint8_t *buf) { + uint8_t *orig_buf = buf; + + enum stats_binary_id id = *buf; buf += 1; + i32 rank = *(i32*)buf; buf += 4; + i64 *shape = (i64*)buf; buf += rank * 8; + i64 *strides1 = (i64*)buf; buf += rank * 8; + i64 *strides2 = (i64*)buf; buf += rank * 8; + double secs = *(double*)buf; buf += 8; + + i64 shsize = 1; for (i32 i = 0; i < rank; i++) shsize *= shape[i]; + + printf("binary %d sz %" PRIi64 " ms %.3lf sh=[", (int)id, shsize, secs * 1000); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); } + printf("] str1=["); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides1[i]); } + printf("] str2=["); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides2[i]); } + printf("]\n"); + + return buf - orig_buf; +} + +// Also frees the printed log. +void oxarrays_stats_print_all(void) { + printf("=== ox-arrays-arith-stats start ===\n"); + + // Claim the entire chain and prevent new blocks from being added to it. + // (This is technically slightly wrong because a value may still be in the + // process of being recorded to some blocks in the chain while we're doing + // this printing, but yolo) + struct stats_block *last = atomic_exchange(&stats_current, NULL); + + // Reverse the linked list; after this loop, the 'prev' pointers point to the + // _next_ block, not the previous one. + struct stats_block *block = last; + if (last != NULL) { + struct stats_block *next = NULL; + // block next + // ##### <-##### <-##### NULL + while (block->prev != NULL) { + struct stats_block *prev = block->prev; + // prev block next + // ##### <-##### <-##### ##... + block->prev = next; + // prev block next + // ##### <-##### #####-> ##... + next = block; + // prev bl=nx + // ##### <-##### #####-> ##... + block = prev; + // block next + // ##### <-##### #####-> ##... + } + // block next + // NULL <-##### #####-> ##... + block->prev = next; + // block next + // NULL #####-> #####-> ##... + } + + while (block != NULL) { + for (size_t i = 0; i < block->len; ) { + switch (block->buf[i]) { + case 1: i += 1 + stats_print_unary(block->buf + i+1); break; + case 2: i += 1 + stats_print_binary(block->buf + i+1); break; + default: + printf("# UNKNOWN ENTRY WITH ID %d, SKIPPING BLOCK\n", (int)block->buf[i]); + i = block->len; + break; + } + } + struct stats_block *next = block->prev; // remember, reversed! + free(block); + block = next; + } + + printf("=== ox-arrays-arith-stats end ===\n"); +} + + +/***************************************************************************** + * Additional math functions * + *****************************************************************************/ + +#define GEN_ABS(x) \ + _Generic((x), \ + int: abs, \ + long: labs, \ + long long: llabs, \ + float: fabsf, \ + double: fabs)(x) + +// This does not result in multiple loads with GCC 13. +#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) + +#define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y) +#define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x)) +#define GEN_ATAN2(y, x) _Generic((x), float: atan2f(y, x), double: atan2(y, x)) +#define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x) +#define GEN_LOG(x) _Generic((x), float: logf, double: log)(x) +#define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x) +#define GEN_SIN(x) _Generic((x), float: sinf, double: sin)(x) +#define GEN_COS(x) _Generic((x), float: cosf, double: cos)(x) +#define GEN_TAN(x) _Generic((x), float: tanf, double: tan)(x) +#define GEN_ASIN(x) _Generic((x), float: asinf, double: asin)(x) +#define GEN_ACOS(x) _Generic((x), float: acosf, double: acos)(x) +#define GEN_ATAN(x) _Generic((x), float: atanf, double: atan)(x) +#define GEN_SINH(x) _Generic((x), float: sinhf, double: sinh)(x) +#define GEN_COSH(x) _Generic((x), float: coshf, double: cosh)(x) +#define GEN_TANH(x) _Generic((x), float: tanhf, double: tanh)(x) +#define GEN_ASINH(x) _Generic((x), float: asinhf, double: asinh)(x) +#define GEN_ACOSH(x) _Generic((x), float: acoshf, double: acosh)(x) +#define GEN_ATANH(x) _Generic((x), float: atanhf, double: atanh)(x) +#define GEN_LOG1P(x) _Generic((x), float: log1pf, double: log1p)(x) +#define GEN_EXPM1(x) _Generic((x), float: expm1f, double: expm1)(x) + +// Taken from Haskell's implementation: +// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#log1mexpOrd +#define LOG1MEXP_IMPL(x) do { \ + if (x > _Generic((x), float: logf, double: log)(2)) return GEN_LOG(-GEN_EXPM1(x)); \ + else return GEN_LOG1P(-GEN_EXP(x)); \ + } while (0) + +static float log1mexp_float(float x) { LOG1MEXP_IMPL(x); } +static double log1mexp_double(double x) { LOG1MEXP_IMPL(x); } + +#define GEN_LOG1MEXP(x) _Generic((x), float: log1mexp_float, double: log1mexp_double)(x) + +// Taken from Haskell's implementation: +// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#line-595 +#define LOG1PEXP_IMPL(x) do { \ + if (x <= 18) return GEN_LOG1P(GEN_EXP(x)); \ + if (x <= 100) return x + GEN_EXP(-x); \ + return x; \ + } while (0) + +static float log1pexp_float(float x) { LOG1PEXP_IMPL(x); } +static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } + +#define GEN_LOG1PEXP(x) _Generic((x), float: log1pexp_float, double: log1pexp_double)(x) + + +/***************************************************************************** + * Helper functions * + *****************************************************************************/ + +__attribute__((used)) +static void print_shape(FILE *stream, i64 rank, const i64 *shape) { + fputc('[', stream); + for (i64 i = 0; i < rank; i++) { + if (i != 0) fputc(',', stream); + fprintf(stream, "%" PRIi64, shape[i]); + } + fputc(']', stream); +} + + +/***************************************************************************** + * Skeletons * + *****************************************************************************/ + +// Walk a orthotope-style strided array, except for the inner dimension. The +// body is run for every "inner vector". +// Provides idx, outlinidx, arrlinidx. +#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, ...) \ + do { \ + i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \ + memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ + i64 arrlinidx = 0; \ + i64 outlinidx = 0; \ + again_label_name: \ + { \ + __VA_ARGS__ \ + } \ + for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ + if (++idx[dim] < (shape)[dim]) { \ + arrlinidx += (strides)[dim]; \ + outlinidx++; \ + goto again_label_name; \ + } \ + arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \ + idx[dim] = 0; \ + } \ + } while (false) + +// Walk TWO orthotope-style strided arrays simultaneously, except for their +// inner dimension. The arrays must have the same shape, but may have different +// strides. The body is run for every pair of "inner vectors". +// Provides idx, outlinidx, arrlinidx1, arrlinidx2. +#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, ...) \ + do { \ + i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \ + memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ + i64 arrlinidx1 = 0, arrlinidx2 = 0; \ + i64 outlinidx = 0; \ + again_label_name: \ + { \ + __VA_ARGS__ \ + } \ + for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ + if (++idx[dim] < (shape)[dim]) { \ + arrlinidx1 += (strides1)[dim]; \ + arrlinidx2 += (strides2)[dim]; \ + outlinidx++; \ + goto again_label_name; \ + } \ + arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \ + arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \ + idx[dim] = 0; \ + } \ + } while (false) + + +/***************************************************************************** + * Kernel functions * + *****************************************************************************/ + +#define COMM_OP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + if (rank == 0) { out[0] = x op y[0]; return; } \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \ + } \ + }); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + if (rank == 0) { out[0] = x[0] op y[0]; return; } \ + TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \ + } \ + }); \ + } + +#define NONCOMM_OP_STRIDED(name, op, typ) \ + COMM_OP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + if (rank == 0) { out[0] = x[0] op y; return; } \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \ + } \ + }); \ + } + +#define PREFIX_BINOP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + if (rank == 0) { out[0] = op(x, y[0]); return; } \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \ + } \ + }); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + if (rank == 0) { out[0] = op(x[0], y[0]); return; } \ + TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \ + } \ + }); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + if (rank == 0) { out[0] = op(x[0], y); return; } \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \ + } \ + }); \ + } + +#define UNARY_OP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ + /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \ + print_shape(stderr, rank, shape); \ + fprintf(stderr, " strides="); \ + print_shape(stderr, rank, strides); \ + fprintf(stderr, "\n"); */ \ + if (rank == 0) { out[0] = op(arr[0]); return; } \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \ + } \ + }); \ + } + +// Used for reduction and dot product kernels below +#define MANUAL_VECT_WID 8 + +// Used in REDUCE1_OP and REDUCEFULL_OP below +#define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \ + do { \ + const i64 n = innerLen; const i64 s = innerStride; \ + if (n < MANUAL_VECT_WID) { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < n; i++) accum = accum op arr[arrlinidx + s * i]; \ + destination = accum; \ + } else { \ + typ accum[MANUAL_VECT_WID]; \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr[arrlinidx + s * j]; \ + for (i64 i = 1; i < n / MANUAL_VECT_WID; i++) { \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) { \ + accum[j] = accum[j] op arr[arrlinidx + s * (MANUAL_VECT_WID * i + j)]; \ + } \ + } \ + typ res = accum[0]; \ + for (i64 j = 1; j < MANUAL_VECT_WID; j++) res = res op accum[j]; \ + for (i64 i = n / MANUAL_VECT_WID * MANUAL_VECT_WID; i < n; i++) \ + res = res op arr[arrlinidx + s * i]; \ + destination = res; \ + } \ + } while (0) + +// Reduces along the innermost dimension. +// 'out' will be filled densely in linearisation order. +#define REDUCE1_OP(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, out[outlinidx]); \ + }); \ + } + +#define REDUCEFULL_OP(name, op, typ) \ + typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + if (rank == 0) return arr[0]; \ + typ result = 0; \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ + }); \ + return result; \ + } + +// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex. +#define EXTREMUM_OP(name, cmp, typ) \ + void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + if (rank == 0) return; /* output index vector has length 0 anyways */ \ + typ best = arr[0]; \ + memset(outidx, 0, rank * sizeof(i64)); \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + strides[rank - 1] * i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } + +// Reduces along the innermost dimension. +// 'out' will be filled densely in linearisation order. +#define DOTPROD_INNER_OP(typ) \ + void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ + TIME_START(tm); \ + TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ + const i64 length = shape[rank - 1], stride1 = strides1[rank - 1], stride2 = strides2[rank - 1]; \ + if (length < MANUAL_VECT_WID) { \ + typ res = 0; \ + for (i64 i = 0; i < length; i++) res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \ + out[outlinidx] = res; \ + } else { \ + typ accum[MANUAL_VECT_WID]; \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[arrlinidx1 + stride1 * j] * arr2[arrlinidx2 + stride2 * j]; \ + for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) \ + accum[j] += arr1[arrlinidx1 + stride1 * (MANUAL_VECT_WID * i + j)] * arr2[arrlinidx2 + stride2 * (MANUAL_VECT_WID * i + j)]; \ + typ res = accum[0]; \ + for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \ + for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \ + res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \ + out[outlinidx] = res; \ + } \ + }); \ + stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \ + } + + +/***************************************************************************** + * Entry point functions * + *****************************************************************************/ + +__attribute__((noreturn, cold)) +static void wrong_op(const char *name, int tag) { + fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag); + abort(); +} + +enum binop_tag_t { +#undef LIST_BINOP +#define LIST_BINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_BINOP +#define LIST_BINOP(name, id, hsop) +}; + +#define ENTRY_BINARY_STRIDED_OPS(typ) \ + void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + switch (tag) { \ + case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("binary_sv_strided", tag); \ + } \ + } \ + void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + switch (tag) { \ + case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \ + case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \ + default: wrong_op("binary_vs_strided", tag); \ + } \ + } \ + void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + switch (tag) { \ + case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case BO_MUL: oxarop_op_mul_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("binary_vv_strided", tag); \ + } \ + } + +enum ibinop_tag_t { +#undef LIST_IBINOP +#define LIST_IBINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_IBINOP +#define LIST_IBINOP(name, id, hsop) +}; + +#define ENTRY_IBINARY_STRIDED_OPS(typ) \ + void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("ibinary_sv_strided", tag); \ + } \ + } \ + void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + default: wrong_op("ibinary_vs_strided", tag); \ + } \ + } \ + void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("ibinary_vv_strided", tag); \ + } \ + } + +enum fbinop_tag_t { +#undef LIST_FBINOP +#define LIST_FBINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_FBINOP +#define LIST_FBINOP(name, id, hsop) +}; + +#define ENTRY_FBINARY_STRIDED_OPS(typ) \ + void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + switch (tag) { \ + case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case FB_ATAN2: oxarop_op_atan2_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("fbinary_sv_strided", tag); \ + } \ + } \ + void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + switch (tag) { \ + case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + default: wrong_op("fbinary_vs_strided", tag); \ + } \ + } \ + void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + switch (tag) { \ + case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("fbinary_vv_strided", tag); \ + } \ + } + +enum unop_tag_t { +#undef LIST_UNOP +#define LIST_UNOP(name, id, _) name = id, +#include "arith_lists.h" +#undef LIST_UNOP +#define LIST_UNOP(name, id, _) +}; + +#define ENTRY_UNARY_STRIDED_OPS(typ) \ + void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \ + switch (tag) { \ + case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case UO_SIGNUM: oxarop_op_signum_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + default: wrong_op("unary_strided", tag); \ + } \ + } + +enum funop_tag_t { +#undef LIST_FUNOP +#define LIST_FUNOP(name, id, _) name = id, +#include "arith_lists.h" +#undef LIST_FUNOP +#define LIST_FUNOP(name, id, _) +}; + +#define ENTRY_FUNARY_STRIDED_OPS(typ) \ + void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \ + switch (tag) { \ + case FU_RECIP: oxarop_op_recip_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_EXP: oxarop_op_exp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_LOG: oxarop_op_log_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_SQRT: oxarop_op_sqrt_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_SIN: oxarop_op_sin_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_COS: oxarop_op_cos_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_TAN: oxarop_op_tan_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_ASIN: oxarop_op_asin_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_ACOS: oxarop_op_acos_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_ATAN: oxarop_op_atan_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_SINH: oxarop_op_sinh_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_COSH: oxarop_op_cosh_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_TANH: oxarop_op_tanh_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_ASINH: oxarop_op_asinh_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_ACOSH: oxarop_op_acosh_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_ATANH: oxarop_op_atanh_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_LOG1P: oxarop_op_log1p_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_EXPM1: oxarop_op_expm1_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + default: wrong_op("funary_strided", tag); \ + } \ + } + +enum redop_tag_t { +#undef LIST_REDOP +#define LIST_REDOP(name, id, _) name = id, +#include "arith_lists.h" +#undef LIST_REDOP +#define LIST_REDOP(name, id, _) +}; + +#define ENTRY_REDUCE1_OPS(typ) \ + void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ + switch (tag) { \ + case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \ + case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \ + default: wrong_op("reduce", tag); \ + } \ + } + +#define ENTRY_REDUCEFULL_OPS(typ) \ + typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + switch (tag) { \ + case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \ + case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \ + default: wrong_op("reduce", tag); \ + } \ + } + + +/***************************************************************************** + * Generate all the functions * + *****************************************************************************/ + +#define INT_TYPES_XLIST X(i32) X(i64) +#define FLOAT_TYPES_XLIST X(double) X(float) +#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST + +#define X(typ) \ + COMM_OP_STRIDED(add, +, typ) \ + NONCOMM_OP_STRIDED(sub, -, typ) \ + COMM_OP_STRIDED(mul, *, typ) \ + UNARY_OP_STRIDED(neg, -, typ) \ + UNARY_OP_STRIDED(abs, GEN_ABS, typ) \ + UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \ + REDUCE1_OP(sum1, +, typ) \ + REDUCE1_OP(product1, *, typ) \ + REDUCEFULL_OP(sumfull, +, typ) \ + REDUCEFULL_OP(productfull, *, typ) \ + ENTRY_BINARY_STRIDED_OPS(typ) \ + ENTRY_UNARY_STRIDED_OPS(typ) \ + ENTRY_REDUCE1_OPS(typ) \ + ENTRY_REDUCEFULL_OPS(typ) \ + EXTREMUM_OP(min, <, typ) \ + EXTREMUM_OP(max, >, typ) \ + DOTPROD_INNER_OP(typ) +NUM_TYPES_XLIST +#undef X + +#define X(typ) \ + NONCOMM_OP_STRIDED(quot, /, typ) \ + NONCOMM_OP_STRIDED(rem, %, typ) \ + ENTRY_IBINARY_STRIDED_OPS(typ) +INT_TYPES_XLIST +#undef X + +#define X(typ) \ + NONCOMM_OP_STRIDED(fdiv, /, typ) \ + PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \ + PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \ + PREFIX_BINOP_STRIDED(atan2, GEN_ATAN2, typ) \ + UNARY_OP_STRIDED(recip, 1.0/, typ) \ + UNARY_OP_STRIDED(exp, GEN_EXP, typ) \ + UNARY_OP_STRIDED(log, GEN_LOG, typ) \ + UNARY_OP_STRIDED(sqrt, GEN_SQRT, typ) \ + UNARY_OP_STRIDED(sin, GEN_SIN, typ) \ + UNARY_OP_STRIDED(cos, GEN_COS, typ) \ + UNARY_OP_STRIDED(tan, GEN_TAN, typ) \ + UNARY_OP_STRIDED(asin, GEN_ASIN, typ) \ + UNARY_OP_STRIDED(acos, GEN_ACOS, typ) \ + UNARY_OP_STRIDED(atan, GEN_ATAN, typ) \ + UNARY_OP_STRIDED(sinh, GEN_SINH, typ) \ + UNARY_OP_STRIDED(cosh, GEN_COSH, typ) \ + UNARY_OP_STRIDED(tanh, GEN_TANH, typ) \ + UNARY_OP_STRIDED(asinh, GEN_ASINH, typ) \ + UNARY_OP_STRIDED(acosh, GEN_ACOSH, typ) \ + UNARY_OP_STRIDED(atanh, GEN_ATANH, typ) \ + UNARY_OP_STRIDED(log1p, GEN_LOG1P, typ) \ + UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \ + UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \ + UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \ + ENTRY_FBINARY_STRIDED_OPS(typ) \ + ENTRY_FUNARY_STRIDED_OPS(typ) +FLOAT_TYPES_XLIST +#undef X + +// Note: [zero-length VLA] +// +// Zero-length variable-length arrays are not allowed in C(99). Thus whenever we +// have a VLA that could sometimes suffice to be empty (e.g. `idx` in the +// TARRAY_WALK_NOINNER macros), we tweak the length formula (typically by just +// adding 1) so that it never ends up empty. diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h new file mode 100644 index 0000000..432765c --- /dev/null +++ b/cbits/arith_lists.h @@ -0,0 +1,39 @@ +LIST_BINOP(BO_ADD, 1, +) +LIST_BINOP(BO_SUB, 2, -) +LIST_BINOP(BO_MUL, 3, *) + +LIST_IBINOP(IB_QUOT, 1, quot) +LIST_IBINOP(IB_REM, 2, rem) + +LIST_FBINOP(FB_DIV, 1, /) +LIST_FBINOP(FB_POW, 2, **) +LIST_FBINOP(FB_LOGBASE, 3, logBase) +LIST_FBINOP(FB_ATAN2, 4, atan2) + +LIST_UNOP(UO_NEG, 1,) +LIST_UNOP(UO_ABS, 2,) +LIST_UNOP(UO_SIGNUM, 3,) + +LIST_FUNOP(FU_RECIP, 1,) +LIST_FUNOP(FU_EXP, 2,) +LIST_FUNOP(FU_LOG, 3,) +LIST_FUNOP(FU_SQRT, 4,) +LIST_FUNOP(FU_SIN, 5,) +LIST_FUNOP(FU_COS, 6,) +LIST_FUNOP(FU_TAN, 7,) +LIST_FUNOP(FU_ASIN, 8,) +LIST_FUNOP(FU_ACOS, 9,) +LIST_FUNOP(FU_ATAN, 10,) +LIST_FUNOP(FU_SINH, 11,) +LIST_FUNOP(FU_COSH, 12,) +LIST_FUNOP(FU_TANH, 13,) +LIST_FUNOP(FU_ASINH, 14,) +LIST_FUNOP(FU_ACOSH, 15,) +LIST_FUNOP(FU_ATANH, 16,) +LIST_FUNOP(FU_LOG1P, 17,) +LIST_FUNOP(FU_EXPM1, 18,) +LIST_FUNOP(FU_LOG1PEXP, 19,) +LIST_FUNOP(FU_LOG1MEXP, 20,) + +LIST_REDOP(RO_SUM, 1,) +LIST_REDOP(RO_PRODUCT, 2,) diff --git a/example/Main.hs b/example/Main.hs new file mode 100644 index 0000000..76c75c2 --- /dev/null +++ b/example/Main.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeApplications #-} +module Main where + +import Data.Array.Nested + + +arr :: Ranked 2 (Shaped [2, 3] (Double, Int)) +arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> + sgenerate (SNat @2 :$$ SNat @3 :$$ ZSS) $ \(k :.$ l :.$ ZIS) -> + let s = 24*i + 6*j + 3*k + l + in (fromIntegral s, s) + +foo :: (Double, Int) +foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS) + +bad :: Ranked 2 (Ranked 1 Double) +bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> + rgenerate (i :$: ZSR) $ \(k :.: ZIR) -> + let s = 24*i + 6*j + 3*k + in fromIntegral s + +main :: IO () +main = do + print arr + print foo + print (rtranspose [1,0] arr) + -- print bad diff --git a/gentrace.sh b/gentrace.sh new file mode 100755 index 0000000..c3f1240 --- /dev/null +++ b/gentrace.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +cat <<'EOF' +module Data.Array.Nested.Trace ( + -- * Traced variants + module Data.Array.Nested.Trace, + + -- * Re-exports from the plain "Data.Array.Nested" module +EOF + +sed -n '/^module/,/^) where/!d; /^\s*--\( \|$\)/d; s/ \b[a-z][a-zA-Z0-9_'"'"']*,//g; /^ $/d; s/(\.\., Z.., ([^)]*))/(..)/g; /^ /p; /^$/p' src/Data/Array/Nested.hs + +cat <<'EOF' +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.Nested +import Data.Array.Nested.Trace.TH + + +EOF + +# shellcheck disable=SC2016 # dollar in single-quoted string +echo '$(concat <$> mapM convertFun' +sed -n '/^module/,/^) where/!d; /^\s*-- /d; /^ /p' src/Data/Array/Nested.hs | + grep -o '\b[a-z][a-zA-Z0-9_'"'"']*\b' | + grep -wv -e 'pattern' -e 'type' | + tr $'\n' ' ' | + sed 's/\([^ ]\+\)/'"'"'\1,/g; s/, $/])/; s/^/ [/' +echo diff --git a/ops/Data/Array/Strided.hs b/ops/Data/Array/Strided.hs new file mode 100644 index 0000000..7d8c2d0 --- /dev/null +++ b/ops/Data/Array/Strided.hs @@ -0,0 +1,7 @@ +module Data.Array.Strided ( + module Data.Array.Strided.Array, + module Data.Array.Strided.Arith, +) where + +import Data.Array.Strided.Arith +import Data.Array.Strided.Array diff --git a/ops/Data/Array/Strided/Arith.hs b/ops/Data/Array/Strided/Arith.hs new file mode 100644 index 0000000..7be6390 --- /dev/null +++ b/ops/Data/Array/Strided/Arith.hs @@ -0,0 +1,7 @@ +module Data.Array.Strided.Arith ( + NumElt(..), + IntElt(..), + FloatElt(..), +) where + +import Data.Array.Strided.Arith.Internal diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs new file mode 100644 index 0000000..5802573 --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -0,0 +1,933 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Strided.Arith.Internal where + +import Control.Monad +import Data.Bifunctor (second) +import Data.Bits +import Data.Int +import Data.List (sort, zip4) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable +import GHC.TypeLits +import GHC.TypeNats qualified as TypeNats +import Language.Haskell.TH +import System.IO (hFlush, stdout) +import System.IO.Unsafe + +import Data.Array.Strided.Arith.Internal.Foreign +import Data.Array.Strided.Arith.Internal.Lists +import Data.Array.Strided.Array + + +-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition + + +-- TODO: move this to a utilities module +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +data Dict c where + Dict :: c => Dict c + +debugShow :: forall n a. (Storable a, KnownNat n) => Array n a -> String +debugShow (Array sh strides offset vec) = + "Array @" ++ show (natVal (Proxy @n)) ++ " " ++ show sh ++ " " ++ show strides ++ " " ++ show offset ++ " <_*" ++ show (VS.length vec) ++ ">" + + +-- TODO: test all the cases of this thing with various input strides +liftOpEltwise1 :: Storable a + => SNat n + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) + -> Array n a -> Array n a +liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + if blockSz == 0 + then Array sh (map (const 0) strides) 0 VS.empty + else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [blockSz] [1] blockOff vec) + in Array sh strides (offset - blockOff) resvec + | otherwise = wrapUnary sn ptrconv cf_strided arr + +-- TODO: test all the cases of this thing with various input strides +liftOpEltwise2 :: Storable a + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv + -> Array n a -> Array n a -> Array n a +liftOpEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv + arr1@(Array sh1 strides1 offset1 vec1) + arr2@(Array sh2 strides2 offset2 vec2) + | sh1 /= sh2 = error $ "liftOpEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | any (<= 0) sh1 = Array sh1 (0 <$ strides1) 0 VS.empty + | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of + (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2)) + in Array sh1 strides1 0 vec' + + (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense + let arr2' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec2) + resvec = arrValues $ wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2' + in Array sh1 strides2 (offset2 - blockOff) resvec + + (Just (_, 1), Nothing) -> -- scalar * array + wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2 + + (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar + let arr1' = arrayFromVector [blockSz] (VS.slice blockOff blockSz vec1) + resvec = arrValues $ wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2) + in Array sh1 strides1 (offset1 - blockOff) resvec + + (Nothing, Just (_, 1)) -> -- array * scalar + wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2) + + (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) + | strides1 == strides2 + -> -- dense * dense but the strides match + if blockSz1 /= blockSz2 || offset1 - blockOff1 /= offset2 - blockOff2 + then error $ "Data.Array.Strided.Ops.Internal(liftOpEltwise2): Internal error: cannot happen " ++ show (strides1, (blockOff1, blockSz1), strides2, (blockOff2, blockSz2)) + else + let arr1' = arrayFromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) + arr2' = arrayFromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) + resvec = arrValues $ wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2' + in Array sh1 strides1 (offset1 - blockOff1) resvec + + (_, _) -> -- fallback case + wrapBinaryVV sn ptrconv f_vv arr1 arr2 + +-- | Given shape vector, offset and stride vector, check whether this virtual +-- vector uses a dense subarray of its backing array. If so, the first index +-- and the number of elements in this subarray is returned. +-- This excludes any offset. +stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int) +stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0) +stridesDense sh offsetNeg stridesNeg = + -- First reverse all dimensions with negative stride, so that the first used + -- value is at 'offset' and the rest is >= offset. + let (offset, strides) = flipReverseds sh offsetNeg stridesNeg + in -- sort dimensions on their stride, ascending, dropping any zero strides + case filter ((/= 0) . fst) (sort (zip strides sh)) of + [] -> Just (offset, 1) + (1, n) : pairs -> (offset,) <$> checkCover n pairs + _ -> Nothing -- if the smallest stride is not 1, it will never be dense + where + -- Given size of currently densely covered region at beginning of the + -- array and the remaining (stride, size) pairs with all strides >=1, + -- return whether this all together covers a dense prefix of the array. If + -- it does, return the number of elements in this prefix. + checkCover :: Int -> [(Int, Int)] -> Maybe Int + checkCover block [] = Just block + checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs + + -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0 + flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int]) + flipReverseds [] off [] = (off, []) + flipReverseds (n : sh') off (s : str') + | s >= 0 = second (s :) (flipReverseds sh' off str') + | otherwise = + let off' = off + (n - 1) * s + in second ((-s) :) (flipReverseds sh' off' str') + flipReverseds _ _ _ = error "flipReverseds: invalid arguments" + +data Unreplicated a = + forall n'. KnownNat n' => + -- | Let the original array, with replicated dimensions, be called A. + Unreplicated -- | An array with all strides /= 0. Call this array U. It has + -- the same shape as A, except with all the replicated (stride + -- == 0) dimensions removed. The shape of U is the + -- "unreplicated shape". + (Array n' a) + -- | Product of sizes of the unreplicated dimensions + Int + -- | Given the stride vector of an array with the unreplicated + -- shape, this function reinserts zeros so that it may be + -- combined with the original shape of A. + ([Int] -> [Int]) + +-- | Removes all replicated dimensions (i.e. those with stride == 0) from the array. +unreplicateStrides :: Array n a -> Unreplicated a +unreplicateStrides (Array sh strides offset vec) = + let replDims = map (== 0) strides + (shF, stridesF) = unzip [(n, s) | (n, s) <- zip sh strides, s /= 0] + + reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' + reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' + reinsertZeros [] [] = [] + reinsertZeros (False : _) [] = error "unreplicateStrides: Internal error: reply strides too short" + reinsertZeros [] (_:_) = error "unreplicateStrides: Internal error: reply strides too long" + + unrepSize = product [n | (n, True) <- zip sh replDims] + + in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims) + +simplifyArray :: Array n a + -> (forall n'. KnownNat n' + => Array n' a -- U + -- Product of sizes of the unreplicated dimensions + -> Int + -- Convert index in U back to index into original + -- array. Replicated dimensions get 0. + -> ([Int] -> [Int]) + -- Given a new array of the same shape as U, convert + -- it back to the original shape and iteration order. + -> (Array n' a -> Array n a) + -- Do the same except without the INNER dimension. + -- This throws an error if the inner dimension had + -- stride 0. + -> (Array (n' - 1) a -> Array (n - 1) a) + -> r) + -> r +simplifyArray array k + | let revDims = map (<0) (arrStrides array) + , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array) + = k array' + unrepSize + (\idx -> rereplicate (zipWith3 (\b n i -> if b then n - 1 - i else i) + revDims (arrShape array') idx)) + (\(Array sh' strides' offset' vec') -> + if sh' == arrShape array' + then arrayRevDims revDims (Array (arrShape array) (rereplicate strides') offset' vec') + else error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")") + (\(Array sh' strides' offset' vec') -> + if | sh' /= init (arrShape array') -> + error $ "simplifyArray: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show (arrShape array') ++ ")" + | last (arrStrides array) == 0 -> + error "simplifyArray: Internal error: reduction reply handler used while inner stride was 0" + | otherwise -> + arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec')) + +-- | The two input arrays must have the same shape. +simplifyArray2 :: Array n a -> Array n a + -> (forall n'. KnownNat n' + => Array n' a -- U1 + -> Array n' a -- U2 (same shape as U1) + -- Product of sizes of the dimensions that are + -- replicated in neither input + -> Int + -- Convert index in U{1,2} back to index into original + -- arrays. Dimensions that are replicated in both + -- inputs get 0. + -> ([Int] -> [Int]) + -- Given a new array of the same shape as U1 (& U2), + -- convert it back to the original shape and + -- iteration order. + -> (Array n' a -> Array n a) + -- Do the same except without the INNER dimension. + -- This throws an error if the inner dimension had + -- stride 0 in both inputs. + -> (Array (n' - 1) a -> Array (n - 1) a) + -> r) + -> r +simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k + | sh /= sh2 = error "simplifyArray2: Unequal shapes" + + | let revDims = zipWith (\s1 s2 -> s1 < 0 && s2 < 0) (arrStrides arr1) (arrStrides arr2) + , Array _ strides1 offset1 vec1 <- arrayRevDims revDims arr1 + , Array _ strides2 offset2 vec2 <- arrayRevDims revDims arr2 + + , let replDims = zipWith (\s1 s2 -> s1 == 0 && s2 == 0) strides1 strides2 + , let (shF, strides1F, strides2F) = unzip3 [(n, s1, s2) | (n, s1, s2, False) <- zip4 sh strides1 strides2 replDims] + + , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' + reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' + reinsertZeros [] [] = [] + reinsertZeros (False : _) [] = error "simplifyArray2: Internal error: reply strides too short" + reinsertZeros [] (_:_) = error "simplifyArray2: Internal error: reply strides too long" + + , let unrepSize = product [n | (n, True) <- zip sh replDims] + + = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + k @lenshF + (Array shF strides1F offset1 vec1) + (Array shF strides2F offset2 vec2) + unrepSize + (\idx -> zipWith3 (\b n i -> if b then n - 1 - i else i) + revDims sh (reinsertZeros replDims idx)) + (\(Array sh' strides' offset' vec') -> + if sh' /= shF then error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" + else arrayRevDims revDims (Array sh (reinsertZeros replDims strides') offset' vec')) + (\(Array sh' strides' offset' vec') -> + if | sh' /= init shF -> + error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" + | last replDims -> + error "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated" + | otherwise -> + arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec')) + +{-# NOINLINE wrapUnary #-} +wrapUnary :: forall a b n. Storable a + => SNat n + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) + -> Array n a + -> Array n a +wrapUnary _ ptrconv cf_strided array = + simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do + let ndims' = length sh + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides -> + VS.unsafeWith vec $ \pv -> + let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a)) + in cf_strided (fromIntegral ndims') (ptrconv poutv) psh pstrides pv' + restore . arrayFromVector sh <$> VS.unsafeFreeze outv + +{-# NOINLINE wrapBinarySV #-} +wrapBinarySV :: forall a b n. Storable a + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) + -> a -> Array n a + -> Array n a +wrapBinarySV SNat valconv ptrconv cf_strided x array = + simplifyArray array $ \(Array sh strides offset vec) _ _ restore _ -> unsafePerformIO $ do + let ndims' = length sh + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides)) $ \pstrides -> + VS.unsafeWith vec $ \pv -> + let pv' = pv `plusPtr` (offset * sizeOf (undefined :: a)) + in cf_strided (fromIntegral ndims') psh (ptrconv poutv) (valconv x) pstrides pv' + restore . arrayFromVector sh <$> VS.unsafeFreeze outv + +wrapBinaryVS :: Storable a + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) + -> Array n a -> a + -> Array n a +wrapBinaryVS sn valconv ptrconv cf_strided arr y = + wrapBinarySV sn valconv ptrconv + (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr + +-- | The two shapes must be equal and non-empty. This is checked. +{-# NOINLINE wrapBinaryVV #-} +wrapBinaryVV :: forall a b n. Storable a + => SNat n + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) + -> Array n a -> Array n a + -> Array n a +-- TODO: do unreversing and unreplication on the input arrays (but +-- simultaneously: can only unreplicate if _both_ are replicated on that +-- dimension) +wrapBinaryVV sn@SNat ptrconv cf_strided + (Array sh strides1 offset1 vec1) + (Array sh2 strides2 offset2 vec2) + | sh /= sh2 = error $ "wrapBinaryVV: unequal shapes: " ++ show sh ++ " and " ++ show sh2 + | any (<= 0) sh = error $ "wrapBinaryVV: empty shape: " ++ show sh + | otherwise = unsafePerformIO $ do + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 -> + VS.unsafeWith vec1 $ \pv1 -> + VS.unsafeWith vec2 $ \pv2 -> + let pv1' = pv1 `plusPtr` (offset1 * sizeOf (undefined :: a)) + pv2' = pv2 `plusPtr` (offset2 * sizeOf (undefined :: a)) + in cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 pv1' pstrides2 pv2' + arrayFromVector sh <$> VS.unsafeFreeze outv + +-- TODO: test handling of negative strides +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -> Array (n + 1) a -> Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides offset vec) + | null sh = error "unreachable" + | last sh <= 0 = arrayFromConstant (init sh) 0 + | any (<= 0) (init sh) = Array (init sh) (0 <$ init strides) 0 VS.empty + -- now the input array is nonempty + | last sh == 1 = Array (init sh) (init strides) offset vec + | last strides == 0 = + wrapBinarySV sn valconv ptrconv fscale (fromIntegral @Int @a (last sh)) + (Array (init sh) (init strides) offset vec) + -- now there is useful work along the inner dimension + -- Note that unreplication keeps the inner dimension intact, because `last strides /= 0` at this point. + | otherwise = + simplifyArray array $ \(Array sh' strides' offset' vec' :: Array n' a) _ _ _ restore -> unsafePerformIO $ do + let ndims' = length sh' + outv <- VSM.unsafeNew (product (init sh')) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides -> + VS.unsafeWith vec' $ \pv -> + let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) + in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv') + TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do + (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of + LTI -> pure Dict + EQI -> pure Dict + _ -> error "impossible" -- because `last strides /= 0` + case sameNat (natSing @(n' - 1)) (natSing @n'm1) of + Just Refl -> restore . arrayFromVector @_ @n'm1 (init sh') <$> VS.unsafeFreeze outv + Nothing -> error "impossible" + +-- TODO: test handling of negative strides +-- | Reduce full array +{-# NOINLINE vectorRedFullOp #-} +vectorRedFullOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> Int -> a) + -> (b -> a) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel + -> Array n a -> a +vectorRedFullOp _ scaleval valbackconv ptrconv fred array@(Array sh strides offset vec) + | null sh = vec VS.! offset -- 0D array has one element + | any (<= 0) sh = 0 + -- now the input array is nonempty + | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset + -- now there is at least one non-replicated dimension + | otherwise = + simplifyArray array $ \(Array sh' strides' offset' vec') unrepSize _ _ _ -> unsafePerformIO $ do + let ndims' = length sh' + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides -> + VS.unsafeWith vec' $ \pv -> + let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) + in (`scaleval` unrepSize) . valbackconv + <$> fred (fromIntegral ndims') psh pstrides (ptrconv pv') + +-- TODO: test this function +-- | Find extremum (minindex ("argmin") or maxindex) in full array +{-# NOINLINE vectorExtremumOp #-} +vectorExtremumOp :: forall a b n. Storable a + => (Ptr a -> Ptr b) + -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel + -> Array n a -> [Int] -- result length: n +vectorExtremumOp ptrconv fextrem array@(Array sh strides _ _) + | null sh = [] + | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array" + -- now the input array is nonempty + | all (== 0) strides = 0 <$ sh + -- now there is at least one non-replicated dimension + | otherwise = + simplifyArray array $ \(Array sh' strides' offset' vec') _ upindex _ _ -> unsafePerformIO $ do + let ndims' = length sh' + outvR <- VSM.unsafeNew (length sh') + VSM.unsafeWith outvR $ \poutv -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral sh')) $ \psh -> + VS.unsafeWith (VS.fromListN ndims' (map fromIntegral strides')) $ \pstrides -> + VS.unsafeWith vec' $ \pv -> + let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) + in fextrem poutv (fromIntegral ndims') psh pstrides (ptrconv pv') + upindex . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outvR + +{-# NOINLINE vectorDotprodInnerOp #-} +vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (SNat n -> Array n a -> Array n a -> Array n a) -- ^ elementwise multiplication + -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel + -> Array (n + 1) a -> Array (n + 1) a -> Array n a +vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner + arr1@(Array sh1 strides1 offset1 vec1) + arr2@(Array sh2 strides2 offset2 vec2) + | null sh1 || null sh2 = error "unreachable" + | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | last sh1 <= 0 = arrayFromConstant (init sh1) 0 + | any (<= 0) (init sh1) = Array (init sh1) (0 <$ init strides1) 0 VS.empty + -- now the input arrays are nonempty + | last sh1 == 1 = + fmul sn (Array (init sh1) (init strides1) offset1 vec1) + (Array (init sh2) (init strides2) offset2 vec2) + | last strides1 == 0 = + fmul sn + (Array (init sh1) (init strides1) offset1 vec1) + (vectorRedInnerOp sn valconv ptrconv fscale fred arr2) + | last strides2 == 0 = + fmul sn + (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) + (Array (init sh2) (init strides2) offset2 vec2) + -- now there is useful dotprod work along the inner dimension + | otherwise = + simplifyArray2 arr1 arr2 $ \(Array sh' strides1' offset1' vec1' :: Array n' a) (Array _ strides2' offset2' vec2') _ _ _ restore -> + unsafePerformIO $ do + let inrank = length sh' + outv <- VSM.unsafeNew (product (init sh')) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh')) $ \psh -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1')) $ \pstrides1 -> + VS.unsafeWith vec1' $ \pvec1 -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2')) $ \pstrides2 -> + VS.unsafeWith vec2' $ \pvec2 -> + fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) + pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) + pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) + TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do + (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of + LTI -> pure Dict + EQI -> pure Dict + GTI -> error "impossible" -- because `last strides1 /= 0` + case sameNat (natSing @(n' - 1)) (natSing @n'm1) of + Just Refl -> restore . arrayFromVector (init sh') <$> VS.unsafeFreeze outv + Nothing -> error "impossible" + +mulWithInt :: Num a => a -> Int -> a +mulWithInt a i = a * fromIntegral i + + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_binary_" ++ atCName arithtype + c_ss_str = varE (aboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM intTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_ibinary_" ++ atCName arithtype + c_ss_str = varE (aiboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_fbinary_" ++ atCName arithtype + c_ss_str = varE (afboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> liftOpEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> Array n $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> liftOpEltwise1 sn id $c_op_strided |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let scaleVar = case arithop of + RO_SUM -> varE 'mulWithInt + RO_PRODUCT -> varE '(^) + let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype)) + namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype)) + c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) + sequence [SigD name1 <$> + [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |] + return $ FunD name1 [Clause [] (NormalB body) []] + ,SigD namefull <$> + [t| forall n. SNat n -> Array n $ttyp -> $ttyp |] + ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |] + return $ FunD namefull [Clause [] (NormalB body) []] + ]) + +$(fmap concat . forM typesList $ \arithtype -> + fmap concat . forM ["min", "max"] $ \fname -> do + let ttyp = conT (atType arithtype) + name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) + sequence [SigD name <$> + [t| forall n. Array n $ttyp -> [Int] |] + ,do body <- [| vectorExtremumOp id $c_op |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype)) + mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) + c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM))) + sequence [SigD name <$> + [t| forall n. SNat n -> Array (n + 1) $ttyp -> Array (n + 1) $ttyp -> Array n $ttyp |] + ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |] + return $ FunD name [Clause [] (NormalB body) []]]) + +foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO () +foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO () + +statisticsEnable :: Bool -> IO () +statisticsEnable b = c_stats_enable (if b then 1 else 0) + +-- | Consumes the log: one particular event will only ever be printed once, +-- even if statisticsPrintAll is called multiple times. +statisticsPrintAll :: IO () +statisticsPrintAll = do + hFlush stdout -- lower the chance of overlapping output + c_stats_print_all + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) + => (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) + -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) + -> (SNat n -> Array n i -> Array n i) +intWidBranch1 f32 f64 sn + | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr f32 + | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr f64 + | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> i -> i) -- ss + -- int32 + -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv + -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs + -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv + -- int64 + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- sv + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- vs + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- vv + -> (SNat n -> Array n i -> Array n i -> Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn + | finiteBitSize (undefined :: i) == 32 = liftOpEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32 + | finiteBitSize (undefined :: i) == 64 = liftOpEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64 + | otherwise = error "Unsupported Int width" + +intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant + -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -- int64 + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant + -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -> (SNat n -> Array (n + 1) i -> Array n i) +intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 + | otherwise = error "Unsupported Int width" + +intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> Int -> i) -- ^ scale op + -- int32 + -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel + -- int64 + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel + -> (SNat n -> Array n i -> i) +intWidBranchRedFull fsc fred32 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 + | otherwise = error "Unsupported Int width" + +intWidBranchExtr :: forall i n. (FiniteBits i, Storable i) + => -- int32 + (forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel + -- int64 + -> (forall b. b ~ Int64 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel + -> (Array n i -> [Int]) +intWidBranchExtr fextr32 fextr64 + | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32 + | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 + | otherwise = error "Unsupported Int width" + +intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i) + => -- int32 + (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant + -> (forall b. b ~ Int32 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -> (forall b. b ~ Int32 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel + -- int64 + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ scale by constant + -> (forall b. b ~ Int64 => Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -> (forall b. b ~ Int64 => Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel + -> (SNat n -> Array (n + 1) i -> Array (n + 1) i -> Array n i) +intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn + | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 + | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64 + | otherwise = error "Unsupported Int width" + +class NumElt a where + numEltAdd :: SNat n -> Array n a -> Array n a -> Array n a + numEltSub :: SNat n -> Array n a -> Array n a -> Array n a + numEltMul :: SNat n -> Array n a -> Array n a -> Array n a + numEltNeg :: SNat n -> Array n a -> Array n a + numEltAbs :: SNat n -> Array n a -> Array n a + numEltSignum :: SNat n -> Array n a -> Array n a + numEltSum1Inner :: SNat n -> Array (n + 1) a -> Array n a + numEltProduct1Inner :: SNat n -> Array (n + 1) a -> Array n a + numEltSumFull :: SNat n -> Array n a -> a + numEltProductFull :: SNat n -> Array n a -> a + numEltMinIndex :: SNat n -> Array n a -> [Int] + numEltMaxIndex :: SNat n -> Array n a -> [Int] + numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a + +instance NumElt Int32 where + numEltAdd = addVectorInt32 + numEltSub = subVectorInt32 + numEltMul = mulVectorInt32 + numEltNeg = negVectorInt32 + numEltAbs = absVectorInt32 + numEltSignum = signumVectorInt32 + numEltSum1Inner = sum1VectorInt32 + numEltProduct1Inner = product1VectorInt32 + numEltSumFull = sumFullVectorInt32 + numEltProductFull = productFullVectorInt32 + numEltMinIndex _ = minindexVectorInt32 + numEltMaxIndex _ = maxindexVectorInt32 + numEltDotprodInner = dotprodinnerVectorInt32 + +instance NumElt Int64 where + numEltAdd = addVectorInt64 + numEltSub = subVectorInt64 + numEltMul = mulVectorInt64 + numEltNeg = negVectorInt64 + numEltAbs = absVectorInt64 + numEltSignum = signumVectorInt64 + numEltSum1Inner = sum1VectorInt64 + numEltProduct1Inner = product1VectorInt64 + numEltSumFull = sumFullVectorInt64 + numEltProductFull = productFullVectorInt64 + numEltMinIndex _ = minindexVectorInt64 + numEltMaxIndex _ = maxindexVectorInt64 + numEltDotprodInner = dotprodinnerVectorInt64 + +instance NumElt Float where + numEltAdd = addVectorFloat + numEltSub = subVectorFloat + numEltMul = mulVectorFloat + numEltNeg = negVectorFloat + numEltAbs = absVectorFloat + numEltSignum = signumVectorFloat + numEltSum1Inner = sum1VectorFloat + numEltProduct1Inner = product1VectorFloat + numEltSumFull = sumFullVectorFloat + numEltProductFull = productFullVectorFloat + numEltMinIndex _ = minindexVectorFloat + numEltMaxIndex _ = maxindexVectorFloat + numEltDotprodInner = dotprodinnerVectorFloat + +instance NumElt Double where + numEltAdd = addVectorDouble + numEltSub = subVectorDouble + numEltMul = mulVectorDouble + numEltNeg = negVectorDouble + numEltAbs = absVectorDouble + numEltSignum = signumVectorDouble + numEltSum1Inner = sum1VectorDouble + numEltProduct1Inner = product1VectorDouble + numEltSumFull = sumFullVectorDouble + numEltProductFull = productFullVectorDouble + numEltMinIndex _ = minindexVectorDouble + numEltMaxIndex _ = maxindexVectorDouble + numEltDotprodInner = dotprodinnerVectorDouble + +instance NumElt Int where + numEltAdd = intWidBranch2 @Int (+) + (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) + (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @Int (-) + (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) + (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @Int (*) + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed1 @Int + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) + numEltProduct1Inner = intWidBranchRed1 @Int + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) + numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) + numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) + numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 + numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 + numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 + +instance NumElt CInt where + numEltAdd = intWidBranch2 @CInt (+) + (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) + (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @CInt (-) + (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) + (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @CInt (*) + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed1 @CInt + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) + numEltProduct1Inner = intWidBranchRed1 @CInt + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) + numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) + numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) + numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 + numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 + numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 + +class NumElt a => IntElt a where + intEltQuot :: SNat n -> Array n a -> Array n a -> Array n a + intEltRem :: SNat n -> Array n a -> Array n a -> Array n a + +instance IntElt Int32 where + intEltQuot = quotVectorInt32 + intEltRem = remVectorInt32 + +instance IntElt Int64 where + intEltQuot = quotVectorInt64 + intEltRem = remVectorInt64 + +instance IntElt Int where + intEltQuot = intWidBranch2 @Int quot + (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + intEltRem = intWidBranch2 @Int rem + (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) + (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +instance IntElt CInt where + intEltQuot = intWidBranch2 @CInt quot + (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + intEltRem = intWidBranch2 @CInt rem + (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) + (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + +class NumElt a => FloatElt a where + floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a + floatEltPow :: SNat n -> Array n a -> Array n a -> Array n a + floatEltLogbase :: SNat n -> Array n a -> Array n a -> Array n a + floatEltRecip :: SNat n -> Array n a -> Array n a + floatEltExp :: SNat n -> Array n a -> Array n a + floatEltLog :: SNat n -> Array n a -> Array n a + floatEltSqrt :: SNat n -> Array n a -> Array n a + floatEltSin :: SNat n -> Array n a -> Array n a + floatEltCos :: SNat n -> Array n a -> Array n a + floatEltTan :: SNat n -> Array n a -> Array n a + floatEltAsin :: SNat n -> Array n a -> Array n a + floatEltAcos :: SNat n -> Array n a -> Array n a + floatEltAtan :: SNat n -> Array n a -> Array n a + floatEltSinh :: SNat n -> Array n a -> Array n a + floatEltCosh :: SNat n -> Array n a -> Array n a + floatEltTanh :: SNat n -> Array n a -> Array n a + floatEltAsinh :: SNat n -> Array n a -> Array n a + floatEltAcosh :: SNat n -> Array n a -> Array n a + floatEltAtanh :: SNat n -> Array n a -> Array n a + floatEltLog1p :: SNat n -> Array n a -> Array n a + floatEltExpm1 :: SNat n -> Array n a -> Array n a + floatEltLog1pexp :: SNat n -> Array n a -> Array n a + floatEltLog1mexp :: SNat n -> Array n a -> Array n a + floatEltAtan2 :: SNat n -> Array n a -> Array n a -> Array n a + +instance FloatElt Float where + floatEltDiv = divVectorFloat + floatEltPow = powVectorFloat + floatEltLogbase = logbaseVectorFloat + floatEltRecip = recipVectorFloat + floatEltExp = expVectorFloat + floatEltLog = logVectorFloat + floatEltSqrt = sqrtVectorFloat + floatEltSin = sinVectorFloat + floatEltCos = cosVectorFloat + floatEltTan = tanVectorFloat + floatEltAsin = asinVectorFloat + floatEltAcos = acosVectorFloat + floatEltAtan = atanVectorFloat + floatEltSinh = sinhVectorFloat + floatEltCosh = coshVectorFloat + floatEltTanh = tanhVectorFloat + floatEltAsinh = asinhVectorFloat + floatEltAcosh = acoshVectorFloat + floatEltAtanh = atanhVectorFloat + floatEltLog1p = log1pVectorFloat + floatEltExpm1 = expm1VectorFloat + floatEltLog1pexp = log1pexpVectorFloat + floatEltLog1mexp = log1mexpVectorFloat + floatEltAtan2 = atan2VectorFloat + +instance FloatElt Double where + floatEltDiv = divVectorDouble + floatEltPow = powVectorDouble + floatEltLogbase = logbaseVectorDouble + floatEltRecip = recipVectorDouble + floatEltExp = expVectorDouble + floatEltLog = logVectorDouble + floatEltSqrt = sqrtVectorDouble + floatEltSin = sinVectorDouble + floatEltCos = cosVectorDouble + floatEltTan = tanVectorDouble + floatEltAsin = asinVectorDouble + floatEltAcos = acosVectorDouble + floatEltAtan = atanVectorDouble + floatEltSinh = sinhVectorDouble + floatEltCosh = coshVectorDouble + floatEltTanh = tanhVectorDouble + floatEltAsinh = asinhVectorDouble + floatEltAcosh = acoshVectorDouble + floatEltAtanh = atanhVectorDouble + floatEltLog1p = log1pVectorDouble + floatEltExpm1 = expm1VectorDouble + floatEltLog1pexp = log1pexpVectorDouble + floatEltLog1mexp = log1mexpVectorDouble + floatEltAtan2 = atan2VectorDouble diff --git a/ops/Data/Array/Strided/Arith/Internal/Foreign.hs b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs new file mode 100644 index 0000000..dad65f9 --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Strided.Arith.Internal.Foreign where + +import Data.Int +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Strided.Arith.Internal.Lists + + +$(do + let importsScal ttyp tyn = + [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) + ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ] + + let importsInt ttyp tyn = + [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ] + + let importsFloat ttyp tyn = + [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ] + + let generate types imports = + sequence + [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ + | arithtype <- types + , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] + decs1 <- generate typesList importsScal + decs2 <- generate intTypesList importsInt + decs3 <- generate floatTypesList importsFloat + return (decs1 ++ decs2 ++ decs3)) diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs new file mode 100644 index 0000000..910a77c --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Strided.Arith.Internal.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Strided.Arith.Internal.Lists.TH + + +data ArithType = ArithType + { atType :: Name -- ''Int32 + , atCName :: String -- "i32" + } + +intTypesList :: [ArithType] +intTypesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ] + +floatTypesList :: [ArithType] +floatTypesList = + [ArithType ''Float "float" + ,ArithType ''Double "double" + ] + +typesList :: [ArithType] +typesList = intTypesList ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] + ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded) +$(genArithDataType IBinop "ArithIBOp") + +$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3)) +$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum") + +$(do clauses <- readArithLists IBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |] + ,return $ FunD (mkName "aiboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] + ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs new file mode 100644 index 0000000..b8f6a3d --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Strided.Arith.Internal.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop + deriving (Show, Eq) + +readArithLists :: OpKind + -> (String -> Int -> String -> Q a) + -> ([a] -> Q r) + -> Q r +readArithLists targetkind fop fcombine = do + addDependentFile "cbits/arith_lists.h" + lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + + mvals <- forM lns $ \line -> do + if null (dropWhile (== ' ') line) + then return Nothing + else do let (kind, name, num, aux) = parseLine line + if kind == targetkind + then Just <$> fop name num aux + else return Nothing + + fcombine (catMaybes mvals) + where + parseLine s0 + | ("LIST_", s1) <- splitAt 5 s0 + , (kindstr, '(' : s2) <- break (== '(') s1 + , (f1, ',' : s3) <- parseField s2 + , (f2, ',' : s4) <- parseField s3 + , (f3, ')' : _) <- parseField s4 + , Just kind <- parseKind kindstr + , let name = f1 + , Just num <- readMaybe f2 + , let aux = f3 + = (kind, name, num, aux) + | otherwise + = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + + parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + + parseKind "BINOP" = Just Binop + parseKind "IBINOP" = Just IBinop + parseKind "FBINOP" = Just FBinop + parseKind "UNOP" = Just Unop + parseKind "FUNOP" = Just FUnop + parseKind "REDOP" = Just Redop + parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do + cons <- readArithLists kind + (\name _num _ -> return $ NormalC (mkName name) []) + return + return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do + clauses <- readArithLists kind + (\name _num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (StringL (nametrans name)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) + ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do + clauses <- readArithLists kind + (\name num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (IntegerL (fromIntegral num)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) + ,FunD (mkName funname) clauses] diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs new file mode 100644 index 0000000..9280fe0 --- /dev/null +++ b/ops/Data/Array/Strided/Array.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Data.Array.Strided.Array where + +import Data.List.NonEmpty qualified as NE +import Data.Proxy +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +data Array (n :: Nat) a = Array + { arrShape :: ![Int] + , arrStrides :: ![Int] + , arrOffset :: !Int + , arrValues :: !(VS.Vector a) + } + +-- | Takes a vector in normalised order (inner dimension, i.e. last in the +-- list, iterates fastest). +arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a +arrayFromVector sh vec + | VS.length vec == shsize + , length sh == fromIntegral (natVal (Proxy @n)) + = Array sh strides 0 vec + | otherwise = error $ "arrayFromVector: Shape " ++ show sh ++ " does not match vector length " ++ show (VS.length vec) + where + shsize = product sh + strides = NE.tail (NE.scanr (*) 1 sh) + +arrayFromConstant :: Storable a => [Int] -> a -> Array n a +arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x) + +arrayRevDims :: [Bool] -> Array n a -> Array n a +arrayRevDims bs (Array sh strides offset vec) + | length bs == length sh = + Array sh + (zipWith (\b s -> if b then -s else s) bs strides) + (offset + sum (zipWith3 (\b n s -> if b then (n - 1) * s else 0) bs sh strides)) + vec + | otherwise = error $ "arrayRevDims: " ++ show (length bs) ++ " booleans given but rank " ++ show (length sh) diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 0aa7001..be4bb03 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -1,32 +1,184 @@ cabal-version: 3.0 name: ox-arrays version: 0.1.0.0 -author: Tom Smeding +synopsis: An efficient CPU-based multidimensional array (tensor) library +description: + An efficient and richly typed CPU-based multidimensional array (tensor) + library built upon the optimized tensor representation (strides list) + implemented in the orthotope package. See the README. + + If you use this package: let me know (e.g. via email) if you find it useful! + Both positive feedback (keep this!) and negative feedback (I needed this but + ox-arrays doesn't provide it) is welcome. +copyright: (c) 2025 Tom Smeding, Mikolaj Konarski +author: Tom Smeding, Mikolaj Konarski +maintainer: Tom Smeding <xhackage@tomsmeding.com> license: BSD-3-Clause +category: Array, Tensors build-type: Simple +extra-doc-files: README.md CHANGELOG.md +extra-source-files: cbits/arith_lists.h + +flag trace-wrappers + description: + Compile modules that define wrappers around the array methods that trace + their arguments and results. This is conditional on a flag because these + modules make documentation generation fail. + (@https://gitlab.haskell.org/ghc/ghc/-/issues/24964@ , should be fixed in + GHC 9.12) + default: False + manual: True + +flag nonportable-simd + description: + Assume the binary will be run on the same CPU as where it is built. Setting + this flag causes `-march=native` to be passed to the C compiler when + compiling arithmetic operations. The result is generally much faster + arithmetic operations, but the executable is much less portable to + different computers. + default: False + manual: True + +flag pedantic-c-warnings + description: + Compile embedded C code with a high warning level. Only useful for + ox-arrays developers. + default: False + manual: True + +flag default-show-instances + description: + Use default GHC-derived Show instances for arrays, shapes and indices. This + exposes the internal struct-of-arrays representation and is less readable, + but can be useful for ox-arrays debugging. + default: False + manual: True + +common basics + default-language: Haskell2010 + ghc-options: -Wall -Wcompat -Widentities -Wunused-packages + library + import: basics exposed-modules: - Data.Array.Mixed + -- put this module on top so ghci considers it the "main" module Data.Array.Nested - Data.Array.Nested.Internal - Data.INat + + Data.Array.Nested.Convert + Data.Array.Nested.Mixed + Data.Array.Nested.Mixed.Shape + Data.Array.Nested.Lemmas + Data.Array.Nested.Permutation + Data.Array.Nested.Ranked + Data.Array.Nested.Ranked.Base + Data.Array.Nested.Ranked.Shape + Data.Array.Nested.Shaped + Data.Array.Nested.Shaped.Base + Data.Array.Nested.Shaped.Shape + Data.Array.Nested.Types + Data.Array.Strided.Orthotope + Data.Array.XArray + Data.Bag + + if flag(trace-wrappers) + exposed-modules: + Data.Array.Nested.Trace + Data.Array.Nested.Trace.TH + build-depends: + template-haskell + other-extensions: TemplateHaskell + + if flag(default-show-instances) + cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES + build-depends: - base >=4.18 && <4.20, + strided-array-ops, + + base, + deepseq < 1.7, ghc-typelits-knownnat, - -- ghc-typelits-natnormalise, - orthotope, + ghc-typelits-natnormalise, + orthotope < 0.2, vector hs-source-dirs: src - default-language: Haskell2010 - ghc-options: -Wall + +library strided-array-ops + import: basics + exposed-modules: + Data.Array.Strided + Data.Array.Strided.Array + Data.Array.Strided.Arith + Data.Array.Strided.Arith.Internal + Data.Array.Strided.Arith.Internal.Foreign + Data.Array.Strided.Arith.Internal.Lists + Data.Array.Strided.Arith.Internal.Lists.TH + build-depends: + base >=4.18 && <4.22, + ghc-typelits-knownnat < 1, + ghc-typelits-natnormalise < 1, + template-haskell < 3, + vector < 0.14 + hs-source-dirs: ops + c-sources: cbits/arith.c + + cc-options: -O3 -std=c11 + if flag(pedantic-c-warnings) + cc-options: -Wall -Wextra -pedantic + if flag(nonportable-simd) + cc-options: -march=native + elif arch(x86_64) || arch(i386) + -- hmatrix assumes sse2, so we can too + cc-options: -msse2 + + other-extensions: TemplateHaskell test-suite test + import: basics type: exitcode-stdio-1.0 main-is: Main.hs + other-modules: + Gen + Tests.C + Tests.Permutation + Util build-depends: ox-arrays, - base + base, + bytestring, + ghc-typelits-knownnat, + ghc-typelits-natnormalise, + hedgehog, + orthotope, + random >= 1.3.0, + tasty, + tasty-hedgehog, + vector hs-source-dirs: test - default-language: Haskell2010 - ghc-options: -Wall + +test-suite example + import: basics + type: exitcode-stdio-1.0 + main-is: Main.hs + build-depends: + ox-arrays, + base + hs-source-dirs: example + +benchmark bench + import: basics + type: exitcode-stdio-1.0 + main-is: Main.hs + build-depends: + ox-arrays, + strided-array-ops, + base, + hmatrix, + orthotope, + tasty-bench, + vector + hs-source-dirs: bench + +source-repository head + type: git + location: https://git.tomsmeding.com/ox-arrays diff --git a/release-hints.txt b/release-hints.txt new file mode 100644 index 0000000..d300da0 --- /dev/null +++ b/release-hints.txt @@ -0,0 +1,3 @@ +- Temporarily enable -Wredundant-constraints + - Has too many false-positives to enable normally, but sometimes catches actual redundant constraints +- Don't forget to rerun gentrace.sh diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs deleted file mode 100644 index 0351beb..0000000 --- a/src/Data/Array/Mixed.hs +++ /dev/null @@ -1,416 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed where - -import qualified Data.Array.RankedS as S -import qualified Data.Array.Ranked as ORB -import Data.Coerce -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import Foreign.Storable (Storable) -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import Data.INat - - --- | The 'SNat' pattern synonym is complete, but it doesn't have a --- @COMPLETE@ pragma. This copy of it does. -pattern GHC_SNat :: () => KnownNat n => SNat n -pattern GHC_SNat = SNat -{-# COMPLETE GHC_SNat #-} - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - -type IxX :: [Maybe Nat] -> Type -> Type -data IxX sh i where - ZIX :: IxX '[] i - (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i - (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i -deriving instance Show i => Show (IxX sh i) -deriving instance Eq i => Eq (IxX sh i) -deriving instance Ord i => Ord (IxX sh i) -deriving instance Functor (IxX sh) -deriving instance Foldable (IxX sh) -infixr 3 :.@ -infixr 3 :.? - -type IIxX sh = IxX sh Int - -type ShX :: [Maybe Nat] -> Type -> Type -data ShX sh i where - ZSX :: ShX '[] i - (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i - (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i -deriving instance Show i => Show (ShX sh i) -deriving instance Eq i => Eq (ShX sh i) -deriving instance Ord i => Ord (ShX sh i) -deriving instance Functor (ShX sh) -deriving instance Foldable (ShX sh) -infixr 3 :$@ -infixr 3 :$? - -type IShX sh = ShX sh Int - --- | The part of a shape that is statically known. -type StaticShX :: [Maybe Nat] -> Type -data StaticShX sh where - ZKSX :: StaticShX '[] - (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) - (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) -deriving instance Show (StaticShX sh) -infixr 3 :!$@ -infixr 3 :!$? - --- | Evidence for the static part of a shape. -type KnownShapeX :: [Maybe Nat] -> Constraint -class KnownShapeX sh where - knownShapeX :: StaticShX sh -instance KnownShapeX '[] where - knownShapeX = ZKSX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing :!$@ knownShapeX -instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :!$? knownShapeX - -type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) - -type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) - deriving (Show) - -zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKSX = ZIX -zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh - -zeroIxX' :: IShX sh -> IIxX sh -zeroIxX' ZSX = ZIX -zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh - --- This is a weird operation, so it has a long name -completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKSX = ZSX -completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh -completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh - -ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') -ixAppend ZIX idx' = idx' -ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' -ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' - -shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') -shAppend ZSX sh' = sh' -shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' -shAppend (n :$? sh) sh' = n :$? shAppend sh sh' - -ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' -ixDrop sh ZIX = sh -ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx -ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKSX sh' = sh' -ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' -ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' - -shapeSize :: IShX sh -> Int -shapeSize ZSX = 1 -shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh -shapeSize (n :$? sh) = n * shapeSize sh - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKSX = Just ZSX -ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh -ssxToShape' (_ :!$? _) = Nothing - -fromLinearIdx :: IShX sh -> Int -> IIxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$@ sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSNat' n - in (locali :.@ idx, upi) - go (n :$? sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali :.? idx, upi) - -toLinearIdx :: IShX sh -> IIxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$@ sh) (i :.@ ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSNat' n * sz) - go (n :$? sh) (i :.? ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - -enumShape :: IShX sh -> [IIxX sh] -enumShape = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] - go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] - -shapeLshape :: IShX sh -> S.ShapeL -shapeLshape ZSX = [] -shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh -shapeLshape (n :$? sh) = n : shapeLshape sh - -ssxLength :: StaticShX sh -> Int -ssxLength ZKSX = 0 -ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :!$? ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKSX = [] -ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh - -lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZSX = Dict -lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict - -lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZKSX = Dict -lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict - -lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh -lemKnownShapeX ZKSX = Dict -lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict - -lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict -lemAppKnownShapeX (() :!$? ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict - -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh -shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKSX [] = ZSX - go (n :!$@ ssh) (_ : l) = n :$@ go ssh l - go (() :!$? ssh) (n : l) = n :$? go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.fromVector (shapeLshape sh) v) - -toVector :: Storable a => XArray sh a -> VS.Vector a -toVector (XArray arr) = S.toVector arr - -scalar :: Storable a => a -> XArray '[] a -scalar = XArray . S.scalar - -unScalar :: Storable a => XArray '[] a -> a -unScalar (XArray a) = S.unScalar a - -constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -constant sh x - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.constant (shapeLshape sh) x) - -generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) - --- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownINatRank sh = --- XArray . S.fromVector (shapeLshape sh) --- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) - -indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a -indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx -indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx - -index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' - -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - -append :: forall n m sh a. (KnownShapeX sh, Storable a) - => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append (XArray a) (XArray b) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.append a b) - -rerank :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) - (\a -> unXArray (f (XArray a))) - arr) - where - unXArray (XArray a) = a - -rerankTop :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b -rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh - -rerank2 :: forall sh sh1 sh2 a b c. - (Storable a, Storable b, Storable c) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) - (\a b -> unXArray (f (XArray a) (XArray b))) - arr1 arr2) - where - unXArray (XArray a) = a - --- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a -transpose perm (XArray arr) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.transpose perm arr) - -transpose2 :: forall sh1 sh2 a. - StaticShX sh1 -> StaticShX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (Storable a, Num a) => XArray sh a -> a -sumFull (XArray arr) = S.sumA arr - -sumInner :: forall sh sh' a. (Storable a, Num a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' - | Refl <- lemAppNil @sh - = rerank ssh ssh' ZKSX (scalar . sumFull) - -sumOuter :: forall sh sh' a. (Storable a, Num a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' - | Refl <- lemAppNil @sh - = sumInner ssh' ssh . transpose2 ssh ssh' - -fromList1 :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromList1 ssh l - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) - = case ssh of - m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> - error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (natVal m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) - -toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] -toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) - --- | Throws if the given shape is not, in fact, empty. -empty :: forall sh a. Storable a => IShX sh -> XArray sh a -empty sh - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.constant (shapeLshape sh) - (error "Data.Array.Mixed.empty: shape was not empty")) - -slice :: [(Int, Int)] -> XArray sh a -> XArray sh a -slice ivs (XArray arr) = XArray (S.slice ivs arr) - -rev1 :: XArray (n : sh) a -> XArray (n : sh) a -rev1 (XArray arr) = XArray (S.rev [0] arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ec5f0b5..c3635e9 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -2,52 +2,126 @@ {-# LANGUAGE PatternSynonyms #-} module Data.Array.Nested ( -- * Ranked arrays - Ranked, - ListR(ZR, (:::)), knownListR, - IxR(.., ZIR, (:.:)), IIxR, knownIxR, - ShR(.., ZSR, (:$:)), knownShR, - rshape, rindex, rindexPartial, rgenerate, rsumOuter1, - rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, - rconstant, rfromList, rfromList1, rtoList, rtoList1, - rslice, rrev1, + Ranked(Ranked), + ListR(ZR, (:::)), + IxR(.., ZIR, (:.:)), IIxR, + ShR(.., ZSR, (:$:)), IShR, + rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim, + rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, + remptyArray, + rrerank, + rreplicate, rreplicateScal, + rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear, + rtoList, rtoListOuter, rtoListLinear, + rslice, rrev1, rreshape, rflatten, riota, + rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot, + rnest, runNest, rzip, runzip, -- ** Lifting orthotope operations to 'Ranked' arrays - rlift, + rlift, rlift2, + -- ** Conversions + rtoXArrayPrim, rfromXArrayPrim, + rtoMixed, rcastToMixed, rcastToShaped, + rfromOrthotope, rtoOrthotope, + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + rquotArray, rremArray, ratan2Array, -- * Shaped arrays - Shaped, + Shaped(Shaped), ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, - ShS(..), KnownShape(..), - sshape, sindex, sindexPartial, sgenerate, ssumOuter1, + ShS(.., ZSS, (:$$)), KnownShS(..), + sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, - sconstant, sfromList, sfromList1, stoList, stoList1, - sslice, srev1, + -- TODO: sconcat? What should its type be? + semptyArray, + srerank, + sreplicate, sreplicateScal, + sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear, + stoList, stoListOuter, stoListLinear, + sslice, srev1, sreshape, sflatten, siota, + sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot, + snest, sunNest, szip, sunzip, -- ** Lifting orthotope operations to 'Shaped' arrays - slift, + slift, slift2, + -- ** Conversions + stoXArrayPrim, sfromXArrayPrim, + stoMixed, scastToMixed, stoRanked, + sfromOrthotope, stoOrthotope, + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + squotArray, sremArray, satan2Array, -- * Mixed arrays Mixed, - IxX(..), IIxX, - KnownShapeX(..), StaticShX(..), - mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, - mconstant, mfromList, mtoList, mslice, mrev1, + ListX(ZX, (::%)), + IxX(.., ZIX, (:.%)), IIxX, + ShX(.., ZSX, (:$%)), KnownShX(..), IShX, + StaticShX(.., ZKX, (:!%)), + SMayNat(..), + mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim, + mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, + memptyArray, + mrerank, + mreplicate, mreplicateScal, + mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear, + mtoList, mtoListOuter, mtoListLinear, + mslice, mrev1, mreshape, mflatten, miota, + mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot, + mnest, munNest, mzip, munzip, + -- ** Lifting orthotope operations to 'Mixed' arrays + mlift, mlift2, + -- ** Conversions + mtoXArrayPrim, mfromXArrayPrim, + mcast, + mcastToShaped, mtoRanked, + convert, Conversion(..), + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + mquotArray, mremArray, matan2Array, -- * Array elements - Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), + Elt, PrimElt, Primitive(..), - - -- * Inductive natural numbers - module Data.INat, + KnownElt, -- * Further utilities / re-exports type (++), Storable, + SNat, pattern SNat, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, + KnownPerm(..), + NumElt, IntElt, FloatElt, + Rank, Product, + Replicate, + MapJust, ) where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat) -import Data.Array.Mixed -import Data.Array.Nested.Internal -import Data.INat +import Data.Array.Nested.Convert +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith import Foreign.Storable +import GHC.TypeLits + +-- $integralRealFloat +-- +-- These functions are separate top-level functions, and not exposed in +-- instances for 'RealFloat' and 'Integral', because those classes include a +-- variety of other functions that make no sense for arrays. +-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but +-- having 'Num', 'Fractional' and 'Floating' available is just too useful. diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..2438f68 --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,333 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +module Data.Array.Nested.Convert ( + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', + + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions + -- + -- | These functions can all be implemented using 'convert' in some way, + -- but some have fewer constraints. + rtoMixed, rcastToMixed, rcastToShaped, + stoMixed, scastToMixed, stoRanked, + mcast, mcastToShaped, mtoRanked, +) where + +import Control.Category +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: IxS sh i -> IxR (Rank sh) i +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported +-- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ _ = error "unreachable" + +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +-- shsFromShX re-exported + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn@SNat :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl + + +-- * Special cases of array conversions + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked = convert ConvXR + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => ShS sh' -> Mixed sh a -> Shaped sh' a +mcastToShaped targetsh = convert (ConvXS' targetsh) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped targetsh arr diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs deleted file mode 100644 index 350eb6f..0000000 --- a/src/Data/Array/Nested/Internal.hs +++ /dev/null @@ -1,1294 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} - -{-| -TODO: -* We should be more consistent in whether functions take a 'StaticShX' - argument or a 'KnownShapeX' constraint. - -* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point - being that we need to do induction over the former, but the latter need to be - able to get large. - --} - -module Data.Array.Nested.Internal where - -import Prelude hiding (mappend) - -import Control.Monad (forM_, when) -import Control.Monad.ST -import qualified Data.Array.RankedS as S -import Data.Bifunctor (first) -import Data.Coerce (coerce, Coercible) -import Data.Foldable (toList) -import Data.Kind -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.Storable (Storable) -import GHC.TypeLits - -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat) -import qualified Data.Array.Mixed as X -import Data.INat - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - -type family Replicate n a where - Replicate Z a = '[] - Replicate (S n) a = a : Replicate n a - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - -lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) - where - go :: SINat m -> StaticShX (Replicate m Nothing) - go SZ = ZKSX - go (SS n) = () :!$? go n - -lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (inatSing @n) - where - go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (inatSing @n) - where - go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKSX idx = (ZSX, idx) -shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) -shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx) - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Int -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) - deriving (Show) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (XArray sh Int) - deriving (Show) -newtype instance Mixed sh Double = M_Double (XArray sh Double) - deriving (Show) -newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) - deriving (Show) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) --- etc. - -newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - --- | Tree giving the shape of every array component. -type family ShapeTree a where - ShapeTree (Primitive _) = () - -- [PRIMITIVE ELEMENT TYPES LIST] - ShapeTree Int = () - ShapeTree Double = () - ShapeTree () = () - - ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) - ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - --- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: KnownShapeX sh => Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- If you want a single-dimensional array from your list, map 'mscalar' - -- first. - mfromList1 :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a - - mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - -- ====== PRIVATE METHODS ====== -- - - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IShX sh -> Mixed sh a - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive a) = X.shape a - mindex (M_Primitive a) i = Primitive (X.index a i) - mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive (X.scalar x) - mfromList1 l = M_Primitive (X.fromList1 knownShapeX (coerce (toList l))) - mtoList1 (M_Primitive arr) = coerce (X.toList1 arr) - - mlift :: forall sh1 sh2. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift f (M_Primitive a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - = M_Primitive (f Proxy a) - - mlift2 :: forall sh1 sh2 sh3. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 f (M_Primitive a) (M_Primitive b) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - , Refl <- X.lemAppNil @sh3 - = M_Primitive (f Proxy a b) - - memptyArray sh = M_Primitive (X.empty sh) - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. KnownShapeX sh' - => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do - let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) - VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance Elt Int -deriving via Primitive Double instance Elt Double -deriving via Primitive () instance Elt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromList1 l = M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) - (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) - mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) - mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) - mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) - - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - --- Arrays of arrays are just arrays, but with more dimensions. -instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest arr) - | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) - - mindex (M_Nest arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest - - mfromList1 :: forall n sh. KnownShapeX (n : sh) - => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a) - mfromList1 l - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) - = M_Nest (mfromList1 (coerce l)) - - mtoList1 (M_Nest arr) = coerce (mtoList1 arr) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift f (M_Nest arr) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - = M_Nest (mlift f' arr) - where - f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) - = f (Proxy @(sh' ++ shT)) - - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 f (M_Nest arr1) (M_Nest arr2) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) - = M_Nest (mlift2 f' arr1 arr2) - where - f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) - = f (Proxy @(sh' ++ shT)) - - memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh')))) - - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example)) - (mindex example (X.zeroIxX (knownShapeX @sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs - - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case X.enumShape sh of - [] -> memptyArray sh - firstidx : restidxs -> - let firstelem = f (X.zeroIxX' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a -mtranspose perm = - mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') - (X.transpose perm)) - -mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a -mappend = mlift2 go - where go :: forall sh' b. (KnownShapeX sh', Storable b) - => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b - go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append - -mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive (X.fromVector sh v) - -mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive v) = X.toVector v - -mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (coerce toPrimitive arr) - -mfromList :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a -mfromList = mfromList1 . fmap mscalar - -mtoList :: Elt a => Mixed '[n] a -> [a] -mtoList = map munScalar . mtoList1 - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x = M_Primitive (X.constant sh x) - -mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) - => IShX sh -> a -> Mixed sh a -mconstant sh x = fromPrimitive (mconstantP sh x) - -mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a -mslice ivs = mlift $ \_ -> X.slice ivs - -mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 = mlift $ \_ -> X.rev1 - -mliftPrim :: (KnownShapeX sh, Storable a) - => (a -> a) - -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) - -mliftPrim2 :: (KnownShapeX sh, Storable a) - => (a -> a -> a) - -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = - M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) - -instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where - (+) = mliftPrim2 (+) - (-) = mliftPrim2 (-) - (*) = mliftPrim2 (*) - negate = mliftPrim negate - abs = mliftPrim abs - signum = mliftPrim signum - fromInteger n = - case X.ssxToShape' (knownShapeX @sh) of - Just sh -> M_Primitive (X.constant sh (fromInteger n)) - Nothing -> error "Data.Array.Nested.fromIntegral: \ - \Unknown components in shape, use explicit mconstant" - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) -deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'INat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a --- type-level natural that supports induction. --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: INat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) -deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) - - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance (Elt a, KnownINat n) => Elt (Ranked n a) where - mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr - mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest x) - - mfromList1 :: forall m sh. KnownShapeX (m : sh) - => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a) - mfromList1 l - | Dict <- lemKnownReplicate (Proxy @n) - = M_Ranked (mfromList1 (coerce l)) - - mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoList1 (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift f (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift f arr - - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 f (M_Ranked arr1) (M_Ranked arr2) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 f arr1 arr2 - - memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i - - mshapeTree (Ranked arr) - | Refl <- lemRankReplicate (Proxy @n) - , Dict <- lemKnownReplicate (Proxy @n) - = first shCvtXR (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - --- | The shape of a shape-typed array given as a list of 'SNat' values. -data ShS sh where - ZSS :: ShS '[] - (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) -deriving instance Show (ShS sh) -deriving instance Eq (ShS sh) -deriving instance Ord (ShS sh) -infixr 3 :$$ - --- | A statically-known shape of a shape-typed array. -class KnownShape sh where knownShape :: ShS sh -instance KnownShape '[] where knownShape = ZSS -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = natSing :$$ knownShape - -sshapeKnown :: ShS sh -> Dict KnownShape sh -sshapeKnown ZSS = Dict -sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict - -lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) -lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKSX - go (n :$$ sh) = n :!$@ go sh - -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ZSS = Refl - go (_ :$$ sh) | Refl <- go sh = Refl - -instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where - mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr - mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest x) - - mfromList1 :: forall n sh'. KnownShapeX (n : sh') - => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a) - mfromList1 l - | Dict <- lemKnownMapJust (Proxy @sh) - = M_Shaped (mfromList1 (coerce l)) - - mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoList1 (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift f (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift f arr - - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 f (M_Shaped arr1) (M_Shaped arr2) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 f arr1 arr2 - - memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i - - mshapeTree (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = first (shCvtXS (knownShape @sh)) (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - --- Utility functions to satisfy the type checker sometimes - -rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a -rewriteMixed Refl x = x - - --- ====== API OF RANKED ARRAYS ====== -- - -arithPromoteRanked :: forall n a. KnownINat n - => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce - -arithPromoteRanked2 :: forall n a. KnownINat n - => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -> Ranked n a -arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce - -instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where - (+) = arithPromoteRanked2 (+) - (-) = arithPromoteRanked2 (-) - (*) = arithPromoteRanked2 (*) - negate = arithPromoteRanked negate - abs = arithPromoteRanked abs - signum = arithPromoteRanked signum - fromInteger n = case inatSing @n of - SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) - SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ - \Rank non-zero, use explicit mconstant" - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) -deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) - -type role ListR nominal representational -type ListR :: INat -> Type -> Type -data ListR n i where - ZR :: ListR Z i - (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i -deriving instance Show i => Show (ListR n i) -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -infixr 3 ::: - -instance Foldable (ListR n) where - foldr f z l = foldr f z (listRToList l) - -listRToList :: ListR n i -> [i] -listRToList ZR = [] -listRToList (i ::: is) = i : listRToList is - -knownListR :: ListR n i -> Dict KnownINat n -knownListR ZR = Dict -knownListR (_ ::: l) | Dict <- knownListR l = Dict - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: INat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ Z => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (S n ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) - where i :.: IxR sh = IxR (i ::: sh) -{-# COMPLETE ZIR, (:.:) #-} -infixr 3 :.: - -data UnconsIxRRes i n1 = - forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i -unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1) -unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i) -unconsIxR (IxR ZR) = Nothing - -type IIxR n = IxR n Int - -knownIxR :: IxR n i -> Dict KnownINat n -knownIxR (IxR sh) = knownListR sh - -type role ShR nominal representational -type ShR :: INat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -type IShR n = ShR n Int - -pattern ZSR :: forall n i. () => n ~ Z => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (S n ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) - where i :$: (ShR sh) = ShR (i ::: sh) -{-# COMPLETE ZSR, (:$:) #-} -infixr 3 :$: - -data UnconsShRRes i n1 = - forall n. S n ~ n1 => UnconsShRRes (ShR n i) i -unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1) -unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i) -unconsShR (ShR ZR) = Nothing - -knownShR :: ShR n i -> Dict KnownINat n -knownShR (ShR sh) = knownListR sh - -zeroIxR :: SINat n -> IIxR n -zeroIxR SZ = ZIR -zeroIxR (SS n) = 0 :.: zeroIxR n - -ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx -ixCvtXR (n :.? idx) = n :.: ixCvtXR idx - -shCvtXR :: IShX sh -> IShR (X.Rank sh) -shCvtXR ZSX = ZSR -shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx -shCvtXR (n :$? idx) = n :$: shCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: idx) = n :.? ixCvtRX idx - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: idx) = n :$? shCvtRX idx - -shapeSizeR :: IShR n -> Int -shapeSizeR ZSR = 1 -shapeSizeR (n :$: sh) = n * shapeSizeR sh - - -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n -rshape (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = shCvtXR (mshape arr) - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) - (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f - | Dict <- knownShR sh - , Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. (KnownINat n2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift f (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n2) - = Ranked (mlift f arr) - -rsumOuter1P :: forall n a. - (Storable a, Num a, KnownINat n) - => Ranked (S n) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked - . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) - . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) - $ arr - -rsumOuter1 :: forall n a. - (Storable a, Num a, PrimElt a, KnownINat n) - => Ranked (S n) a -> Ranked n a -rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive - -rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a -rtranspose perm (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mtranspose perm arr) - -rappend :: forall n a. (KnownINat n, Elt a) - => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a -rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend - -rscalar :: Elt a => a -> Ranked I0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a -rfromList1 l - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromList1 (coerce l)) - -rfromList :: Elt a => NonEmpty a -> Ranked I1 a -rfromList = Ranked . mfromList1 . fmap mscalar - -rtoList :: Elt a => Ranked (S n) a -> [Ranked n a] -rtoList (Ranked arr) = coerce (mtoList1 arr) - -rtoList1 :: Elt a => Ranked I1 a -> [a] -rtoList1 = map runScalar . rtoList - -runScalar :: Elt a => Ranked I0 a -> a -runScalar arr = rindex arr ZIR - -rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) -rconstantP sh x - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mconstantP (shCvtRX sh) x) - -rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a) - => IShR n -> a -> Ranked n a -rconstant sh x = coerce fromPrimitive (rconstantP sh x) - -rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a -rslice ivs = rlift $ \_ -> X.slice ivs - -rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a -rrev1 = rlift $ \_ -> X.rev1 - - --- ====== API OF SHAPED ARRAYS ====== -- - -arithPromoteShaped :: forall sh a. KnownShape sh - => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce - -arithPromoteShaped2 :: forall sh a. KnownShape sh - => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -> Shaped sh a -arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce - -instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where - (+) = arithPromoteShaped2 (+) - (-) = arithPromoteShaped2 (-) - (*) = arithPromoteShaped2 (*) - negate = arithPromoteShaped negate - abs = arithPromoteShaped abs - signum = arithPromoteShaped signum - fromInteger n = sconstantP (fromInteger n) - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) -deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) - -type role ListS nominal representational -type ListS :: [Nat] -> Type -> Type -data ListS sh i where - ZS :: ListS '[] i - (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i -deriving instance Show i => Show (ListS sh i) -deriving instance Eq i => Eq (ListS sh i) -deriving instance Ord i => Ord (ListS sh i) -deriving instance Functor (ListS sh) -infixr 3 ::$ - -instance Foldable (ListS sh) where - foldr f z l = foldr f z (listSToList l) - -listSToList :: ListS sh i -> [i] -listSToList ZS = [] -listSToList (i ::$ is) = i : listSToList is - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i)) - where i :.$ IxS shl = IxS (i ::$ shl) -{-# COMPLETE ZIS, (:.$) #-} -infixr 3 :.$ - -data UnconsIxSRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS sh i) i -unconsIxS :: IxS sh1 i -> Maybe (UnconsIxSRes i sh1) -unconsIxS (IxS (i ::$ shl')) = Just (UnconsIxSRes (IxS shl') i) -unconsIxS (IxS ZS) = Nothing - -type IIxS sh = IxS sh Int - -data UnconsShSRes sh1 = - forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n) -unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1) -unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i) -unconsShS ZSS = Nothing - -zeroIxS :: ShS sh -> IIxS sh -zeroIxS ZSS = ZIS -zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx - -shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh -shCvtXS ZSS ZSX = ZSS -shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = n :$@ shCvtSX sh - -shapeSizeS :: ShS sh -> Int -shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh - - --- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh -sshape _ = knownShape @sh - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial (Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a -sgenerate f - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh))) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift f (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh2) - = Shaped (mlift f arr) - -ssumOuter1P :: forall sh n a. - (Storable a, Num a, KnownNat n, KnownShape sh) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped - . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) - . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh)) - . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) - $ arr - -ssumOuter1 :: forall sh n a. - (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive - -stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a -stranspose perm (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mtranspose perm arr) - -sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) - => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP v - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v) - -sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a -sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) - => NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 l - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromList1 (coerce l)) - -sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a -sfromList = Shaped . mfromList1 . fmap mscalar - -stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoList (Shaped arr) = coerce (mtoList1 arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoList - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) -sconstantP x - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) - -sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) - => a -> Shaped sh a -sconstant x = coerce fromPrimitive (sconstantP @sh x) - -sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a -sslice ivs = slift $ \_ -> X.slice ivs - -srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 = slift $ \_ -> X.rev1 diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs new file mode 100644 index 0000000..8cac298 --- /dev/null +++ b/src/Data/Array/Nested/Lemmas.hs @@ -0,0 +1,162 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + + +-- * Lemmas about numbers and lists + +-- ** Nat + +lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerceRefl + +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +-- ** Append + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +-- ** Simple type families + +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp sn _ _ = go sn + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS (n :: SNat n'm1)) + | Refl <- lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (lemReplicateSucc @a @(n'm1 + m)) + +lemDropLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerceRefl + +lemTakeLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerceRefl + +lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l +lemInitApp _ _ = unsafeCoerceRefl + +lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x +lemLastApp _ _ = unsafeCoerceRefl + + +-- ** KnownNat + +lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +lemKnownNatSucc = Dict + +lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + + +-- * Lemmas about shapes + +-- ** Known shapes + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs new file mode 100644 index 0000000..144230e --- /dev/null +++ b/src/Data/Array/Nested/Mixed.hs @@ -0,0 +1,936 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Mixed where + +import Prelude hiding (mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad (forM_, when) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Bifunctor (bimap) +import Data.Coerce +import Data.Foldable (toList) +import Data.Int +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty qualified as NE +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X +import Data.Bag + + +-- TODO: +-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +-- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int +-- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) +-- After benchmarking: matmul and matvec + + + +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + +-- Primitive element types +-- ======================= +-- +-- There are a few primitive element types; arrays containing elements of such +-- type are a newtype over an XArray, which it itself a newtype over a Vector. +-- Unfortunately, the setup of the library requires us to list these primitive +-- element types multiple times; to aid in extending the list, all these lists +-- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + deriving (Show) + +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class (Storable a, Elt a) => PrimElt a where + fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a + toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce + + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Bool +instance PrimElt Int +instance PrimElt Int64 +instance PrimElt Int32 +instance PrimElt CInt +instance PrimElt Float +instance PrimElt Double +instance PrimElt () + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'mtranspose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, you might see (nonempty) +-- sizes of the elements of an empty array, which is information that should +-- ostensibly not exist; the full array is still empty. + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +#define ANDSHOW , Show +#else +#define ANDSHOW +#endif + +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) + deriving (Eq, Ord, Generic ANDSHOW) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic ANDSHOW) -- no content, orthotope optimises this (via Vector) +-- etc. + +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +#endif +-- etc., larger tuples (perhaps use generics to allow arbitrary product types) + +deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) +deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) + +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (Show (Mixed (sh1 ++ sh2) a)) => Show (Mixed sh1 (Mixed sh2 a)) +#endif + +deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) +deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) + + +-- | Internal helper data family mirroring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) +newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) + + +showsMixedArray :: (Show a, Elt a) + => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ + -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ + -> Int -> Mixed sh a -> ShowS +showsMixedArray fromlistPrefix replicatePrefix d arr = + showParen (d > 10) $ + -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here + case mtoListLinear arr of + hd : _ : _ + | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> + showString replicatePrefix . showString " " . showsPrec 11 hd + _ -> + showString fromlistPrefix . showString " " . shows (mtoListLinear arr) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Mixed sh a) where + showsPrec d arr = + let sh = show (shxToList (mshape arr)) + in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr +#endif + +instance Elt a => NFData (Mixed sh a) where + rnf = mrnf + + +mliftNumElt1 :: (PrimElt a, PrimElt b) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) + -> Mixed sh a -> Mixed sh b +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) + +mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 (liftO2 . numEltAdd) + (-) = mliftNumElt2 (liftO2 . numEltSub) + (*) = mliftNumElt2 (liftO2 . numEltMul) + negate = mliftNumElt1 (liftO1 . numEltNeg) + abs = mliftNumElt1 (liftO1 . numEltAbs) + signum = mliftNumElt1 (liftO1 . numEltSignum) + -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS + fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + recip = mliftNumElt1 (liftO1 . floatEltRecip) + (/) = mliftNumElt2 (liftO2 . floatEltDiv) + +instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + exp = mliftNumElt1 (liftO1 . floatEltExp) + log = mliftNumElt1 (liftO1 . floatEltLog) + sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) + + (**) = mliftNumElt2 (liftO2 . floatEltPow) + logBase = mliftNumElt2 (liftO2 . floatEltLogbase) + + sin = mliftNumElt1 (liftO1 . floatEltSin) + cos = mliftNumElt1 (liftO1 . floatEltCos) + tan = mliftNumElt1 (liftO1 . floatEltTan) + asin = mliftNumElt1 (liftO1 . floatEltAsin) + acos = mliftNumElt1 (liftO1 . floatEltAcos) + atan = mliftNumElt1 (liftO1 . floatEltAtan) + sinh = mliftNumElt1 (liftO1 . floatEltSinh) + cosh = mliftNumElt1 (liftO1 . floatEltCosh) + tanh = mliftNumElt1 (liftO1 . floatEltTanh) + asinh = mliftNumElt1 (liftO1 . floatEltAsinh) + acosh = mliftNumElt1 (liftO1 . floatEltAcosh) + atanh = mliftNumElt1 (liftO1 . floatEltAtanh) + log1p = mliftNumElt1 (liftO1 . floatEltLog1p) + expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) + log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) + log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) + +mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem) + +matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) + +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where + -- ====== PUBLIC METHODS ====== -- + + mshape :: Mixed sh a -> IShX sh + mindex :: Mixed sh a -> IIxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a + + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] + + -- | Note: this library makes no particular guarantees about the shapes of + -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the + -- full 'XArray' and as such you can distinguish different empty arrays by + -- the "shapes" of their elements. This information is meaningless, so you + -- should not use it. + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a + + -- | See the documentation for 'mlift'. + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + + -- TODO: mliftL is currently unused. + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) + + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a + + mrnf :: Mixed sh a -> () + + -- ====== PRIVATE METHODS ====== -- + + -- | Tree giving the shape of every array component. + type ShapeTree a + + mshapeTree :: a -> ShapeTree a + + mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + + mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + + mshowShapeTree :: Proxy a -> ShapeTree a -> String + + -- | Returns the stride vector of each underlying component array making up + -- this mixed array. + marrayStrides :: Mixed sh a -> Bag [Int] + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. +class Elt a => KnownElt a where + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArrayUnsafe :: IShX sh -> Mixed sh a + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance Storable a => Elt (Primitive a) where + mshape (M_Primitive sh _) = sh + mindex (M_Primitive _ a) i = Primitive (X.index a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) + mfromListOuter l@(arr1 :| _) = + let sh = SUnknown (length l) :$% mshape arr1 + in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) + mlift ssh2 f (M_Primitive _ a) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , let result = f ZKX a + = M_Primitive (X.shape ssh2 result) result + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) + mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 + , let result = f ZKX a b + = M_Primitive (X.shape ssh3 result) result + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) + mliftL ssh2 f l + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ + f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) + + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = + let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + sh2 = shxCast' ssh2 sh1 + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) + + mtranspose perm (M_Primitive sh arr) = + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShX sh) perm arr) + + mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) + mconcat l@(M_Primitive (_ :$% sh) _ :| _) = + let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result + + mrnf (M_Primitive sh a) = rnf sh `seq` rnf a + + type ShapeTree (Primitive a) = () + mshapeTree _ = () + mshapeTreeEq _ () () = True + mshapeTreeEmpty _ () = False + mshowShapeTree _ () = "()" + marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + + -- TODO: this use of toVector is suboptimal + mvecsWritePartial + :: forall sh' sh s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShX sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Bool instance Elt Bool +deriving via Primitive Int instance Elt Int +deriving via Primitive Int64 instance Elt Int64 +deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive CInt instance Elt CInt +deriving via Primitive Double instance Elt Double +deriving via Primitive Float instance Elt Float +deriving via Primitive () instance Elt () + +instance Storable a => KnownElt (Primitive a) where + memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Bool instance KnownElt Bool +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Int64 instance KnownElt Int64 +deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive CInt instance KnownElt CInt +deriving via Primitive Double instance KnownElt Double +deriving via Primitive Float instance KnownElt Float +deriving via Primitive () instance KnownElt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + mliftL ssh2 f = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 + + mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) + + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + mconcat = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 + + mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b + + type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) + mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' + mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) + +-- Arrays of arrays are just arrays, but with more dimensions. +instance Elt a => Elt (Mixed sh' a) where + -- TODO: this is quadratic in the nesting depth because it repeatedly + -- truncates the shape vector to one a little shorter. Fix with a + -- moverlongShape method, a prefix of which is mshape. + mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh + mshape (M_Nest sh arr) + = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + + mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a + mindex (M_Nest _ arr) = mindexPartial arr + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest sh arr) i + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + mscalar = M_Nest ZSX + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = + M_Nest (SUnknown (length l) :$% mshape arr) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) + mlift ssh2 f (M_Nest sh1 arr) = + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + in M_Nest sh2 result + where + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) + mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + in M_Nest sh3 result + where + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) + -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) + mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = + let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + in fmap (M_Nest sh2) result + where + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + sh2 = shxCast' ssh2 sh1 + in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = shxDropSh @sh @sh' sh (mshape arr) + , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (shxPermutePrefix perm sh) + (mtranspose perm arr) + + mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mconcat l@(M_Nest sh1 _ :| _) = + let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result + + mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr + + type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Nest _ arr) = marrayStrides arr + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) + + mvecsUnsafeNew sh example + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh'))) + where + sh' = mshape example + + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + + +memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) + +mrank :: Elt a => Mixed sh a -> SNat (Rank sh) +mrank = shxRank . mshape + +-- | The total number of elements in the array. +msize :: Elt a => Mixed sh a -> Int +msize = shxSize . mshape + +-- | Create an array given a size and a function that computes the element at a +-- given index. +-- +-- __WARNING__: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> +-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> +-- > ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case shxEnum sh of + [] -> memptyArrayUnsafe sh + firstidx : restidxs -> + let firstelem = f (ixxZero' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArrayUnsafe sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs + +msumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = + let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX + in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) + +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + +msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr + +mappend :: forall n m sh a. Elt a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 + where + sn :$% sh = mshape arr1 + sm :$% _ = mshape arr2 + ssh = ssxFromShX sh + snm :: SMayNat () SNat (AddMaybe n m) + snm = case (sn, sm) of + (SUnknown{}, _) -> SUnknown () + (SKnown{}, SUnknown{}) -> SUnknown () + (SKnown n, SKnown m) -> SKnown (snatPlus n m) + + f :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') + +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) + +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) + +mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a +mtoVectorP (M_Primitive _ v) = X.toVector v + +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a +mtoVector arr = mtoVectorP (toPrimitive arr) + +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? + +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1 l) + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) + +mtoList :: Elt a => Mixed '[n] a -> [a] +mtoList = map munScalar . mtoListOuter + +mtoListLinear :: Elt a => Mixed sh a -> [a] +mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise + +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr ZIX + +mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) +mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr + +munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a +munNest (M_Nest _ arr) = arr + +-- | The arguments must have equal shapes. If they do not, an error is raised. +mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip a b + | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b + | otherwise = error "mzip: unequal shapes" + +munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) +munzip (M_Tup2 a b) = (a, b) + +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) +mrerankP ssh sh2 f (M_Primitive sh arr) = + let sh1 = shxDropSSX ssh sh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) + (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) + +-- | See the caveats at 'Data.Array.XArray.rerank'. +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b +mrerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = ssxFromShX (mshape arr) + in mlift (ssxAppend (ssxFromShX sh) ssh') + (\(sshT :: StaticShX shT) -> + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) + +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n arr = + let _ :$% sh = mshape arr + in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr + +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr + +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr + +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' arr = + mlift (ssxFromShX sh') + (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh') + arr + +mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a +mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr + +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + +-- | Throws if the array is empty. +mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr)) + +-- | Throws if the array is empty. +mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) + +mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a +mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sh1 of + _ :$% _ + | sh1 == sh2 + , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) -> + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) + | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" + ZSX -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'mdot1Inner' if applicable. +mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a +mdot a b = + munScalar $ + mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) + (fromPrimitive (mflatten (toPrimitive b))) + +mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) + +mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) +mtoXArrayPrim = mtoXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr + +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP + +mliftPrim :: (PrimElt a, PrimElt b) + => (a -> b) + -> Mixed sh a -> Mixed sh b +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) + => (a -> b -> c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = + fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs new file mode 100644 index 0000000..852dd5e --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -0,0 +1,644 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Bifunctor (first) +import Data.Coerce +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Types + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListX sh f) +#else +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows +#endif + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqType ZX ZX = Just Refl +listxEqType (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , Just Refl <- listxEqType sh sh' + = Just Refl +listxEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqual ZX ZX = Just Refl +listxEqual (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listxEqual sh sh' + = Just Refl +listxEqual _ _ = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxRank :: ListX sh f -> SNat (Rank sh) +listxRank ZX = SNat +listxRank (_ ::% l) | SNat <- listxRank l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) +listxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "listxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxHead :: ListX (mn ': sh) f -> f mn +listxHead (i ::% _) = i + +listxTail :: ListX (n : sh) i -> ListX sh i +listxTail (_ ::% sh) = sh + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' + +listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f +listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh +listxInit (_ ::% ZX) = ZX + +listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast (_ ::% sh@(_ ::% _)) = listxLast sh +listxLast (x ::% ZX) = x + +listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) +listxZip ZX ZX = ZX +listxZip (i ::% irest) (j ::% jrest) = + Pair i j ::% listxZip irest jrest + +listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g + -> ListX sh h +listxZipWith _ ZX ZX = ZX +listxZipWith f (i ::% is) (j ::% js) = + f i j ::% listxZipWith f is js + + +-- * Mixed indices + +-- | An index into a mixed-typed array. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxX sh = IxX sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxX sh i) +#else +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (shows . getConst) l +#endif + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxLength :: IxX sh i -> Int +ixxLength (IxX l) = listxLength l + +ixxRank :: IxX sh i -> SNat (Rank sh) +ixxRank (IxX l) = listxRank l + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i +ixxFromList = coerce (listxFromList @_ @i) + +ixxHead :: IxX (n : sh) i -> i +ixxHead (IxX list) = getConst (listxHead list) + +ixxTail :: IxX (n : sh) i -> IxX sh i +ixxTail (IxX list) = IxX (listxTail list) + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit = coerce (listxInit @(Const i)) + +ixxLast :: forall n sh i. IxX (n : sh) i -> i +ixxLast = coerce (listxLast @(Const i)) + +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + +ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) +ixxZip ZIX ZIX = ZIX +ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js + +ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k +ixxZipWith _ ZIX ZIX = ZIX +ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +instance TestEquality f => TestEquality (SMayNat i f) where + testEquality SUnknown{} SUnknown{} = Just Refl + testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl + testEquality _ _ = Nothing + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShX sh i) +#else +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +-- | This checks only whether the types are equal; unknown dimensions might +-- still differ. This corresponds to 'testEquality', except on the penultimate +-- type parameter. +shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqType ZSX ZSX = Just Refl +shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') + | Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType _ _ = Nothing + +-- | This checks whether all dimensions have the same value. This is more than +-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the +-- @some@ package (except on the penultimate type parameter). +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +shxRank :: ShX sh i -> SNat (Rank sh) +shxRank (ShX l) = listxRank l + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxFromList :: StaticShX sh -> [Int] -> IShX sh +shxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn :$% go sh is + | otherwise = error $ "shxFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is + go _ _ = error $ "shxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" + +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead (ShX list) = listxHead list + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (ShX list) = ShX (listxTail list) + +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit = coerce (listxInit @(SMayNat i SNat)) + +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh + +shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) + -> ShX sh i -> ShX sh j -> ShX sh k +shxZipWith _ ZSX ZSX = ZSX +shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh +shxCast _ _ = Nothing + +-- | Partial version of 'shxCast'. +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of + Just sh' -> sh' + Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (StaticShX sh) +#else +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance NFData (StaticShX sh) where + rnf (StaticShX ZX) = () + rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + +instance TestEquality StaticShX where + testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +ssxRank :: StaticShX sh -> SNat (Rank sh) +ssxRank (StaticShX l) = listxRank l + +-- | @ssxEqType = 'testEquality'@. Provided for consistency. +ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxEqType = testEquality + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead (StaticShX list) = listxHead list + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) + +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) + +ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) +ssxInit = coerce (listxInit @(SMayNat () SNat)) + +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat () SNat)) + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) + +ssxFromShX :: ShX sh i -> StaticShX sh +ssxFromShX ZSX = ZKX +ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh + +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + +withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r +withKnownShX = withDict @(KnownShX sh) + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList = listxFromList (knownShX @sh) + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList = shxFromList (knownShX @sh) + toList = shxToList diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs new file mode 100644 index 0000000..03d1640 --- /dev/null +++ b/src/Data/Array/Nested/Permutation.hs @@ -0,0 +1,283 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Maybe (fromMaybe) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.Exts (withDict) +import GHC.TypeError +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. +data Perm list where + PNil :: Perm '[] + PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + +permRank :: Perm list -> SNat (Rank list) +permRank PNil = SNat +permRank (_ `PCons` l) | SNat <- permRank l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromList xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + +-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of +-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, +-- then @Nothing@ is returned. +permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r +permCheckPermutation = \p k -> + let n = permRank p + in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of + (Just Refl, Just Refl) -> Just k + _ -> Nothing + where + lemElemCount :: (0 <= n, Compare n m ~ LT) + => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) + => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just Refl + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + +-- | Utility class for generating permutations from type class information. +class KnownPerm l where makePerm :: Perm l +instance KnownPerm '[] where makePerm = PNil +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm + +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + +-- | Untyped permutations for ranked arrays +type PermR = [Int] + + +-- ** Applying permutations + +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) +listxIndex _ _ SZ (n ::% _) = n +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh +listxIndex _ _ _ ZX = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +permInverse :: Perm is + -> (forall is'. + IsPermutation is' + => Perm is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +permInverse = \perm k -> + genPerm perm $ \(invperm :: Perm is') -> + fromMaybe + (error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + (permCheckPermutation invperm + (k invperm + (\ssh -> case permCheckInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm))) + where + genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r + genPerm perm = + let permList = permToList' perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r + toHList [] k = k PNil + toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh + -> Maybe (Permute is' (Permute is sh) :~: sh) + permCheckInverse perm perminv ssh = + ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where + MapSucc '[] = '[] + MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc + where + permMapSucc :: Perm l -> Perm (MapSucc l) + permMapSucc PNil = PNil + permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l + -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs new file mode 100644 index 0000000..9778c54 --- /dev/null +++ b/src/Data/Array/Nested/Ranked.hs @@ -0,0 +1,323 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked ( + Ranked(Ranked), + rquotArray, rremArray, ratan2Array, + rshape, rrank, + module Data.Array.Nested.Ranked, + liftRanked1, liftRanked2, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.RankedS qualified as S +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X + + +remptyArray :: KnownElt a => Ranked 1 a +remptyArray = mtoRanked (memptyArray ZSX) + +-- | The total number of elements in the array. +rsize :: Elt a => Ranked n a -> Int +rsize = shrSize . rshape + +rindex :: Elt a => Ranked n a -> IIxR n -> a +rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) + (ixxFromIxR idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate sh f + | sn@SNat <- shrRank sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn + = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) + +-- | See the documentation of 'mlift'. +rlift :: forall n1 n2 a. Elt a + => SNat n2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) + +-- | See the documentation of 'mlift2'. +rlift2 :: forall n1 n2 n3 a. Elt a + => SNat n3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a +rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) + +rsumOuter1P :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (msumOuter1P arr) + +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive + +rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a +rsumAllPrim (Ranked arr) = msumAllPrim arr + +rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a +rtranspose perm arr + | sn@SNat <- rrank arr + , Dict <- lemKnownReplicate sn + , length perm <= fromIntegral (natVal (Proxy @n)) + = rlift sn + (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + arr + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a +rconcat + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce mconcat + +rappend :: forall n a. Elt a + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend arr1 arr2 + | sn@SNat <- rrank arr1 + , Dict <- lemKnownReplicate sn + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) + arr1 arr2 + +rscalar :: Elt a => a -> Ranked 0 a +rscalar x = Ranked (mscalar x) + +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mfromVectorP (shxFromShR sh) v) + +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = rfromPrimitive (rfromVectorP sh v) + +rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a +rtoVectorP = coerce mtoVectorP + +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a +rtoVector = coerce mtoVector + +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) + +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = rreshape sh (rfromList1 l) + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = Ranked (mfromListPrim l) + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) + +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) + +rtoListLinear :: Elt a => Ranked n a -> [a] +rtoListLinear (Ranked arr) = mtoListLinear arr + +rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a +rfromOrthotope sn arr + | Refl <- lemRankReplicate sn + = let xarr = XArray arr + in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) + +rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + = arr + +runScalar :: Elt a => Ranked 0 a -> a +runScalar arr = rindex arr ZIR + +rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) +rnest n arr + | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) + = coerce (mnest (ssxFromSNat n) (coerce arr)) + +runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a +runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) + | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked arr + +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip = coerce mzip + +runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) +runzip = coerce munzip + +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) + => SNat n -> IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) +rrerankP sn sh2 f (Ranked arr) + | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) + , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) + = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr) + +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 6 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => SNat n -> IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked (n + n1) a -> Ranked (n + n2) b +rrerank sn sh2 f (rtoPrimitive -> arr) = + rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr + +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shxFromShR sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mreplicateScalP (shxFromShR sh) x) + +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) + +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = rlift (rrank arr) + (\_ -> X.sliceU i n) + arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = + rlift (rrank arr) + (\(_ :: StaticShX sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) + arr + +rreshape :: forall n n' a. Elt a + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' rarr@(Ranked arr) + | Dict <- lemKnownReplicate (rrank rarr) + , Dict <- lemKnownReplicate (shrRank sh') + = Ranked (mreshape (shxFromShR sh') arr) + +rflatten :: Elt a => Ranked n a -> Ranked 1 a +rflatten (Ranked arr) = mtoRanked (mflatten arr) + +riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a +riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota + +-- | Throws if the array is empty. +rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rminIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixrFromIxX (mminIndexPrim arr) + +-- | Throws if the array is empty. +rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rmaxIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixrFromIxX (mmaxIndexPrim arr) + +rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a +rdot1Inner arr1 arr2 + | SNat <- rrank arr1 + , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) + = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'rdot1Inner' if applicable. +rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a +rdot = coerce mdot + +rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) + +rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) + +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) + +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) + +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs new file mode 100644 index 0000000..babc809 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -0,0 +1,268 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Ranked.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +import Data.Foldable (toList) +#endif + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) +#endif +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Ranked n a) where + showsPrec d arr@(Ranked marr) = + let sh = show (toList (rshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Ranked n a) where + rnf (Ranked arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) + @(NonEmpty (Mixed sh2 (Ranked n a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + + mconcat l = M_Ranked (mconcat (coerce l)) + + mrnf (M_Ranked arr) = mrnf arr + + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + + mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Ranked arr) = marrayStrides arr + + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArrayUnsafe i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +liftRanked1 :: forall n a b. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) + -> Ranked n a -> Ranked n b +liftRanked1 = coerce + +liftRanked2 :: forall n a b c. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) + -> Ranked n a -> Ranked n b -> Ranked n c +liftRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where + (+) = liftRanked2 (+) + (-) = liftRanked2 (-) + (*) = liftRanked2 (*) + negate = liftRanked1 negate + abs = liftRanked1 abs + signum = liftRanked1 signum + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + recip = liftRanked1 recip + (/) = liftRanked2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + exp = liftRanked1 exp + log = liftRanked1 log + sqrt = liftRanked1 sqrt + (**) = liftRanked2 (**) + logBase = liftRanked2 logBase + sin = liftRanked1 sin + cos = liftRanked1 cos + tan = liftRanked1 tan + asin = liftRanked1 asin + acos = liftRanked1 acos + atan = liftRanked1 atan + sinh = liftRanked1 sinh + cosh = liftRanked1 cosh + tanh = liftRanked1 tanh + asinh = liftRanked1 asinh + acosh = liftRanked1 acosh + atanh = liftRanked1 atanh + log1p = liftRanked1 GHC.Float.log1p + expm1 = liftRanked1 GHC.Float.expm1 + log1pexp = liftRanked1 GHC.Float.log1pexp + log1mexp = liftRanked1 GHC.Float.log1mexp + +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = liftRanked2 mquotArray +rremArray = liftRanked2 mremArray + +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = liftRanked2 matan2Array + + +rshape :: Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shrFromShX2 (mshape arr) + +rrank :: Elt a => Ranked n a -> SNat n +rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs new file mode 100644 index 0000000..8b670e5 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -0,0 +1,369 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Types + + +-- * Ranked lists + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListR n i) +#else +instance Show i => Show (ListR n i) where + showsPrec _ = listrShow shows +#endif + +instance NFData i => NFData (ListR n i) where + rnf ZR = () + rnf (x ::: l) = rnf x `seq` rnf l + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') + | Just Refl <- listrEqRank sh sh' + = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') + | Just Refl <- listrEqual sh sh' + , i == j + = Just Refl +listrEqual _ _ = Nothing + +listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListR n' i -> ShowS + go _ ZR = id + go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrLength :: ListR n i -> Int +listrLength = length + +listrRank :: ListR n i -> SNat n +listrRank ZR = SNat +listrRank (_ ::: sh) = snatSucc (listrRank sh) + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrHead :: ListR (n + 1) i -> i +listrHead (i ::: _) = i +listrHead ZR = error "unreachable" + +listrTail :: ListR (n + 1) i -> ListR n i +listrTail (_ ::: sh) = sh +listrTail ZR = error "unreachable" + +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" + +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrZip :: ListR n i -> ListR n j -> ListR n (i, j) +listrZip ZR ZR = ZR +listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest +listrZip _ _ = error "listrZip: impossible pattern needlessly required" + +listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k +listrZipWith _ ZR ZR = ZR +listrZipWith f (i ::: irest) (j ::: jrest) = + f i j ::: listrZipWith f irest jrest +listrZipWith _ _ _ = + error "listrZipWith: impossible pattern needlessly required" + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> + listrFromList perm $ \sperm -> + case (listrRank sperm, listrRank sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- * Ranked indices + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxR n = IxR n Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxR n i) +#else +instance Show i => Show (IxR n i) where + showsPrec _ (IxR l) = listrShow shows l +#endif + +instance NFData i => NFData (IxR sh i) + +ixrLength :: IxR sh i -> Int +ixrLength (IxR l) = listrLength l + +ixrRank :: IxR n i -> SNat n +ixrRank (IxR sh) = listrRank sh + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixrHead :: IxR (n + 1) i -> i +ixrHead (IxR list) = listrHead list + +ixrTail :: IxR (n + 1) i -> IxR n i +ixrTail (IxR list) = IxR (listrTail list) + +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list + +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + +ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i +ixrAppend = coerce (listrAppend @_ @i) + +ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) +ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 + +ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k +ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- * Ranked shapes + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: ShR sh = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShR n i) +#else +instance Show i => Show (ShR n i) where + showsPrec _ (ShR l) = listrShow shows l +#endif + +instance NFData i => NFData (ShR sh i) + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' + +shrLength :: ShR sh i -> Int +shrLength (ShR l) = listrLength l + +-- | This function can also be used to conjure up a 'KnownNat' dictionary; +-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern +-- synonym yields 'KnownNat' evidence. +shrRank :: ShR n i -> SNat n +shrRank (ShR sh) = listrRank sh + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrHead :: ShR (n + 1) i -> i +shrHead (ShR list) = listrHead list + +shrTail :: ShR (n + 1) i -> ShR n i +shrTail (ShR list) = ShR (listrTail list) + +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list + +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + +shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i +shrAppend = coerce (listrAppend @_ @i) + +shrZip :: ShR n i -> ShR n j -> ShR n (i, j) +shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 + +shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k +shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList topl = go (SNat @n) topl + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error $ "IsList(ListR): Mismatched list length (type says " + ++ show (fromSNat (SNat @n)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs new file mode 100644 index 0000000..198a068 --- /dev/null +++ b/src/Data/Array/Nested/Shaped.hs @@ -0,0 +1,272 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Shaped ( + Shaped(Shaped), + squotArray, sremArray, satan2Array, + sshape, + module Data.Array.Nested.Shaped, + liftShaped1, liftShaped2, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Array.Internal.ShapedG qualified as SG +import Data.Array.Internal.ShapedS qualified as SS +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.TypeLits + +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X + + +semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray sh = Shaped (memptyArray (shxFromShS sh)) + +srank :: Elt a => Shaped sh a -> SNat (Rank sh) +srank = shsRank . sshape + +-- | The total number of elements in the array. +ssize :: Elt a => Shaped sh a -> Int +ssize = shsSize . sshape + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (ixxFromIxS idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a + => ShS sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a +ssumAllPrim (Shaped arr) = msumAllPrim arr + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemTakeLenMapJust perm (sshape sarr) + , Refl <- lemDropLenMapJust perm (sshape sarr) + , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) + +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) + +stoList :: Elt a => Shaped '[n] a -> [a] +stoList = map sunScalar . stoListOuter + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoListLinear :: Elt a => Shaped sh a -> [a] +stoListLinear (Shaped arr) = mtoListLinear arr + +sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a +sfromOrthotope sh (SS.A (SG.A arr)) = + Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + +stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a +stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) +snest sh arr + | Refl <- lemMapJustApp sh (Proxy @sh') + = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr)) + +sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a +sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) + | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') + = Shaped arr + +szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip = coerce mzip + +sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) +sunzip = coerce munzip + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh1) + , Refl <- lemMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr)))) + (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +-- | See the caveats at 'Data.Array.XArray.rerank'. +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = + sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shxFromShS sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) + +sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a +sflatten arr = + case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff + n@SNat -> sreshape (n :$$ ZSS) arr + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +-- | Throws if the array is empty. +sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) + +-- | Throws if the array is empty. +smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) + +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'sdot1Inner' if applicable. +sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a +sdot = coerce mdot + +stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr) + +stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs new file mode 100644 index 0000000..879e6b5 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -0,0 +1,255 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Shaped.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +#endif +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs new file mode 100644 index 0000000..5f9ba79 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -0,0 +1,425 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Shape qualified as O +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product qualified as Fun +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types + + +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. +type role ListS nominal representational +type ListS :: [Nat] -> (Nat -> Type) -> Type +data ListS sh f where + ZS :: ListS '[] f + -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity + (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +infixr 3 ::$ + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListS sh f) +#else +instance (forall n. Show (f n)) => Show (ListS sh f) where + showsPrec _ = listsShow shows +#endif + +instance (forall m. NFData (f m)) => NFData (ListS n f) where + rnf ZS = () + rnf (x ::$ l) = rnf x `seq` rnf l + +data UnconsListSRes f sh1 = + forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) +listsUncons ZS = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqType ZS ZS = Just Refl +listsEqType (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , Just Refl <- listsEqType sh sh' + = Just Refl +listsEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqual ZS ZS = Just Refl +listsEqual (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listsEqual sh sh' + = Just Refl +listsEqual _ _ = Nothing + +listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g +listsFmap _ ZS = ZS +listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs + +listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFold _ ZS = mempty +listsFold f (x ::$ xs) = f x <> listsFold f xs + +listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListS sh' f -> ShowS + go _ ZS = id + go prefix (x ::$ xs) = showString prefix . f x . go "," xs + +listsLength :: ListS sh f -> Int +listsLength = getSum . listsFold (\_ -> Sum 1) + +listsRank :: ListS sh f -> SNat (Rank sh) +listsRank ZS = SNat +listsRank (_ ::$ sh) = snatSucc (listsRank sh) + +listsToList :: ListS sh (Const i) -> [i] +listsToList ZS = [] +listsToList (Const i ::$ is) = i : listsToList is + +listsHead :: ListS (n : sh) f -> f n +listsHead (i ::$ _) = i + +listsTail :: ListS (n : sh) f -> ListS sh f +listsTail (_ ::$ sh) = sh + +listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh +listsInit (_ ::$ ZS) = ZS + +listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh +listsLast (n ::$ ZS) = n + +listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend ZS idx' = idx' +listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' + +listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip ZS ZS = ZS +listsZip (i ::$ is) (j ::$ js) = + Fun.Pair i j ::$ listsZip is js + +listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g + -> ListS sh h +listsZipWith _ ZS ZS = ZS +listsZipWith f (i ::$ is) (j ::$ js) = + f i j ::$ listsZipWith f is js + +listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm PNil _ = ZS +listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh +listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm PNil sh = sh +listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh +listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = + case listsIndex (Proxy @is') (Proxy @sh) i sh of + (item, SNat) -> item ::$ listsPermute is sh + +-- TODO: remove this SNat when the KnownNat constaint in ListS is removed +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) +listsIndex _ _ SZ (n ::$ _) = (n, SNat) +listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listsIndex p pT i sh +listsIndex _ _ _ ZS = error "Index into empty shape" + +listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) + +-- * Shaped indices + +-- | An index into a shape-typed array. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (ListS sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS = IxS ZS + +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) + where i :.$ IxS shl = IxS (Const i ::$ shl) +infixr 3 :.$ + +{-# COMPLETE ZIS, (:.$) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxS sh = IxS sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else +instance Show i => Show (IxS sh i) where + showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l +#endif + +instance Functor (IxS sh) where + fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) + +instance Foldable (IxS sh) where + foldMap f (IxS l) = listsFold (f . getConst) l + +instance NFData i => NFData (IxS sh i) + +ixsLength :: IxS sh i -> Int +ixsLength (IxS l) = listsLength l + +ixsRank :: IxS sh i -> SNat (Rank sh) +ixsRank (IxS l) = listsRank l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixsHead :: IxS (n : sh) i -> i +ixsHead (IxS list) = getConst (listsHead list) + +ixsTail :: IxS (n : sh) i -> IxS sh i +ixsTail (IxS list) = IxS (listsTail list) + +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (IxS list) = IxS (listsInit list) + +ixsLast :: IxS (n : sh) i -> i +ixsLast (IxS list) = getConst (listsLast list) + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + +ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i +ixsAppend = coerce (listsAppend @_ @(Const i)) + +ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip ZIS ZIS = ZIS +ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js + +ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +ixsZipWith _ ZIS ZIS = ZIS +ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js + +ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) + + +-- * Shaped shapes + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +-- +-- Note that because the shape of a shape-typed array is known statically, you +-- can also retrieve the array shape from a 'KnownShS' dictionary. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ListS sh SNat) + deriving (Eq, Ord, Generic) + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS = ShS ZS + +pattern (:$$) + :: forall {sh1}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) + where i :$$ ShS shl = ShS (i ::$ shl) + +infixr 3 :$$ + +{-# COMPLETE ZSS, (:$$) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else +instance Show (ShS sh) where + showsPrec _ (ShS l) = listsShow (shows . fromSNat) l +#endif + +instance NFData (ShS sh) where + rnf (ShS ZS) = () + rnf (ShS (SNat ::$ l)) = rnf (ShS l) + +instance TestEquality ShS where + testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality + +shsLength :: ShS sh -> Int +shsLength (ShS l) = listsLength l + +shsRank :: ShS sh -> SNat (Rank sh) +shsRank (ShS l) = listsRank l + +shsSize :: ShS sh -> Int +shsSize ZSS = 1 +shsSize (n :$$ sh) = fromSNat' n * shsSize sh + +shsToList :: ShS sh -> [Int] +shsToList ZSS = [] +shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh + +shsHead :: ShS (n : sh) -> SNat n +shsHead (ShS list) = listsHead list + +shsTail :: ShS (n : sh) -> ShS sh +shsTail (ShS list) = ShS (listsTail list) + +shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) +shsInit (ShS list) = ShS (listsInit list) + +shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS list) = listsLast list + +shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') +shsAppend = coerce (listsAppend @_ @SNat) + +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = coerce (listsTakeLenPerm @SNat) + +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = coerce (listsPermute @SNat) + +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) +shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) + +shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +shsPermutePrefix = coerce (listsPermutePrefix @SNat) + +type family Product sh where + Product '[] = 1 + Product (n : ns) = n * Product ns + +shsProduct :: ShS sh -> SNat (Product sh) +shsProduct ZSS = SNat +shsProduct (n :$$ sh) = n `snatMul` shsProduct sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + +withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r +withKnownShS = withDict @(KnownShS sh) + +shsKnownShS :: ShS sh -> Dict KnownShS sh +shsKnownShS ZSS = Dict +shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict + +shsOrthotopeShape :: ShS sh -> Dict O.Shape sh +shsOrthotopeShape ZSS = Dict +shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict + +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + + +-- | Untyped: length is checked at runtime. +instance KnownShS sh => IsList (ListS sh (Const i)) where + type Item (ListS sh (Const i)) = i + fromList topl = go (knownShS @sh) topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "IsList(ListS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listsToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = IxS . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList topl = ShS (go (knownShS @sh) topl) + where + go :: ShS sh' -> [Int] -> ListS sh' SNat + go ZSS [] = ZS + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = sn ::$ go sh is + | otherwise = error $ "IsList(ShS): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "IsList(ShS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shsToList diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs new file mode 100644 index 0000000..8a29aa5 --- /dev/null +++ b/src/Data/Array/Nested/Trace.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TemplateHaskell #-} +{-| +This module is API-compatible with "Data.Array.Nested", except that inputs and +outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods +also have additional 'Show' constraints. + +>>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) +>>> length (show res) `seq` () +oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))] +oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))] +oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))] +oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))] +oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))] +oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))] +>>> res +Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35])))) +-} +module Data.Array.Nested.Trace ( + -- * Traced variants + module Data.Array.Nested.Trace, + + -- * Re-exports from the plain "Data.Array.Nested" module + Ranked(Ranked), + ListR(ZR, (:::)), + IxR(..), IIxR, + ShR(..), IShR, + + Shaped(Shaped), + ListS(ZS, (::$)), + IxS(..), IIxS, + ShS(..), KnownShS(..), + + Mixed, + ListX(ZX, (::%)), + IxX(..), IIxX, + ShX(..), KnownShX(..), IShX, + StaticShX(..), + SMayNat(..), + Conversion(..), + + Elt, + PrimElt, + Primitive(..), + KnownElt, + + type (++), + Storable, + SNat, pattern SNat, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, + KnownPerm(..), + NumElt, IntElt, FloatElt, + Rank, Product, + Replicate, + MapJust, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.Nested +import Data.Array.Nested.Trace.TH + + +$(concat <$> mapM convertFun + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs new file mode 100644 index 0000000..4b388e3 --- /dev/null +++ b/src/Data/Array/Nested/Trace/TH.hs @@ -0,0 +1,98 @@ +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Trace.TH where + +import Control.Monad (zipWithM) +import Data.List (foldl', intersperse) +import Data.Maybe (isJust) +import Language.Haskell.TH hiding (cxt) + +import Debug.Trace qualified as Debug + +import Data.Array.Nested + + +splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type) +splitFunTy = \case + ArrowT `AppT` t1 `AppT` t2 -> + let (vars, cx, args, ret) = splitFunTy t2 + in (vars, cx, t1 : args, ret) + ForallT vs cx' t -> + let (vars, cx, args, ret) = splitFunTy t + in (vars ++ vs, cx ++ cx', args, ret) + t -> ([], [], [], t) + +data Arg = RRanked Type Arg + | RShaped Type Arg + | RMixed Type Arg + | RShowable Type + | ROther Type + deriving (Show) + +-- TODO: always returns Just +recognise :: Type -> Maybe Arg +recognise (ConT name `AppT` sht `AppT` ty) + | name == ''Ranked = RRanked sht <$> recognise ty + | name == ''Shaped = RShaped sht <$> recognise ty + | name == ''Mixed = RMixed sht <$> recognise ty +recognise ty@(ConT name `AppT` _) + | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] = + Just (RShowable ty) +recognise _ = Nothing + +realise :: Arg -> Type +realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty +realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty +realise (RMixed sht ty) = ConT ''Mixed `AppT` sht `AppT` realise ty +realise (RShowable ty) = ty +realise (ROther ty) = ty + +mkShow :: Arg -> Cxt +mkShow (RRanked _ ty) = mkShowElt ty +mkShow (RShaped _ ty) = mkShowElt ty +mkShow (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty)] +mkShow (RShowable _) = [] +mkShow (ROther ty) = [ConT ''Show `AppT` ty] + +mkShowElt :: Arg -> Cxt +mkShowElt (RRanked _ ty) = mkShowElt ty +mkShowElt (RShaped _ ty) = mkShowElt ty +mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT ''Elt `AppT` realise (RMixed sht ty)] +mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty] +mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty] + +convertType :: Type -> Q (Type, [Bool], Bool) +convertType typ = + let (tybndrs, cxt, args, ret) = splitFunTy typ + argrels = map recognise args + retrel = recognise ret + in return + (ForallT tybndrs + (cxt ++ [constr + | Just rel <- retrel : argrels + , constr <- mkShow rel]) + (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args) + ,map isJust argrels + ,isJust retrel) + +convertFun :: Name -> Q [Dec] +convertFun funname = do + defname <- newName (nameBase funname) + (convty, argarrs, retarr) <- reifyType funname >>= convertType + names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..] + resname <- newName "res" + let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) + let ex = LetE [ValD (VarP resname) + (NormalB (foldl' AppE (VarE funname) (map VarE names))) + []] + (VarE 'Debug.trace + `AppE` (VarE 'concat `AppE` ListE + ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++ + intersperse (LitE (StringL ", ")) + (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++ + [LitE (StringL "]")])) + `AppE` VarE resname) + return + [SigD defname convty + ,FunD defname [Clause (map VarP names) (NormalB ex) []]] diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs new file mode 100644 index 0000000..4444acd --- /dev/null +++ b/src/Data/Array/Nested/Types.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Types ( + -- * Reasoning helpers + subst1, subst2, + + -- * Reified evidence of a type class + Dict(..), + + -- * Type-level naturals + pattern SZ, pattern SS, + fromSNat', sameNat', + snatPlus, snatMinus, snatMul, + snatSucc, + + -- * Type-level lists + type (++), + Replicate, + lemReplicateSucc, + MapJust, + lemMapJustEmpty, lemMapJustCons, + Head, + Tail, + Init, + Last, + + -- * Unsafe + unsafeCoerceRefl, +) where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits +import GHC.TypeNats qualified as TN +import Unsafe.Coerce qualified + + +-- Reasoning helpers + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a + +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) +sameNat' n@SNat m@SNat = sameNat n m + +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + +-- This should be a function in base +snatPlus :: SNat n -> SNat m -> SNat (n + m) +snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMinus :: SNat n -> SNat m -> SNat (n - m) +snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl + +type family MapJust l = r | r -> l where + MapJust '[] = '[] + MapJust (x : xs) = Just x : MapJust xs + +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + +type family Head l where + Head (x : _) = x + +type family Tail l where + Tail (_ : xs) = xs + +type family Init l where + Init (x : y : xs) = x : Init (y : xs) + Init '[x] = '[] + +type family Last l where + Last (x : y : xs) = Last (y : xs) + Last '[x] = x + + +-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to +-- only typecheck for actual type equalities. One cannot, e.g. accidentally +-- write this: +-- +-- @ +-- foo :: Proxy a -> Proxy b -> a :~: b +-- foo = unsafeCoerceRefl +-- @ +-- +-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', +-- but would have resulted in interesting memory errors at runtime. +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs new file mode 100644 index 0000000..5c38d14 --- /dev/null +++ b/src/Data/Array/Strided/Orthotope.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Array.Strided.Orthotope ( + module Data.Array.Strided.Orthotope, + module Data.Array.Strided.Arith, +) where + +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS + +import Data.Array.Strided qualified as AS +import Data.Array.Strided.Arith + +-- for liftVEltwise1 +import Data.Array.Strided.Arith.Internal (stridesDense) +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec + +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) + +liftO1 :: (AS.Array n a -> AS.Array n' b) + -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO + +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) + -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs new file mode 100644 index 0000000..bf47622 --- /dev/null +++ b/src/Data/Array/XArray.hs @@ -0,0 +1,348 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.XArray where + +import Control.DeepSeq (NFData) +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Foldable (toList) +import Data.Kind +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope + + +type XArray :: [Maybe Nat] -> Type -> Type +newtype XArray sh a = XArray (S.Array (Rank sh) a) + deriving (Show, Eq, Ord, Generic) + +instance NFData (XArray sh a) + + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + = XArray (S.fromVector (shxToList sh) v) + +toVector :: Storable a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr + +-- | This allows observing the strides in the underlying orthotope array. This +-- can be useful for optimisation, but should be considered an implementation +-- detail: strides may change in new versions of this library without notice. +arrayStrides :: XArray sh a -> [Int] +arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides + +scalar :: Storable a => a -> XArray '[] a +scalar = XArray . S.scalar + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> StaticShX sh' + -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) + | Refl <- lemRankApp ssh1 ssh' + , Refl <- lemRankApp (ssxFromShX sh2) ssh' + = let arrsh :: IShX sh1 + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in if shxToList arrsh == shxToList sh2 + then XArray arr + else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + +unScalar :: Storable a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a + +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh') + , Refl <- lemRankApp (ssxFromShX sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) x) + +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) + +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) + +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a +indexPartial (XArray arr) ZIX = XArray arr +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx + +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in S.unScalar arr' + +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.append a b) + +-- | All arrays must have the same shape, except possibly for the outermost +-- dimension. +concat :: Storable a + => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a +concat ssh l + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.concatOuter (coerce (toList l))) + +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +rerank :: forall sh sh1 sh2 a b. + (Storable a, Storable b) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f xarr@(XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if 0 `elem` shxToList sh + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a -> let XArray r = f (XArray a) in r) + arr) + +rerankTop :: forall sh1 sh2 sh a b. + (Storable a, Storable b) + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + +-- | The caveat about empty arrays at @rerank@ applies here too. +rerank2 :: forall sh sh1 sh2 a b c. + (Storable a, Storable b, Storable c) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if 0 `elem` shxToList sh + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a b -> let XArray r = f (XArray a) (XArray b) in r) + arr1 arr2) + +-- | The list argument gives indices into the original dimension list. +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) + => StaticShX sh + -> Perm is + -> XArray sh a + -> XArray (PermutePrefix is sh) a +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen ssh perm + = XArray (S.transpose (permToList' perm) arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) + = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" + +transpose2 :: forall sh1 sh2 a. + StaticShX sh1 -> StaticShX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) + +sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a +sumFull _ (XArray arr) = + S.unScalar $ + liftO1 (numEltSum1Inner (SNat @0)) $ + S.fromVector [product (S.shapeL arr)] $ + S.toVector arr + +sumInner :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' arr + | Refl <- lemAppNil @sh + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShX sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = listxRank (let StaticShX l = ssh in l) + = XArray (liftO1 (numEltSum1Inner sn) arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' arr + | Refl <- lemAppNil @sh + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShX shF) $ + transpose2 (ssxFromShX shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr + +fromListOuter :: forall n sh a. Storable a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l + | Dict <- lemKnownNatRankSSX ssh + = case ssh of + SKnown m :!% _ | fromSNat' m /= length l -> + error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = + case S.shapeL arr of + 0 : _ -> [] + _ -> coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = + let n = length l + in case ssh of + SKnown m :!% _ | fromSNat' m /= n -> + error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr + +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownNatRank sh + , shxSize sh == 0 + = XArray (S.fromVector (shxToList sh) VS.empty) + | otherwise + = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh + +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) + +rev1 :: XArray (n : sh) a -> XArray (n : sh) a +rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shxToList sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs new file mode 100644 index 0000000..b424857 --- /dev/null +++ b/src/Data/Bag.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE DeriveTraversable #-} +module Data.Bag where + + +-- | An ordered sequence that can be folded over. +data Bag a = BZero | BOne a | BTwo (Bag a) (Bag a) | BList [Bag a] + deriving (Show, Functor, Foldable, Traversable) + +-- Really only here for 'pure' +instance Applicative Bag where + pure = BOne + BZero <*> _ = BZero + BOne f <*> t = f <$> t + BTwo f1 f2 <*> t = BTwo (f1 <*> t) (f2 <*> t) + BList fs <*> t = BList [f <*> t | f <- fs] + +instance Semigroup (Bag a) where (<>) = BTwo +instance Monoid (Bag a) where mempty = BZero diff --git a/src/Data/INat.hs b/src/Data/INat.hs deleted file mode 100644 index af8f18b..0000000 --- a/src/Data/INat.hs +++ /dev/null @@ -1,121 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.INat where - -import Data.Proxy -import Data.Type.Equality ((:~:) (Refl)) -import Numeric.Natural -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - --- | An inductive peano natural number. Intended to be used at the type level. -data INat = Z | S INat - deriving (Show) - --- | Singleton for a 'INat'. -data SINat n where - SZ :: SINat Z - SS :: SINat n -> SINat (S n) -deriving instance Show (SINat n) - --- | A singleton 'SINat' corresponding to @n@. -class KnownINat n where inatSing :: SINat n -instance KnownINat Z where inatSing = SZ -instance KnownINat n => KnownINat (S n) where inatSing = SS inatSing - --- | Explicitly bidirectional pattern synonym that converts between a singleton --- 'SINat' and evidence of a 'KnownINat' constraint. Analogous to 'GHC.SNat'. -pattern SINat' :: () => KnownINat n => SINat n -pattern SINat' <- (snatKnown -> Dict) - where SINat' = inatSing - --- | A 'KnownINat' dictionary is just a singleton natural, so we can create --- evidence of 'KnownINat' given an 'SINat'. -snatKnown :: SINat n -> Dict KnownINat n -snatKnown SZ = Dict -snatKnown (SS n) | Dict <- snatKnown n = Dict - --- | Convert a 'INat' to a normal number. -fromINat :: INat -> Natural -fromINat Z = 0 -fromINat (S n) = 1 + fromINat n - --- | Convert an 'SINat' to a normal number. -fromSINat :: SINat n -> Natural -fromSINat SZ = 0 -fromSINat (SS n) = 1 + fromSINat n - --- | The value of a known inductive natural as a value-level integer. -inatVal :: forall n. KnownINat n => Proxy n -> Natural -inatVal _ = fromSINat (inatSing @n) - --- | Add two 'INat's -type family n +! m where - Z +! m = m - S n +! m = S (n +! m) - --- | Convert a 'INat' to a "GHC.TypeLits" 'G.Nat'. -type family FromINat n where - FromINat Z = 0 - FromINat (S n) = 1 + FromINat n - --- | Convert a "GHC.TypeLits" 'G.Nat' to a 'INat'. -type family ToINat (n :: Nat) where - ToINat 0 = Z - ToINat n = S (ToINat (n - 1)) - -lemInjectiveFromINat :: n :~: ToINat (FromINat n) -lemInjectiveFromINat = unsafeCoerce Refl - -lemSuccFromINat :: Proxy n -> 1 + FromINat n :~: FromINat (S n) -lemSuccFromINat _ = unsafeCoerce Refl - -lemAddFromINat :: Proxy m -> Proxy n - -> FromINat m + FromINat n :~: FromINat (m +! n) -lemAddFromINat _ = unsafeCoerce Refl - -lemInjectiveToINat :: n :~: FromINat (ToINat n) -lemInjectiveToINat = unsafeCoerce Refl - -lemSuccToINat :: Proxy n -> ToINat (1 + n) :~: S (ToINat n) -lemSuccToINat _ = unsafeCoerce Refl - -lemAddToINat :: Proxy m -> Proxy n -> ToINat (m + n) :~: ToINat m +! ToINat n -lemAddToINat _ _ = unsafeCoerce Refl - --- | If an inductive 'INat' is known, then the corresponding "GHC.TypeLits" --- 'G.Nat' is also known. -knownNatFromINat :: KnownINat n => Proxy n -> Dict KnownNat (FromINat n) -knownNatFromINat (Proxy @n) = go (SINat' @n) - where - go :: SINat m -> Dict KnownNat (FromINat m) - go SZ = Dict - go (SS n) | Dict <- go n = Dict - --- * Some type-level inductive naturals - -type I0 = Z -type I1 = S I0 -type I2 = S I1 -type I3 = S I2 -type I4 = S I3 -type I5 = S I4 -type I6 = S I5 -type I7 = S I6 -type I8 = S I7 -type I9 = S I8 diff --git a/test/Gen.hs b/test/Gen.hs new file mode 100644 index 0000000..044de14 --- /dev/null +++ b/test/Gen.hs @@ -0,0 +1,174 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Gen where + +import Data.ByteString qualified as BS +import Data.Foldable (toList) +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types + +import Hedgehog +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range +import System.Random qualified as Random + +import Util + + +-- | Generates zero with small probability, because there's typically only one +-- interesting case for 0 anyway. +genRank :: Monad m => (forall n. SNat n -> PropertyT m ()) -> PropertyT m () +genRank k = do + rank <- forAll $ Gen.frequency [(1, return 0) + ,(49, Gen.int (Range.linear 1 8))] + TN.withSomeSNat (fromIntegral rank) k + +genLowBiased :: RealFloat a => (a, a) -> Gen a +genLowBiased (lo, hi) = do + x <- Gen.realFloat (Range.linearFrac 0 1) + return (lo + x * x * x * (hi - lo)) + +shuffleShR :: IShR n -> Gen (IShR n) +shuffleShR = \sh -> go (length sh) (toList sh) sh + where + go :: Int -> [Int] -> IShR n -> Gen (IShR n) + go _ _ ZSR = return ZSR + go nbag bag (_ :$: sh) = do + idx <- Gen.int (Range.linear 0 (nbag - 1)) + let (dim, bag') = case splitAt idx bag of + (pre, n : post) -> (n, pre ++ post) + _ -> error "unreachable" + (dim :$:) <$> go (nbag - 1) bag' sh + +genShR :: SNat n -> Gen (IShR n) +genShR = genShRwithTarget 100_000 + +genShRwithTarget :: Int -> SNat n -> Gen (IShR n) +genShRwithTarget targetMax sn = do + let n = fromSNat' sn + targetSize <- Gen.int (Range.linear 0 targetMax) + let genDims :: SNat m -> Int -> Gen (IShR m) + genDims SZ _ = return ZSR + genDims (SS m) 0 = do + dim <- Gen.int (Range.linear 0 20) + dims <- genDims m 0 + return (dim :$: dims) + genDims (SS m) tgt = do + dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) + ,(2 , return tgt) + ,(4 , return 1) + ,(1 , return 0)] + dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) + return (dim :$: dims) + dims <- genDims sn targetSize + let dimsL = toList dims + maxdim = maximum dimsL + cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) + shuffleShR (min cap <$> dims) + +-- | Example: given 3 and 7, might return: +-- +-- @ +-- ([ 13, 4, 27 ] +-- ,[1, 13, 1, 1, 4, 27, 1] +-- ,[4, 13, 1, 3, 4, 27, 2]) +-- @ +-- +-- The up-replicated dimensions are always nonzero and not very large, but the +-- other dimensions might be zero. +genReplicatedShR :: m <= n => SNat m -> SNat n -> Gen (IShR m, IShR n, IShR n) +genReplicatedShR = \m n -> do + let expectedSizeIncrease = round (repvalavg ^ (fromSNat' n - fromSNat' m)) + sh1 <- genShRwithTarget (1_000_000 `div` expectedSizeIncrease) m + (sh2, sh3) <- injectOnes n sh1 sh1 + return (sh1, sh2, sh3) + where + repvalrange = (1::Int, 5) + repvalavg = let (lo, hi) = repvalrange in fromIntegral (lo + hi) / 2 :: Double + + injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n) + injectOnes n@SNat shOnes sh + | m@SNat <- shrRank sh + = case cmpNat n m of + LTI -> error "unreachable" + EQI -> return (shOnes, sh) + GTI -> do + index <- Gen.int (Range.linear 0 (fromSNat' m)) + value <- Gen.int (uncurry Range.linear repvalrange) + Refl <- return (lem n m) + injectOnes n (inject index 1 shOnes) (inject index value sh) + + lem :: forall n m proxy. n > m => proxy n -> proxy m -> (m + 1 <=? n) :~: True + lem _ _ = unsafeCoerceRefl + + inject :: Int -> Int -> IShR m -> IShR (m + 1) + inject 0 v sh = v :$: sh + inject i v (w :$: sh) = w :$: inject (i - 1) v sh + inject _ _ ZSR = error "unreachable" + +genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) +genStorables rng f = do + n <- Gen.int rng + seed <- Gen.resize 99 $ Gen.int Range.linearBounded + let gen0 = Random.mkStdGen seed + (bs, _) = Random.uniformByteString (8 * n) gen0 + let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) + return $ VS.generate n (f . readW64) + +genStaticShX :: Monad m => SNat n -> (forall sh. Rank sh ~ n => StaticShX sh -> PropertyT m ()) -> PropertyT m () +genStaticShX = \n k -> case n of + SZ -> k ZKX + SS n' -> + genItem $ \item -> + genStaticShX n' $ \ssh -> + k (item :!% ssh) + where + genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m () + genItem k = do + b <- forAll Gen.bool + if b + then do + n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4)) + ,(1, return 0)] + TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn) + else k (SUnknown ()) + +genShX :: StaticShX sh -> Gen (IShX sh) +genShX ZKX = return ZSX +genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh +genShX (SUnknown () :!% ssh) = do + dim <- Gen.int (Range.linear 1 4) + (SUnknown dim :$%) <$> genShX ssh + +genPermR :: Int -> Gen PermR +genPermR n = Gen.shuffle [0 .. n-1] + +genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r +genPerm n@SNat k = do + list <- forAll $ genPermR (fromSNat' n) + permFromList list $ \perm -> do + case permCheckPermutation perm $ + case sameNat' (permRank perm) n of + Just Refl -> Just (k perm) + Nothing -> Nothing + of + Just (Just act) -> act + _ -> error "" diff --git a/test/Main.hs b/test/Main.hs index 2363813..575bb15 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,29 +1,15 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ImportQualifiedPost #-} module Main where -import Data.Array.Nested +import Test.Tasty +import Tests.C qualified +import Tests.Permutation qualified -arr :: Ranked I2 (Shaped [2, 3] (Double, Int)) -arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> - sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) -> - let s = 24*i + 6*j + 3*k + l - in (fromIntegral s, s) - -foo :: (Double, Int) -foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS) - -bad :: Ranked I2 (Ranked I1 Double) -bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) -> - rgenerate (i :$: ZSR) $ \(k :.: ZIR) -> - let s = 24*i + 6*j + 3*k - in fromIntegral s main :: IO () -main = do - print arr - print foo - print (rtranspose [1,0] arr) - -- print bad +main = defaultMain $ + testGroup "Tests" + [Tests.C.tests + ,Tests.Permutation.tests + ] diff --git a/test/Tests/C.hs b/test/Tests/C.hs new file mode 100644 index 0000000..9567393 --- /dev/null +++ b/test/Tests/C.hs @@ -0,0 +1,160 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Tests.C where + +import Control.Monad +import Data.Array.RankedS qualified as OR +import Data.Foldable (toList) +import Data.Functor.Const +import Data.Type.Equality +import Foreign +import GHC.TypeLits + +import Data.Array.Nested +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types (fromSNat') + +import Hedgehog +import Hedgehog.Gen qualified as Gen +import Hedgehog.Internal.Property (LabelName(..), forAllT) +import Hedgehog.Range qualified as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +-- import Debug.Trace + +import Gen +import Util + + +-- | Appropriate for simple different summation orders +fineTol :: Double +fineTol = 1e-8 + +debugCoverage :: Bool +debugCoverage = False + +prop_sum_nonempty :: Property +prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do + -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. + let inrank = SNat @(n + 1) + sh <- forAll $ genShR inrank + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + guard (all (> 0) (shrTail sh)) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + +prop_sum_empty :: Property +prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do + -- We only need to test shapes where the _result_ is empty; the rest is handled by 'random nonempty' above. + _outrank :: SNat n <- return $ SNat @(nm1 + 1) + let inrank = SNat @(n + 1) + sh <- forAll $ do + shtt <- genShR outrankm1 -- nm1 + sht <- shuffleShR (0 :$: shtt) -- n + n <- Gen.int (Range.linear 0 20) + return (n :$: sht) -- n + 1 + guard (0 `elem` shrTail sh) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) + let arr = OR.fromList @(n + 1) @Double (toList sh) [] + let rarr = rfromOrthotope inrank arr + OR.toList (rtoOrthotope (rsumOuter1 rarr)) === [] + +prop_sum_lasteq1 :: Property +prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do + let inrank = SNat @(n + 1) + outsh <- forAll $ genShR outrank + guard (all (> 0) outsh) + let insh = shrAppend outsh (1 :$: ZSR) + arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> + genStorables (Range.singleton (product insh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + let rarr = rfromOrthotope inrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) + +prop_sum_replicated :: Bool -> Property +prop_sum_replicated doTranspose = property $ + genRank $ \inrank1@(SNat @m) -> + genRank $ \outrank@(SNat @nm1) -> do + inrank2 :: SNat n <- return $ SNat @(nm1 + 1) + (Refl :: (m <=? n) :~: True) <- case cmpNat inrank1 inrank2 of + LTI -> return Refl -- actually we only continue if m < n + _ -> discard + (sh1, sh2, sh3) <- forAll $ genReplicatedShR inrank1 inrank2 + when debugCoverage $ do + label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1))) + label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int))) + label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int))) + guard (all (> 0) sh3) + arr <- forAllT $ + OR.stretch (toList sh3) + . OR.reshape (toList sh2) + . OR.fromVector @Double @m (toList sh1) <$> + genStorables (Range.singleton (product sh1)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + arrTrans <- + if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) + return $ OR.transpose perm arr + else return arr + let rarr = rfromOrthotope inrank2 arrTrans + almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) + +prop_negate_with :: forall f b. Show b + => ((forall n. f n -> SNat n -> PropertyT IO ()) -> PropertyT IO ()) + -> (forall n. f n -> IShR n -> Gen b) + -> (forall n. f n -> b -> OR.Array n Double -> OR.Array n Double) + -> Property +prop_negate_with genRank' genB preproc = property $ + genRank' $ \extra rank@(SNat @n) -> do + sh <- forAll $ genShR rank + guard (all (> 0) sh) + arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$> + genStorables (Range.singleton (product sh)) + (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) + bval <- forAll $ genB extra sh + let arr' = preproc extra bval arr + annotate (show (OR.shapeL arr')) + let rarr = rfromOrthotope rank arr' + rtoOrthotope (negate rarr) === OR.mapA negate arr' + +tests :: TestTree +tests = testGroup "C" + [testGroup "sum" + [testProperty "nonempty" prop_sum_nonempty + ,testProperty "empty" prop_sum_empty + ,testProperty "last==1" prop_sum_lasteq1 + ,testProperty "replicated" (prop_sum_replicated False) + ,testProperty "replicated_transposed" (prop_sum_replicated True) + ] + ,testGroup "negate" + [testProperty "normalised" $ prop_negate_with + (\k -> genRank (k (Const ()))) + (\_ _ -> pure ()) + (\_ _ -> id) + ,testProperty "slice 1D" $ prop_negate_with @((:~:) 1) + (\k -> k Refl (SNat @1)) + (\Refl (n :$: _) -> do lo <- Gen.integral (Range.constant 0 (n-1)) + len <- Gen.integral (Range.constant 0 (n-lo)) + return [(lo, len)]) + (\_ -> OR.slice) + ,testProperty "slice nD" $ prop_negate_with + (\k -> genRank (k (Const ()))) + (\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1)) + len <- Gen.integral (Range.constant 0 (n-lo-1)) + return (lo, len) + pairs <- mapM genPair (toList sh) + return pairs) + (\_ -> OR.slice) + ] + ] diff --git a/test/Tests/Permutation.hs b/test/Tests/Permutation.hs new file mode 100644 index 0000000..98a6da5 --- /dev/null +++ b/test/Tests/Permutation.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Tests.Permutation where + +import Data.Type.Equality + +import Data.Array.Nested.Permutation + +import Hedgehog +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range +import Test.Tasty +import Test.Tasty.Hedgehog + +-- import Debug.Trace + +import Gen + + +tests :: TestTree +tests = testGroup "Permutation" + [testProperty "permCheckPermutation" $ property $ do + n <- forAll $ Gen.int (Range.linear 0 10) + list <- forAll $ genPermR n + let r = permFromList list $ \perm -> + permCheckPermutation perm () + case r of + Just () -> return () + Nothing -> failure + ,testProperty "permInverse" $ property $ + genRank $ \n -> + genPerm n $ \perm -> + genStaticShX n $ \ssh -> + permInverse perm $ \_invperm proof -> + case proof ssh of + Refl -> return () + ] diff --git a/test/Util.hs b/test/Util.hs new file mode 100644 index 0000000..8a5ba72 --- /dev/null +++ b/test/Util.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Util where + +import Data.Array.RankedS qualified as OR +import Data.Kind +import GHC.TypeLits +import Hedgehog +import Hedgehog.Internal.Property (failDiff) + +import Data.Array.Nested.Types (fromSNat') + + +-- Returns highest value that satisfies the predicate, or `lo` if none does +binarySearch :: (Num a, Eq a) => (a -> a) -> a -> a -> (a -> Bool) -> a +binarySearch div2 = \lo hi f -> case (f lo, f hi) of + (False, _) -> lo + (_, True) -> hi + (_, _ ) -> go lo hi f + where + go lo hi f = -- invariant: f lo && not (f hi) + let mid = lo + div2 (hi - lo) + in if mid `elem` [lo, hi] + then mid + else if f mid then go mid hi f else go lo mid f + +orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a +orSumOuter1 (sn@SNat :: SNat n) = + let n = fromSNat' sn + in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) + +class AlmostEq f where + type AlmostEqConstr f :: Type -> Constraint + -- | absolute tolerance, lhs, rhs + almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m) + => a -> f a -> f a -> m () + +instance AlmostEq (OR.Array n) where + type AlmostEqConstr (OR.Array n) = OR.Unbox + almostEq atol lhs rhs + | OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) = + success + | otherwise = + failDiff lhs rhs |