diff options
-rw-r--r-- | .stylish-haskell.yaml | 452 | ||||
-rw-r--r-- | bench/Main.hs | 2 | ||||
-rw-r--r-- | ox-arrays.cabal | 15 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 5 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Convert.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs (renamed from src/Data/Array/Mixed/Shape.hs) | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 363 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs (renamed from src/Data/Array/Nested/Internal/Shape.hs) | 325 | ||||
-rw-r--r-- | test/Gen.hs | 3 | ||||
-rw-r--r-- | test/Tests/C.hs | 2 |
17 files changed, 846 insertions, 347 deletions
diff --git a/.stylish-haskell.yaml b/.stylish-haskell.yaml new file mode 100644 index 0000000..bfd25ea --- /dev/null +++ b/.stylish-haskell.yaml @@ -0,0 +1,452 @@ +# stylish-haskell configuration file +# ================================== + +# The stylish-haskell tool is mainly configured by specifying steps. These steps +# are a list, so they have an order, and one specific step may appear more than +# once (if needed). Each file is processed by these steps in the given order. +steps: + # Convert some ASCII sequences to their Unicode equivalents. This is disabled + # by default. + # - unicode_syntax: + # # In order to make this work, we also need to insert the UnicodeSyntax + # # language pragma. If this flag is set to true, we insert it when it's + # # not already present. You may want to disable it if you configure + # # language extensions using some other method than pragmas. Default: + # # true. + # add_language_pragma: true + + # Format module header + # + # Currently, this option is not configurable and will format all exports and + # module declarations to minimize diffs + # + # - module_header: + # # How many spaces use for indentation in the module header. + # indent: 4 + # + # # Should export lists be sorted? Sorting is only performed within the + # # export section, as delineated by Haddock comments. + # sort: true + # + # # See `separate_lists` for the `imports` step. + # separate_lists: true + + # Format record definitions. This is disabled by default. + # + # You can control the layout of record fields. The only rules that can't be configured + # are these: + # + # - "|" is always aligned with "=" + # - "," in fields is always aligned with "{" + # - "}" is likewise always aligned with "{" + # + # - records: + # # How to format equals sign between type constructor and data constructor. + # # Possible values: + # # - "same_line" -- leave "=" AND data constructor on the same line as the type constructor. + # # - "indent N" -- insert a new line and N spaces from the beginning of the next line. + # equals: "indent 2" + # + # # How to format first field of each record constructor. + # # Possible values: + # # - "same_line" -- "{" and first field goes on the same line as the data constructor. + # # - "indent N" -- insert a new line and N spaces from the beginning of the data constructor + # first_field: "indent 2" + # + # # How many spaces to insert between the column with "," and the beginning of the comment in the next line. + # field_comment: 2 + # + # # How many spaces to insert before "deriving" clause. Deriving clauses are always on separate lines. + # deriving: 2 + # + # # How many spaces to insert before "via" clause counted from indentation of deriving clause + # # Possible values: + # # - "same_line" -- "via" part goes on the same line as "deriving" keyword. + # # - "indent N" -- insert a new line and N spaces from the beginning of "deriving" keyword. + # via: "indent 2" + # + # # Sort typeclass names in the "deriving" list alphabetically. + # sort_deriving: true + # + # # Wheter or not to break enums onto several lines + # # + # # Default: false + # break_enums: false + # + # # Whether or not to break single constructor data types before `=` sign + # # + # # Default: true + # break_single_constructors: true + # + # # Whether or not to curry constraints on function. + # # + # # E.g: @allValues :: Enum a => Bounded a => Proxy a -> [a]@ + # # + # # Instead of @allValues :: (Enum a, Bounded a) => Proxy a -> [a]@ + # # + # # Default: false + # curried_context: false + + # Align the right hand side of some elements. This is quite conservative + # and only applies to statements where each element occupies a single + # line. + # Possible values: + # - always - Always align statements. + # - adjacent - Align statements that are on adjacent lines in groups. + # - never - Never align statements. + # All default to always. + - simple_align: + cases: never + top_level_patterns: never + records: never + multi_way_if: never + + # Import cleanup + - imports: + # There are different ways we can align names and lists. + # + # - global: Align the import names and import list throughout the entire + # file. + # + # - file: Like global, but don't add padding when there are no qualified + # imports in the file. + # + # - group: Only align the imports per group (a group is formed by adjacent + # import lines). + # + # - none: Do not perform any alignment. + # + # Default: global. + align: group + + # The following options affect only import list alignment. + # + # List align has following options: + # + # - after_alias: Import list is aligned with end of import including + # 'as' and 'hiding' keywords. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # > init, last, length) + # + # - with_alias: Import list is aligned with start of alias or hiding. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # > init, last, length) + # + # - with_module_name: Import list is aligned `list_padding` spaces after + # the module name. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # init, last, length) + # + # This is mainly intended for use with `pad_module_names: false`. + # + # > import qualified Data.List as List (concat, foldl, foldr, head, + # init, last, length, scanl, scanr, take, drop, + # sort, nub) + # + # - new_line: Import list starts always on new line. + # + # > import qualified Data.List as List + # > (concat, foldl, foldr, head, init, last, length) + # + # - repeat: Repeat the module name to align the import list. + # + # > import qualified Data.List as List (concat, foldl, foldr, head) + # > import qualified Data.List as List (init, last, length) + # + # Default: after_alias + list_align: after_alias + + # Right-pad the module names to align imports in a group: + # + # - true: a little more readable + # + # > import qualified Data.List as List (concat, foldl, foldr, + # > init, last, length) + # > import qualified Data.List.Extra as List (concat, foldl, foldr, + # > init, last, length) + # + # - false: diff-safe + # + # > import qualified Data.List as List (concat, foldl, foldr, init, + # > last, length) + # > import qualified Data.List.Extra as List (concat, foldl, foldr, + # > init, last, length) + # + # Default: true + pad_module_names: false + + # Long list align style takes effect when import is too long. This is + # determined by 'columns' setting. + # + # - inline: This option will put as much specs on same line as possible. + # + # - new_line: Import list will start on new line. + # + # - new_line_multiline: Import list will start on new line when it's + # short enough to fit to single line. Otherwise it'll be multiline. + # + # - multiline: One line per import list entry. + # Type with constructor list acts like single import. + # + # > import qualified Data.Map as M + # > ( empty + # > , singleton + # > , ... + # > , delete + # > ) + # + # Default: inline + long_list_align: new_line_multiline + + # Align empty list (importing instances) + # + # Empty list align has following options + # + # - inherit: inherit list_align setting + # + # - right_after: () is right after the module name: + # + # > import Vector.Instances () + # + # Default: inherit + empty_list_align: inherit + + # List padding determines indentation of import list on lines after import. + # This option affects 'long_list_align'. + # + # - <integer>: constant value + # + # - module_name: align under start of module name. + # Useful for 'file' and 'group' align settings. + # + # Default: 4 + list_padding: 2 + + # Separate lists option affects formatting of import list for type + # or class. The only difference is single space between type and list + # of constructors, selectors and class functions. + # + # - true: There is single space between Foldable type and list of it's + # functions. + # + # > import Data.Foldable (Foldable (fold, foldl, foldMap)) + # + # - false: There is no space between Foldable type and list of it's + # functions. + # + # > import Data.Foldable (Foldable(fold, foldl, foldMap)) + # + # Default: true + separate_lists: false + + # Space surround option affects formatting of import lists on a single + # line. The only difference is single space after the initial + # parenthesis and a single space before the terminal parenthesis. + # + # - true: There is single space associated with the enclosing + # parenthesis. + # + # > import Data.Foo ( foo ) + # + # - false: There is no space associated with the enclosing parenthesis + # + # > import Data.Foo (foo) + # + # Default: false + space_surround: false + + # Enabling this argument will use the new GHC lib parse to format imports. + # + # This currently assumes a few things, it will assume that you want post + # qualified imports. It is also not as feature complete as the old + # imports formatting. + # + # It does not remove redundant lines or merge lines. As such, the full + # feature scope is still pending. + # + # It _is_ however, a fine alternative if you are using features that are + # not parseable by haskell src extensions and you're comfortable with the + # presets. + # + # Default: false + ghc_lib_parser: false + + # Post qualify option moves any qualifies found in import declarations + # to the end of the declaration. This also adjust padding for any + # unqualified import declarations. + # + # - true: Qualified as <module name> is moved to the end of the + # declaration. + # + # > import Data.Bar + # > import Data.Foo qualified as F + # + # - false: Qualified remains in the default location and unqualified + # imports are padded to align with qualified imports. + # + # > import Data.Bar + # > import qualified Data.Foo as F + # + # Default: false + post_qualify: true + + # A list of rules specifying how to group modules and how to + # order the groups. + # + # Each rule has a match field; the rule only applies to module + # names matched by this pattern. Patterns are POSIX extended + # regular expressions; see the documentation of Text.Regex.TDFA + # for details: + # https://hackage.haskell.org/package/regex-tdfa-1.3.1.2/docs/Text-Regex-TDFA.html + # + # Rules are processed in order, so only the *first* rule that + # matches a specific module will apply. Any module names that do + # not match a single rule will be put into a single group at the + # end of the import block. + # + # Example: group MyApp modules first, with everything else in + # one group at the end. + # + # group_rules: + # - match: "^MyApp\\>" + # + # > import MyApp + # > import MyApp.Foo + # > + # > import Control.Monad + # > import MyApps + # > import Test.MyApp + # + # A rule can also optionally have a sub_group pattern. Imports + # that match the rule will be broken up into further groups by + # the part of the module name matched by the sub_group pattern. + # + # Example: group MyApp modules first, then everything else + # sub-grouped by the first part of the module name. + # + # group_rules: + # - match: "^MyApp\\>" + # - match: "." + # sub_group: "^[^.]+" + # + # > import MyApp + # > import MyApp.Foo + # > + # > import Control.Applicative + # > import Control.Monad + # > + # > import Data.Map + # + # A pattern only needs to match part of the module name, which + # could be in the middle. You can use ^pattern to anchor to the + # beginning of the module name, pattern$ to anchor to the end + # and ^pattern$ to force a full match. Example: + # + # - "Test\\." would match "Test.Foo" and "Foo.Test.Lib" + # - "^Test\\." would match "Test.Foo" but not "Foo.Test.Lib" + # - "\\.Test$" would match "Foo.Test" but not "Foo.Test.Lib" + # - "^Test$" would *only* match "Test" + # + # You can use \\< and \\> to anchor against the beginning and + # end of words, respectively. For example: + # + # - "^Test\\." would match "Test.Foo" but not "Test" or "Tests" + # - "^Test\\>" would match "Test.Foo" and "Test", but not + # "Tests" + # + # The default is a single rule that matches everything and + # sub-groups based on the first component of the module name. + # + # Default: [{ "match" : ".*", "sub_group": "^[^.]+" }] +# group_rules: +# - match: ".*" +# sub_group: "^[^.]+" +# - match: "^Data.Array\\>" +# sub_group: "^[^.]+" +# - match: "^HordeAd\\>" + + # Language pragmas + - language_pragmas: + # We can generate different styles of language pragma lists. + # + # - vertical: Vertical-spaced language pragmas, one per line. + # + # - compact: A more compact style. + # + # - compact_line: Similar to compact, but wrap each line with + # `{-#LANGUAGE #-}'. + # + # Default: vertical. +# style: compact + + # Align affects alignment of closing pragma brackets. + # + # - true: Brackets are aligned in same column. + # + # - false: Brackets are not aligned together. There is only one space + # between actual import and closing bracket. + # + # Default: true + align: false + + # stylish-haskell can detect redundancy of some language pragmas. If this + # is set to true, it will remove those redundant pragmas. Default: true. + remove_redundant: true + + # Language prefix to be used for pragma declaration, this allows you to + # use other options non case-sensitive like "language" or "Language". + # If a non correct String is provided, it will default to: LANGUAGE. + language_prefix: LANGUAGE + + # Replace tabs by spaces. This is disabled by default. + # - tabs: + # # Number of spaces to use for each tab. Default: 8, as specified by the + # # Haskell report. + # spaces: 8 + + # Remove trailing whitespace + - trailing_whitespace: {} + + # Squash multiple spaces between the left and right hand sides of some + # elements into single spaces. Basically, this undoes the effect of + # simple_align but is a bit less conservative. + # - squash: {} + +# A common setting is the number of columns (parts of) code will be wrapped +# to. Different steps take this into account. +# +# Set this to null to disable all line wrapping. +# +# Default: 80. +columns: 200 + +# By default, line endings are converted according to the OS. You can override +# preferred format here. +# +# - native: Native newline format. CRLF on Windows, LF on other OSes. +# +# - lf: Convert to LF ("\n"). +# +# - crlf: Convert to CRLF ("\r\n"). +# +# Default: native. +newline: native + +# Sometimes, language extensions are specified in a cabal file or from the +# command line instead of using language pragmas in the file. stylish-haskell +# needs to be aware of these, so it can parse the file correctly. +# +# No language extensions are enabled by default. +#language_extensions: +# - NoStarIsType + # - TemplateHaskell + # - QuasiQuotes + +# Attempt to find the cabal file in ancestors of the current directory, and +# parse options (currently only language extensions) from that. +# +# Default: true +cabal: true diff --git a/bench/Main.hs b/bench/Main.hs index 5901d8b..276a38a 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -17,7 +17,7 @@ import Text.Show (showListWith) import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed (Mixed (M_Primitive), mliftPrim, mliftPrim2, toPrimitive) +import Data.Array.Nested.Internal.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2) import Data.Array.Strided.Arith.Internal qualified as Arith diff --git a/ox-arrays.cabal b/ox-arrays.cabal index c46e216..2a15cbb 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -2,9 +2,13 @@ cabal-version: 3.0 name: ox-arrays version: 0.1.0.0 synopsis: An efficient CPU-based multidimensional array (tensor) library -description: An efficient and richly typed CPU-based multidimensional array (tensor) library built upon the optimized tensor representation (strides list) implemented in the orthotope package. -author: Tom Smeding -maintainer: Tom Smeding +description: + An efficient and richly typed CPU-based multidimensional array (tensor) + library built upon the optimized tensor representation (strides list) + implemented in the orthotope package. +copyright: (c) 2025 Tom Smeding, Mikolaj Konarski +author: Tom Smeding, Mikolaj Konarski +maintainer: Tom Smeding <xhackage@tomsmeding.com> license: BSD-3-Clause category: Array, Tensors build-type: Simple @@ -46,15 +50,16 @@ library Data.Array.Mixed.Internal.Arith Data.Array.Mixed.Lemmas Data.Array.Mixed.Permutation - Data.Array.Mixed.Shape Data.Array.Mixed.Types Data.Array.Mixed.XArray Data.Array.Nested.Internal.Convert Data.Array.Nested.Internal.Mixed Data.Array.Nested.Internal.Lemmas Data.Array.Nested.Internal.Ranked - Data.Array.Nested.Internal.Shape Data.Array.Nested.Internal.Shaped + Data.Array.Nested.Mixed.Shape + Data.Array.Nested.Ranked.Shape + Data.Array.Nested.Shaped.Shape Data.Bag if flag(trace-wrappers) diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index 560f762..7b3365b 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -13,8 +13,8 @@ import Data.Type.Equality import GHC.TypeLits import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +import Data.Array.Nested.Mixed.Shape -- * Reasoning helpers diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index cedfa22..ef0afe3 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -29,8 +29,8 @@ import GHC.TypeError import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +import Data.Array.Nested.Mixed.Shape -- * Permutations diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index cb790e1..502d5d9 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -34,8 +34,8 @@ import GHC.TypeLits import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +import Data.Array.Nested.Mixed.Shape type XArray :: [Maybe Nat] -> Type -> Type diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9869cba..8198a54 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -104,13 +104,14 @@ module Data.Array.Nested ( import Prelude hiding (mappend, mconcat) import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Convert import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shape import Data.Array.Nested.Internal.Shaped +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped.Shape import Data.Array.Strided.Arith import Foreign.Storable import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs index c316161..4e0f17d 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -12,13 +12,13 @@ import Data.Proxy import Data.Type.Equality import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shape import Data.Array.Nested.Internal.Shaped +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs index f894f78..f4bad70 100644 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ b/src/Data/Array/Nested/Internal/Lemmas.hs @@ -11,9 +11,9 @@ import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index a2f9737..a979d6c 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -45,10 +45,10 @@ import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Mixed.XArray qualified as X +import Data.Array.Nested.Mixed.Shape import Data.Bag diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index daf0374..368e337 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -40,12 +40,12 @@ import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Mixed.XArray qualified as X import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape import Data.Array.Strided.Arith diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 372439f..1415815 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -40,13 +40,13 @@ import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.XArray (XArray) import Data.Array.Mixed.XArray qualified as X import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape import Data.Array.Strided.Arith diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index eb8434f..5f4775c 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -20,7 +20,7 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Shape where +module Data.Array.Nested.Mixed.Shape where import Control.DeepSeq (NFData(..)) import Data.Bifunctor (first) diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs new file mode 100644 index 0000000..1c0b9eb --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -0,0 +1,363 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Mixed.Types +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Mixed.Shape + + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +instance Show i => Show (ListR n i) where + showsPrec _ = listrShow shows + +instance NFData i => NFData (ListR n i) where + rnf ZR = () + rnf (x ::: l) = rnf x `seq` rnf l + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') + | Just Refl <- listrEqRank sh sh' + = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') + | Just Refl <- listrEqual sh sh' + , i == j + = Just Refl +listrEqual _ _ = Nothing + +listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListR n' i -> ShowS + go _ ZR = id + go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrLength :: ListR n i -> Int +listrLength = length + +listrRank :: ListR n i -> SNat n +listrRank ZR = SNat +listrRank (_ ::: sh) = snatSucc (listrRank sh) + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrHead :: ListR (n + 1) i -> i +listrHead (i ::: _) = i +listrHead ZR = error "unreachable" + +listrTail :: ListR (n + 1) i -> ListR n i +listrTail (_ ::: sh) = sh +listrTail ZR = error "unreachable" + +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrZip :: ListR n i -> ListR n j -> ListR n (i, j) +listrZip ZR ZR = ZR +listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest +listrZip _ _ = error "listrZip: impossible pattern needlessly required" + +listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k +listrZipWith _ ZR ZR = ZR +listrZipWith f (i ::: irest) (j ::: jrest) = + f i j ::: listrZipWith f irest jrest +listrZipWith _ _ _ = + error "listrZipWith: impossible pattern needlessly required" + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> + listrFromList perm $ \sperm -> + case (listrRank sperm, listrRank sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +type IIxR n = IxR n Int + +instance Show i => Show (IxR n i) where + showsPrec _ (IxR l) = listrShow shows l + +instance NFData i => NFData (IxR sh i) + +ixrLength :: IxR sh i -> Int +ixrLength (IxR l) = listrLength l + +ixrRank :: IxR n i -> SNat n +ixrRank (IxR sh) = listrRank sh + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixCvtXR :: IIxX sh -> IIxR (Rank sh) +ixCvtXR ZIX = ZIR +ixCvtXR (n :.% idx) = n :.: ixCvtXR idx + +ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) +ixCvtRX ZIR = ZIX +ixCvtRX (n :.: (idx :: IxR m Int)) = + castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixCvtRX idx) + +ixrHead :: IxR (n + 1) i -> i +ixrHead (IxR list) = listrHead list + +ixrTail :: IxR (n + 1) i -> IxR n i +ixrTail (IxR list) = IxR (listrTail list) + +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list + +ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i +ixrAppend = coerce (listrAppend @_ @i) + +ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) +ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 + +ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k +ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: ShR sh = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +instance Show i => Show (ShR n i) where + showsPrec _ (ShR l) = listrShow shows l + +instance NFData i => NFData (ShR sh i) + +shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n +shCvtXR' ZSX = + castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) + ZSR +shCvtXR' (n :$% (idx :: IShX sh)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = + castWith (subst2 (lem1 @sh Refl)) + (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + where + lem1 :: forall sh' n' k. + k : sh' :~: Replicate n' Nothing + -> Rank sh' + 1 :~: n' + lem1 Refl = unsafeCoerceRefl + + lem2 :: k : sh :~: Replicate n Nothing + -> sh :~: Replicate (Rank sh) Nothing + lem2 Refl = unsafeCoerceRefl + +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: (idx :: ShR m Int)) = + castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shCvtRX idx) + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' + +shrLength :: ShR sh i -> Int +shrLength (ShR l) = listrLength l + +-- | This function can also be used to conjure up a 'KnownNat' dictionary; +-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern +-- synonym yields 'KnownNat' evidence. +shrRank :: ShR n i -> SNat n +shrRank (ShR sh) = listrRank sh + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrHead :: ShR (n + 1) i -> i +shrHead (ShR list) = listrHead list + +shrTail :: ShR (n + 1) i -> ShR n i +shrTail (ShR list) = ShR (listrTail list) + +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list + +shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i +shrAppend = coerce (listrAppend @_ @i) + +shrZip :: ShR n i -> ShR n j -> ShR n (i, j) +shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 + +shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k +shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList topl = go (SNat @n) topl + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error $ "IsList(ListR): Mismatched list length (type says " + ++ show (fromSNat (SNat @n)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList = Foldable.toList diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 97b9456..6c43fa7 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -24,7 +24,7 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shape where +module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) import Data.Array.Mixed.Types @@ -42,331 +42,10 @@ import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits -import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape - - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows - -instance NFData i => NFData (ListR n i) where - rnf ZR = () - rnf (x ::: l) = rnf x `seq` rnf l - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqRank ZR ZR = Just Refl -listrEqRank (_ ::: sh) (_ ::: sh') - | Just Refl <- listrEqRank sh sh' - = Just Refl -listrEqRank _ _ = Nothing - --- | This compares the lists for value equality. -listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqual ZR ZR = Just Refl -listrEqual (i ::: sh) (j ::: sh') - | Just Refl <- listrEqual sh sh' - , i == j - = Just Refl -listrEqual _ _ = Nothing - -listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR n' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrLength :: ListR n i -> Int -listrLength = length - -listrRank :: ListR n i -> SNat n -listrRank ZR = SNat -listrRank (_ ::: sh) = snatSucc (listrRank sh) - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrHead :: ListR (n + 1) i -> i -listrHead (i ::: _) = i -listrHead ZR = error "unreachable" - -listrTail :: ListR (n + 1) i -> ListR n i -listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" - -listrInit :: ListR (n + 1) i -> ListR n i -listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh -listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" - -listrLast :: ListR (n + 1) i -> i -listrLast (_ ::: sh@(_ ::: _)) = listrLast sh -listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrZip :: ListR n i -> ListR n j -> ListR n (i, j) -listrZip ZR ZR = ZR -listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest -listrZip _ _ = error "listrZip: impossible pattern needlessly required" - -listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k -listrZipWith _ ZR ZR = ZR -listrZipWith f (i ::: irest) (j ::: jrest) = - f i j ::: listrZipWith f irest jrest -listrZipWith _ _ _ = - error "listrZipWith: impossible pattern needlessly required" - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l - -instance NFData i => NFData (IxR sh i) - -ixrLength :: IxR sh i -> Int -ixrLength (IxR l) = listrLength l - -ixrRank :: IxR n i -> SNat n -ixrRank (IxR sh) = listrRank sh - -ixrZero :: SNat n -> IIxR n -ixrZero SZ = ZIR -ixrZero (SS n) = 0 :.: ixrZero n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - -ixrHead :: IxR (n + 1) i -> i -ixrHead (IxR list) = listrHead list - -ixrTail :: IxR (n + 1) i -> IxR n i -ixrTail (IxR list) = IxR (listrTail list) - -ixrInit :: IxR (n + 1) i -> IxR n i -ixrInit (IxR list) = IxR (listrInit list) - -ixrLast :: IxR (n + 1) i -> i -ixrLast (IxR list) = listrLast list - -ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) - -ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) -ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 - -ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k -ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 - -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l - -instance NFData i => NFData (ShR sh i) - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' - --- | This compares the shapes for value equality. -shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' - -shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l - --- | This function can also be used to conjure up a 'KnownNat' dictionary; --- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern --- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh - --- | The number of elements in an array described by this shape. -shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh - -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list - -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) - -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) - -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list - -shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 - -shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList topl = go (SNat @n) topl - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error $ "IsList(ListR): Mismatched list length (type says " - ++ show (fromSNat (SNat @n)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where - type Item (IxR n i) = i - fromList = IxR . IsList.fromList - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +import Data.Array.Nested.Mixed.Shape type role ListS nominal representational diff --git a/test/Gen.hs b/test/Gen.hs index bf002ca..50c671f 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -21,10 +21,9 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Ranked.Shape import Hedgehog import Hedgehog.Gen qualified as Gen diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 4861eb1..72bf8f6 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -20,7 +20,7 @@ import GHC.TypeLits import Data.Array.Mixed.Types (fromSNat') import Data.Array.Nested -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Ranked.Shape import Hedgehog import Hedgehog.Gen qualified as Gen |