aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md10
-rw-r--r--bench/Main.hs13
-rw-r--r--cbits/arith_lists.h62
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs30
-rw-r--r--ox-arrays.cabal15
-rw-r--r--src/Data/Array/Nested/Convert.hs72
-rw-r--r--src/Data/Array/Nested/Lemmas.hs14
-rw-r--r--src/Data/Array/Nested/Mixed.hs170
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs584
-rw-r--r--src/Data/Array/Nested/Mixed/Shape/Internal.hs59
-rw-r--r--src/Data/Array/Nested/Permutation.hs104
-rw-r--r--src/Data/Array/Nested/Ranked.hs18
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs61
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs243
-rw-r--r--src/Data/Array/Nested/Shaped.hs24
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs52
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs421
-rw-r--r--src/Data/Array/Nested/Types.hs6
-rw-r--r--src/Data/Array/Strided/Orthotope.hs5
-rw-r--r--src/Data/Array/XArray.hs86
-rw-r--r--test/Gen.hs7
-rw-r--r--test/Tests/C.hs39
22 files changed, 1222 insertions, 873 deletions
diff --git a/README.md b/README.md
index 8393148..667bf9d 100644
--- a/README.md
+++ b/README.md
@@ -112,18 +112,18 @@ data ShS sh where
data IShX xsh where
ZSX :: IShX '[]
- (:$%) :: SMayNat Int SNat mn -> IShX xsh -> IShX (mn : xsh)
+ (:$%) :: SMayNat Int mn -> IShX xsh -> IShX (mn : xsh)
-- where:
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: SNat n -> SMayNat i (Just n)
-- Occasionally one needs a singleton for only the _known_ dimensions of a mixed
-- shape -- that is to say, only the statically-known part of a mixed shape.
-- StaticShX provides for this need. It can be used as if defined as follows:
data StaticShX xsh where
ZKX :: StaticShX '[]
- (:!%) :: SMayNat () SNat mn -> StaticShX xsh -> StaticShX (mn : xsh)
+ (:!%) :: SMayNat () mn -> StaticShX xsh -> StaticShX (mn : xsh)
-- The Elt class describes types that can be used as elements of an array. While
-- it is technically possible to define new instances of this class, typical
diff --git a/bench/Main.hs b/bench/Main.hs
index 2058e77..185bef0 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -9,7 +9,6 @@ import Control.Monad (when)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
-import Data.Foldable (toList)
import Data.Vector.Storable qualified as VS
import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
@@ -19,6 +18,7 @@ import Data.Array.Nested
import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
+import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith.Internal qualified as Arith
import Data.Array.XArray (XArray(..))
@@ -40,7 +40,7 @@ main_tests = defaultMain
let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
l ""
- in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
+ in bench (name ++ " " ++ showSh (shrToList (rshape inp1)) ++
" str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
@@ -176,6 +176,15 @@ tests_compare =
,bench "sum Double [1e6]" $
nf (\a -> runScalar (rsumOuter1Prim a))
(riota @Double n)
+ ,bench "sumAll iota [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (riota @Double n)
+ ,bench "sumAll rev1(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rrev1 $ riota @Double n)
+ ,bench "sumAll reshape(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rreshape (1 :$: n :$: 1 :$: ZSR) $ riota @Double n)
]
,bgroup "NumElt"
[bench "sum(+) Double [1e6]" $
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index 432765c..dc9ad1a 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -2,38 +2,38 @@ LIST_BINOP(BO_ADD, 1, +)
LIST_BINOP(BO_SUB, 2, -)
LIST_BINOP(BO_MUL, 3, *)
-LIST_IBINOP(IB_QUOT, 1, quot)
-LIST_IBINOP(IB_REM, 2, rem)
+LIST_IBINOP(IB_QUOT, 11, quot)
+LIST_IBINOP(IB_REM, 12, rem)
-LIST_FBINOP(FB_DIV, 1, /)
-LIST_FBINOP(FB_POW, 2, **)
-LIST_FBINOP(FB_LOGBASE, 3, logBase)
-LIST_FBINOP(FB_ATAN2, 4, atan2)
+LIST_FBINOP(FB_DIV, 21, /)
+LIST_FBINOP(FB_POW, 22, **)
+LIST_FBINOP(FB_LOGBASE, 23, logBase)
+LIST_FBINOP(FB_ATAN2, 24, atan2)
-LIST_UNOP(UO_NEG, 1,)
-LIST_UNOP(UO_ABS, 2,)
-LIST_UNOP(UO_SIGNUM, 3,)
+LIST_UNOP(UO_NEG, 31,)
+LIST_UNOP(UO_ABS, 32,)
+LIST_UNOP(UO_SIGNUM, 33,)
-LIST_FUNOP(FU_RECIP, 1,)
-LIST_FUNOP(FU_EXP, 2,)
-LIST_FUNOP(FU_LOG, 3,)
-LIST_FUNOP(FU_SQRT, 4,)
-LIST_FUNOP(FU_SIN, 5,)
-LIST_FUNOP(FU_COS, 6,)
-LIST_FUNOP(FU_TAN, 7,)
-LIST_FUNOP(FU_ASIN, 8,)
-LIST_FUNOP(FU_ACOS, 9,)
-LIST_FUNOP(FU_ATAN, 10,)
-LIST_FUNOP(FU_SINH, 11,)
-LIST_FUNOP(FU_COSH, 12,)
-LIST_FUNOP(FU_TANH, 13,)
-LIST_FUNOP(FU_ASINH, 14,)
-LIST_FUNOP(FU_ACOSH, 15,)
-LIST_FUNOP(FU_ATANH, 16,)
-LIST_FUNOP(FU_LOG1P, 17,)
-LIST_FUNOP(FU_EXPM1, 18,)
-LIST_FUNOP(FU_LOG1PEXP, 19,)
-LIST_FUNOP(FU_LOG1MEXP, 20,)
+LIST_FUNOP(FU_RECIP, 41,)
+LIST_FUNOP(FU_EXP, 42,)
+LIST_FUNOP(FU_LOG, 43,)
+LIST_FUNOP(FU_SQRT, 44,)
+LIST_FUNOP(FU_SIN, 45,)
+LIST_FUNOP(FU_COS, 46,)
+LIST_FUNOP(FU_TAN, 47,)
+LIST_FUNOP(FU_ASIN, 48,)
+LIST_FUNOP(FU_ACOS, 49,)
+LIST_FUNOP(FU_ATAN, 50,)
+LIST_FUNOP(FU_SINH, 51,)
+LIST_FUNOP(FU_COSH, 52,)
+LIST_FUNOP(FU_TANH, 53,)
+LIST_FUNOP(FU_ASINH, 54,)
+LIST_FUNOP(FU_ACOSH, 55,)
+LIST_FUNOP(FU_ATANH, 56,)
+LIST_FUNOP(FU_LOG1P, 57,)
+LIST_FUNOP(FU_EXPM1, 58,)
+LIST_FUNOP(FU_LOG1PEXP, 59,)
+LIST_FUNOP(FU_LOG1MEXP, 60,)
-LIST_REDOP(RO_SUM, 1,)
-LIST_REDOP(RO_PRODUCT, 2,)
+LIST_REDOP(RO_SUM, 81,)
+LIST_REDOP(RO_PRODUCT, 82,)
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 7578dd8..2eb0666 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -27,7 +27,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
+import GHC.TypeNats qualified as TN
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
@@ -42,7 +42,7 @@ import Data.Array.Strided.Array
-- TODO: move this to a utilities module
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
data Dict c where
Dict :: c => Dict c
@@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) =
unrepSize = product [n | (n, True) <- zip sh replDims]
- in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
simplifyArray :: Array n a
@@ -200,7 +200,7 @@ simplifyArray :: Array n a
-> r)
-> r
simplifyArray array k
- | let revDims = map (<0) (arrStrides array)
+ | let revDims = map (< 0) (arrStrides array)
, Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array)
= k array'
unrepSize
@@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
, let unrepSize = product [n | (n, True) <- zip sh replDims]
- = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
k @lenshF
(Array shF strides1F offset1 vec1)
(Array shF strides2F offset2 vec2)
@@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
VS.unsafeWith vec' $ \pv ->
let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
- TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -491,7 +491,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
- TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -879,19 +879,19 @@ instance IntElt Int64 where
instance IntElt Int where
intEltQuot = intWidBranch2 @Int quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @Int rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
instance IntElt CInt where
intEltQuot = intWidBranch2 @CInt quot
- (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT))
- (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT))
intEltRem = intWidBranch2 @CInt rem
- (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM))
- (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM))
+ (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM))
class NumElt a => FloatElt a where
floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 3a92f6e..eb65e18 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -57,7 +57,9 @@ flag default-show-instances
common basics
default-language: Haskell2010
- ghc-options: -Wall -Wcompat -Widentities -Wunused-packages
+ ghc-options: -Wall -Wcompat -Widentities -Wunused-packages -Wpartial-fields -Wredundant-bang-patterns -Woperator-whitespace -Wredundant-strictness-flags
+ if impl(ghc >= 9.14)
+ ghc-options: -Wno-pattern-namespace-specifier
library
import: basics
@@ -68,7 +70,6 @@ library
Data.Array.Nested.Convert
Data.Array.Nested.Mixed
Data.Array.Nested.Mixed.Shape
- Data.Array.Nested.Mixed.Shape.Internal
Data.Array.Nested.Lemmas
Data.Array.Nested.Permutation
Data.Array.Nested.Ranked
@@ -91,26 +92,28 @@ library
exposed-modules:
Data.Array.Nested.Trace
Data.Array.Nested.Trace.TH
+ build-depends:
+ template-haskell
+ other-extensions: TemplateHaskell
if flag(default-show-instances)
cpp-options: -DOXAR_DEFAULT_SHOW_INSTANCES
build-depends:
- strided-array-ops,
+ ox-arrays:strided-array-ops,
base,
deepseq < 1.7,
ghc-typelits-knownnat,
ghc-typelits-natnormalise,
orthotope < 0.2,
- template-haskell,
vector,
vector-stream
hs-source-dirs: src
- other-extensions: TemplateHaskell
library strided-array-ops
import: basics
+ visibility: public
exposed-modules:
Data.Array.Strided
Data.Array.Strided.Array
@@ -179,7 +182,7 @@ benchmark bench
main-is: Main.hs
build-depends:
ox-arrays,
- strided-array-ops,
+ ox-arrays:strided-array-ops,
base,
hmatrix,
orthotope,
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 8c88d23..408bf8a 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -15,10 +15,10 @@
module Data.Array.Nested.Convert (
-- * Shape\/index\/list casting functions
-- ** To ranked
- ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
+ ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX,
listrCast, ixrCast, shrCast,
-- ** To shaped
- ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
+ ixsFromIxR, ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX,
ixsCast,
-- ** To mixed
ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
@@ -38,9 +38,11 @@ module Data.Array.Nested.Convert (
) where
import Control.Category
+import Data.Coerce (coerce)
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
@@ -55,48 +57,39 @@ import Data.Array.Nested.Types
-- * To ranked
+-- TODO: change all those unsafeCoerces into coerces by defining shaped
+-- and ranekd index types as newtypes of the mixed index type
+-- and similarly for the sized lists or, preferably, by defining
+-- all as newtypes over [], exploiting fusion and getting free toList.
ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
-ixrFromIxS ZIS = ZIR
-ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
+ixrFromIxS = unsafeCoerce
-ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX ZIX = ZIR
-ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+-- ixrFromIxX re-exported
shrFromShS :: ShS sh -> IShR (Rank sh)
shrFromShS ZSS = ZSR
shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
--- shrFromShX re-exported
--- shrFromShX2 re-exported
+shrFromShXAnyShape :: IShX sh -> IShR (Rank sh)
+shrFromShXAnyShape ZSX = ZSR
+shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx
+
+shrFromShX :: IShX (Replicate n Nothing) -> IShR n
+shrFromShX = coerce
+
-- listrCast re-exported
-- ixrCast re-exported
-- shrCast re-exported
-- * To shaped
--- TODO: these take a ShS because there are KnownNats inside IxS.
-
-ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i
-ixsFromIxR ZSS ZIR = ZIS
-ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx
+ixsFromIxR :: IxR (Rank sh) i -> IxS sh i
+ixsFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
--- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the
--- following, but more efficient:
---
--- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx)
-ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i
-ixsFromIxR' ZSS ZIR = ZIS
-ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx
-ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank"
-
--- TODO: this takes a ShS because there are KnownNats inside IxS.
-ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
-ixsFromIxX ZSS ZIX = ZIS
-ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
+-- ixsFromIxX re-exported
-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to
--- the following, but more efficient:
+-- the following, but less verbose:
--
-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx)
ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i
@@ -113,7 +106,8 @@ withShsFromShR (n :$: sh) k =
Just sn@SNat -> k (sn :$$ sh')
Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")"
--- shsFromShX re-exported
+shsFromShX :: IShX (MapJust sh) -> ShS sh
+shsFromShX = coerce
-- | Produce an existential 'ShS' from an 'IShX'. If you already know that
-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead.
@@ -128,6 +122,7 @@ withShsFromShX (SUnknown n :$% sh) k =
Just sn@SNat -> k (sn :$$ sh')
Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")"
+-- If it ever matters for performance, this is unsafeCoercible.
shsFromSSX :: StaticShX (MapJust sh) -> ShS sh
shsFromSSX = shsFromShX Prelude.. shxFromSSX
@@ -135,25 +130,14 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX
-- * To mixed
-ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
-ixxFromIxR ZIR = ZIX
-ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
- (n :.% ixxFromIxR idx)
-
-ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
-ixxFromIxS ZIS = ZIX
-ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
+-- ixxFromIxR re-exported
+-- ixxFromIxS re-exported
shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
-shxFromShR ZSR = ZSX
-shxFromShR (n :$: (idx :: ShR m i)) =
- castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m)))
- (SUnknown n :$% shxFromShR idx)
+shxFromShR = coerce
shxFromShS :: ShS sh -> IShX (MapJust sh)
-shxFromShS ZSS = ZSX
-shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
+shxFromShS = coerce
-- ixxCast re-exported
-- shxCast re-exported
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index e089479..fa5611b 100644
--- a/src/Data/Array/Nested/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -56,6 +56,20 @@ lemReplicatePlusApp sn _ _ = go sn
-}
lemReplicatePlusApp _ _ _ = unsafeCoerceRefl
+lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0
+lemReplicateEmpty _ Refl = unsafeCoerceRefl
+
+-- TODO: make less ad-hoc and rename these three:
+lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1
+lemReplicateCons _ _ Refl = unsafeCoerceRefl
+
+lemReplicateCons2 :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> sh :~: Replicate (Rank sh) Nothing
+lemReplicateCons2 _ _ Refl = unsafeCoerceRefl
+
+lemReplicateSucc2 :: forall n1 n proxy.
+ proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing
+lemReplicateSucc2 _ _ = unsafeCoerceRefl
+
lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
-> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 182943d..39f00fa 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -23,12 +23,14 @@ module Data.Array.Nested.Mixed where
import Prelude hiding (mconcat)
import Control.DeepSeq (NFData(..))
-import Control.Monad (forM_, when)
+import Control.Monad (foldM_, forM_, when)
import Control.Monad.ST
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
import Data.Array.RankedS qualified as S
import Data.Bifunctor (bimap)
import Data.Coerce
-import Data.Foldable (toList)
import Data.Int
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
@@ -39,6 +41,7 @@ import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types (CInt)
import Foreign.Storable (Storable)
+import Foreign.Storable qualified as Storable
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
@@ -237,11 +240,13 @@ instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
+{-# INLINE mliftNumElt1 #-}
mliftNumElt1 :: (PrimElt a, PrimElt b)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
-> Mixed sh a -> Mixed sh b
mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+{-# INLINE mliftNumElt2 #-}
mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
@@ -310,7 +315,7 @@ class Elt a where
-- | See 'mfromListOuter'. If the list does not have the given length, a
-- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
- mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -355,6 +360,9 @@ class Elt a where
-- | Tree giving the shape of every array component.
type ShapeTree a
+ -- | Produces an internal representation of a tree of shapes of (potentially)
+ -- nested arrays. If the argument is an array, it requires that the array
+ -- is not empty (it's guaranteed to crash early otherwise).
mshapeTree :: a -> ShapeTree a
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
@@ -367,17 +375,21 @@ class Elt a where
-- this mixed array.
marrayStrides :: Mixed sh a -> Bag [Int]
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ -- | 'mvecsFreeze' but without copying the mutable vectors before freezing
+ -- them. If the mutable vectors are mutated after this function, referential
+ -- transparency is broken.
+ mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
@@ -393,11 +405,18 @@ class Elt a => KnownElt a where
-- this vector and an example for the contents.
mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+ -- | Create initialised vectors for this array type, given the shape of
+ -- this vector and the chosen element.
+ mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ -- Somehow, INLINE here can increase allocation with GHC 9.14.1.
+ -- Maybe that happens in void instances such as @Primitive ()@.
+ {-# INLINEABLE mshape #-}
mshape (M_Primitive sh _) = sh
{-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
@@ -405,10 +424,11 @@ instance Storable a => Elt (Primitive a) where
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuterSN sn l@(arr1 :| _) =
- let sh = SKnown sn :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ let sh = mshape arr1
+ in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -419,6 +439,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
@@ -430,6 +451,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -464,18 +486,22 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+ mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
+ -- TODO: this use of toVectorListT is suboptimal
+ mvecsWritePartialLinear
:: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+ offset = i * shxSize arrsh
+ f off el = do
+ VS.copy (VSM.slice off (VS.length el) v) el
+ return $! off + VS.length el
+ foldM_ f offset (OI.toVectorListT sht t)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+ mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
deriving via Primitive Bool instance Elt Bool
@@ -492,6 +518,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -508,16 +535,22 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ {-# INLINEABLE mshape #-}
mshape (M_Tup2 a _) = mshape a
+ {-# INLINEABLE mindex #-}
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ {-# INLINEABLE mindexPartial #-}
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
mfromListOuterSN sn l =
M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
(mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ {-# INLINE mlift #-}
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ {-# INLINE mlift2 #-}
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ {-# INLINE mliftL #-}
mliftL ssh2 f =
let unzipT2l [] = ([], [])
unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
@@ -542,17 +575,19 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
+ mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
+ mvecsWriteLinear i x a
+ mvecsWriteLinear i y b
+ mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartialLinear proxy i x a
+ mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+ mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y
mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-- Arrays of arrays are just arrays, but with more dimensions.
@@ -560,13 +595,16 @@ instance Elt a => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
+ {-# INLINEABLE mshape #-}
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
+ = shxTakeSh (Proxy @sh') sh (mshape arr)
+ {-# INLINEABLE mindex #-}
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) = mindexPartial arr
+ {-# INLINEABLE mindexPartial #-}
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
@@ -581,16 +619,17 @@ instance Elt a => Elt (Mixed sh' a) where
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift ssh2 f (M_Nest sh1 arr) =
let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -598,16 +637,17 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -616,16 +656,17 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
-> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result))
in fmap (M_Nest sh2) result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
f' sshT
@@ -658,12 +699,13 @@ instance Elt a => Elt (Mixed sh' a) where
mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
mconcat l@(M_Nest sh1 _ :| _) =
let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result
+ in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result
mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+ -- This requires that @arr@ is not empty.
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
@@ -676,17 +718,20 @@ instance Elt a => Elt (Mixed sh' a) where
marrayStrides (M_Nest _ arr) = marrayStrides arr
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+ mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
+ mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+ = mvecsWritePartialLinear proxy idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+ mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
@@ -697,9 +742,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
where
sh' = mshape example
+ mvecsReplicate sh example = do
+ vecs <- mvecsUnsafeNew sh example
+ forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs
+ -- this is a slow case, but the alternative, mvecsUnsafeNew with manual
+ -- writing in a loop, leads to every case being as slow
+ return vecs
+
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWrite #-}
+mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs
+
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWritePartial #-}
+mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs
+
-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
@@ -746,7 +810,7 @@ mgenerate sh f = case shxEnum sh of
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
+ mvecsUnsafeFreeze sh vecs
-- | An optimized special case of 'mgenerate', where the function results
-- are of a primitive type and so there's not need to check that all shapes
@@ -759,19 +823,23 @@ mgeneratePrim sh f =
let g i = f (ixxFromLinear sh i)
in mfromVector sh $ VS.generate (shxSize sh) g
+{-# INLINEABLE msumOuter1PrimP #-}
msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
+{-# INLINEABLE msumOuter1Prim #-}
msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+{-# INLINEABLE msumAllPrimP #-}
msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+{-# INLINEABLE msumAllPrim #-}
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
msumAllPrim arr = msumAllPrimP (toPrimitive arr)
@@ -782,7 +850,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
ssh = ssxFromShX sh
- snm :: SMayNat () SNat (AddMaybe n m)
+ snm :: SMayNat () (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
@@ -792,15 +860,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
=> StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
f ssh' = X.append (ssxAppend ssh ssh')
+{-# INLINEABLE mfromVectorP #-}
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+{-# INLINEABLE mfromVector #-}
mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+{-# INLINEABLE mtoVectorP #-}
mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
mtoVectorP (M_Primitive _ v) = X.toVector v
+{-# INLINEABLE mtoVector #-}
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
@@ -856,7 +928,7 @@ mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
+ xarr = X.fromList1 l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
@@ -865,11 +937,15 @@ mfromList1PrimN n l =
Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
-mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
mfromList1PrimSN sn l =
- let ssh = SKnown sn :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+ let ssh = SKnown sn :$% ZSX
+ in fromPrimitive $ M_Primitive ssh
+ $ if Storable.sizeOf (undefined :: a) > 0
+ then X.fromList1SN sn l
+ else case l of -- don't force the list if all elements are the same
+ a0 : _ -> X.replicateScal ssh a0
+ [] -> X.fromList1SN sn l
mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
@@ -886,7 +962,7 @@ munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
-mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
@@ -999,6 +1075,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+{-# INLINEABLE mdot1Inner #-}
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
@@ -1014,6 +1091,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'mdot1Inner' if applicable.
+{-# INLINEABLE mdot #-}
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
mdot a b =
munScalar $
@@ -1032,11 +1110,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+{-# INLINE mliftPrim #-}
mliftPrim :: (PrimElt a, PrimElt b)
=> (a -> b)
-> Mixed sh a -> Mixed sh b
mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+{-# INLINE mliftPrim2 #-}
mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (a -> b -> c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index c999853..abcf3f8 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -1,9 +1,10 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
@@ -31,13 +32,11 @@ import Control.DeepSeq (NFData(..))
import Data.Bifunctor (first)
import Data.Coerce
import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
+import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
-import GHC.Generics (Generic)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
@@ -45,7 +44,6 @@ import GHC.TypeLits
import GHC.TypeLits.Orphans ()
#endif
-import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
@@ -56,129 +54,107 @@ type family Rank sh where
Rank (_ : sh) = Rank sh + 1
--- * Mixed lists
+-- * Mixed lists to be used IxX and shaped and ranked lists and indexes
type role ListX nominal representational
-type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
-data ListX sh f where
- ZX :: ListX '[] f
- (::%) :: f n -> ListX sh f -> ListX (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
+type ListX :: [Maybe Nat] -> Type -> Type
+data ListX sh i where
+ ZX :: ListX '[] i
+ (::%) :: forall n sh {i}. i -> ListX sh i -> ListX (n : sh) i
+deriving instance Eq i => Eq (ListX sh i)
+deriving instance Ord i => Ord (ListX sh i)
infixr 3 ::%
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance (forall n. Show (f n)) => Show (ListX sh f)
+deriving instance Show i => Show (ListX sh i)
#else
-instance (forall n. Show (f n)) => Show (ListX sh f) where
+instance Show i => Show (ListX sh i) where
showsPrec _ = listxShow shows
#endif
-instance (forall n. NFData (f n)) => NFData (ListX sh f) where
+instance NFData i => NFData (ListX sh i) where
rnf ZX = ()
rnf (x ::% l) = rnf x `seq` rnf l
-data UnconsListXRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
+data UnconsListXRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh i) i
listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
listxUncons ZX = Nothing
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEqType ZX ZX = Just Refl
-listxEqType (n ::% sh) (m ::% sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listxEqType sh sh'
- = Just Refl
-listxEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEqual ZX ZX = Just Refl
-listxEqual (n ::% sh) (m ::% sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listxEqual sh sh'
- = Just Refl
-listxEqual _ _ = Nothing
-
-{-# INLINE listxFmap #-}
-listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
-listxFmap _ ZX = ZX
-listxFmap f (x ::% xs) = f x ::% listxFmap f xs
+instance Functor (ListX l) where
+ {-# INLINE fmap #-}
+ fmap _ ZX = ZX
+ fmap f (x ::% xs) = f x ::% fmap f xs
-{-# INLINE listxFoldMap #-}
-listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
-listxFoldMap _ ZX = mempty
-listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs
+instance Foldable (ListX l) where
+ {-# INLINE foldMap #-}
+ foldMap _ ZX = mempty
+ foldMap f (x ::% xs) = f x <> foldMap f xs
+ {-# INLINE foldr #-}
+ foldr _ z ZX = z
+ foldr f z (x ::% xs) = f x (foldr f z xs)
+ toList = listxToList
+ null ZX = False
+ null _ = True
-listxLength :: ListX sh f -> Int
-listxLength = getSum . listxFoldMap (\_ -> Sum 1)
+listxLength :: ListX sh i -> Int
+listxLength = length
-listxRank :: ListX sh f -> SNat (Rank sh)
+listxRank :: ListX sh i -> SNat (Rank sh)
listxRank ZX = SNat
listxRank (_ ::% l) | SNat <- listxRank l = SNat
{-# INLINE listxShow #-}
-listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
+listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
where
- go :: String -> ListX sh' f -> ShowS
+ go :: String -> ListX sh' i -> ShowS
go _ ZX = id
go prefix (x ::% xs) = showString prefix . f x . go "," xs
-listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i)
+listxFromList :: StaticShX sh -> [i] -> ListX sh i
listxFromList topssh topl = go topssh topl
where
- go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
+ go :: StaticShX sh' -> [i] -> ListX sh' i
go ZKX [] = ZX
- go (_ :!% sh) (i : is) = Const i ::% go sh is
+ go (_ :!% sh) (i : is) = i ::% go sh is
go _ _ = error $ "listxFromList: Mismatched list length (type says "
++ show (ssxLength topssh) ++ ", list has length "
++ show (length topl) ++ ")"
{-# INLINEABLE listxToList #-}
-listxToList :: ListX sh (Const i) -> [i]
+listxToList :: ListX sh i -> [i]
listxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: ListX sh (Const i) -> is
+ let go :: ListX sh i -> is
go ZX = nil
- go (Const i ::% is) = i `cons` go is
+ go (i ::% is) = i `cons` go is
in go list)
-listxHead :: ListX (mn ': sh) f -> f mn
+listxHead :: ListX (mn ': sh) i -> i
listxHead (i ::% _) = i
listxTail :: ListX (n : sh) i -> ListX sh i
listxTail (_ ::% sh) = sh
-listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend :: ListX sh i -> ListX sh' i -> ListX (sh ++ sh') i
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f
+listxDrop :: forall i j sh sh'. ListX sh j -> ListX (sh ++ sh') i -> ListX sh' i
listxDrop ZX long = long
listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'
-listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
+listxInit :: forall i n sh. ListX (n : sh) i -> ListX (Init (n : sh)) i
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
listxInit (_ ::% ZX) = ZX
-listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
+listxLast :: forall i n sh. ListX (n : sh) i -> i
listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
listxLast (x ::% ZX) = x
-listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g)
-listxZip ZX ZX = ZX
-listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest
-
{-# INLINE listxZipWith #-}
-listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g
- -> ListX sh h
+listxZipWith :: (i -> j -> k) -> ListX sh i -> ListX sh j -> ListX sh k
listxZipWith _ ZX ZX = ZX
listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js
@@ -188,8 +164,8 @@ listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js
-- | An index into a mixed-typed array.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
-newtype IxX sh i = IxX (ListX sh (Const i))
- deriving (Eq, Ord, Generic)
+newtype IxX sh i = IxX (ListX sh i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
pattern ZIX = IxX ZX
@@ -198,8 +174,8 @@ pattern (:.%)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
=> i -> IxX sh i -> IxX sh1 i
-pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
- where i :.% IxX shl = IxX (Const i ::% shl)
+pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) i))
+ where i :.% IxX shl = IxX (i ::% shl)
infixr 3 :.%
{-# COMPLETE ZIX, (:.%) #-}
@@ -212,25 +188,9 @@ type IIxX sh = IxX sh Int
deriving instance Show i => Show (IxX sh i)
#else
instance Show i => Show (IxX sh i) where
- showsPrec _ (IxX l) = listxShow (shows . getConst) l
+ showsPrec _ (IxX l) = listxShow shows l
#endif
-instance Functor (IxX sh) where
- {-# INLINE fmap #-}
- fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
-
-instance Foldable (IxX sh) where
- {-# INLINE foldMap #-}
- foldMap f (IxX l) = listxFoldMap (f . getConst) l
- {-# INLINE foldr #-}
- foldr _ z ZIX = z
- foldr f z (x :.% xs) = f x (foldr f z xs)
- toList = ixxToList
- null ZIX = False
- null _ = True
-
-instance NFData i => NFData (IxX sh i)
-
ixxLength :: IxX sh i -> Int
ixxLength (IxX l) = listxLength l
@@ -245,30 +205,30 @@ ixxZero' :: IShX sh -> IIxX sh
ixxZero' ZSX = ZIX
ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
+{-# INLINEABLE ixxFromList #-}
ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)
-{-# INLINEABLE ixxToList #-}
-ixxToList :: forall sh i. IxX sh i -> [i]
-ixxToList = coerce (listxToList @_ @i)
+ixxToList :: IxX sh i -> [i]
+ixxToList = Foldable.toList
ixxHead :: IxX (n : sh) i -> i
-ixxHead (IxX list) = getConst (listxHead list)
+ixxHead (IxX list) = listxHead list
ixxTail :: IxX (n : sh) i -> IxX sh i
ixxTail (IxX list) = IxX (listxTail list)
ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
-ixxAppend = coerce (listxAppend @_ @(Const i))
+ixxAppend = coerce (listxAppend @_ @i)
ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i
-ixxDrop = coerce (listxDrop @(Const i) @(Const i))
+ixxDrop = coerce (listxDrop @i @i)
ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
-ixxInit = coerce (listxInit @(Const i))
+ixxInit = coerce (listxInit @i)
ixxLast :: forall n sh i. IxX (n : sh) i -> i
-ixxLast = coerce (listxLast @(Const i))
+ixxLast = coerce (listxLast @i)
ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
ixxCast ZKX ZIX = ZIX
@@ -284,43 +244,96 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
-ixxToLinear :: IShX sh -> IIxX sh -> Int
-ixxToLinear = \sh i -> fst (go sh i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixxToLinear #-}
+ixxToLinear :: Num i => IShX sh -> IxX sh i -> i
+ixxToLinear = \sh i -> go sh i 0
where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$% sh) (i :.% ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
+ go :: Num i => IShX sh -> IxX sh i -> i -> i
+ go ZSX ZIX !a = a
+ go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i)
+{-# INLINEABLE ixxFromLinear #-}
+ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i
+ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times
+ let suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+ in fromLin0 sh suffixes
+ where
+ -- Unfold first iteration of fromLin to do the range check.
+ -- Don't inline this function at first to allow GHC to inline the outer
+ -- function and realise that 'suffixes' is shared. But then later inline it
+ -- anyway, to avoid the function call. Removing the pragma makes GHC
+ -- somehow unable to recognise that 'suffixes' can be shared in a loop.
+ {-# NOINLINE [0] fromLin0 #-}
+ fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
+ fromLin0 sh suffixes i =
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of
+ (ZSX, _) | i > 0 -> outrange sh i
+ | otherwise -> ZIX
+ ((fromSMayNat' -> n) :$% sh', suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= n then outrange sh i else
+ fromIntegral q :.% fromLin sh' suffs r
+ _ -> error "impossible"
--- * Mixed shapes
+ fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
+ fromLin ZSX _ !_ = ZIX
+ fromLin (_ :$% sh') (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in fromIntegral q :.% fromLin sh' suffs r
+ fromLin _ _ _ = error "impossible"
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+ {-# NOINLINE outrange #-}
+ outrange :: IShX sh -> Int -> a
+ outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = shxEnum'
+
+{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
+shxEnum' :: Num i => IShX sh -> [IxX sh i]
+shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+
+ fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
+ fromLin ZSX _ _ = ZIX
+ fromLin (_ :$% sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.% fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
+
+-- * Mixed shape-like lists to be used for ShX and StaticShX
+
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
+
+instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
rnf (SUnknown i) = rnf i
rnf (SKnown x) = rnf x
-instance TestEquality f => TestEquality (SMayNat i f) where
+instance TestEquality (SMayNat i) where
testEquality SUnknown{} SUnknown{} = Just Refl
testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
testEquality _ _ = Nothing
{-# INLINE fromSMayNat #-}
fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => f m -> r)
- -> SMayNat i f n -> r
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s
-fromSMayNat' :: SMayNat Int SNat n -> Int
+{-# INLINE fromSMayNat' #-}
+fromSMayNat' :: SMayNat Int n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
type family AddMaybe n m where
@@ -328,27 +341,162 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
--- | This is a newtype over 'ListX'.
+type role ListH nominal representational
+type ListH :: [Maybe Nat] -> Type -> Type
+data ListH sh i where
+ ZH :: ListH '[] i
+ ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i
+-- TODO: bring this UNPACK back when GHC no longer crashes:
+-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i
+ ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i
+deriving instance Ord i => Ord (ListH sh i)
+
+-- A manually defined instance and this INLINEABLE is needed to specialize
+-- mdot1Inner (otherwise GHC warns specialization breaks down here).
+instance Eq i => Eq (ListH sh i) where
+ {-# INLINEABLE (==) #-}
+ ZH == ZH = True
+ ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2
+ ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ListH sh i)
+#else
+instance Show i => Show (ListH sh i) where
+ showsPrec _ = listhShow shows
+#endif
+
+instance NFData i => NFData (ListH sh i) where
+ rnf ZH = ()
+ rnf (x `ConsUnknown` l) = rnf x `seq` rnf l
+ rnf (SNat `ConsKnown` l) = rnf l
+
+data UnconsListHRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n)
+listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1)
+listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes shl' (SUnknown i))
+listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes shl' (SKnown i))
+listhUncons ZH = Nothing
+
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
+listhEqType ZH ZH = Just Refl
+listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh')
+ | Just Refl <- listhEqType sh sh'
+ = Just Refl
+listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listhEqType sh sh'
+ = Just Refl
+listhEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
+listhEqual ZH ZH = Just Refl
+listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh')
+ | n == m
+ , Just Refl <- listhEqual sh sh'
+ = Just Refl
+listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listhEqual sh sh'
+ = Just Refl
+listhEqual _ _ = Nothing
+
+{-# INLINE listhFmap #-}
+listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j
+listhFmap _ ZH = ZH
+listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of
+ SUnknown y -> y `ConsUnknown` listhFmap f xs
+listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of
+ SKnown y -> y `ConsKnown` listhFmap f xs
+
+{-# INLINE listhFoldMap #-}
+listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m
+listhFoldMap _ ZH = mempty
+listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs
+listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs
+
+listhLength :: ListH sh i -> Int
+listhLength = getSum . listhFoldMap (\_ -> Sum 1)
+
+listhRank :: ListH sh i -> SNat (Rank sh)
+listhRank ZH = SNat
+listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat
+listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat
+
+{-# INLINE listhShow #-}
+listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS
+listhShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListH sh' i -> ShowS
+ go _ ZH = id
+ go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs
+ go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs
+
+listhHead :: ListH (mn ': sh) i -> SMayNat i mn
+listhHead (i `ConsUnknown` _) = SUnknown i
+listhHead (i `ConsKnown` _) = SKnown i
+
+listhTail :: ListH (n : sh) i -> ListH sh i
+listhTail (_ `ConsUnknown` sh) = sh
+listhTail (_ `ConsKnown` sh) = sh
+
+listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i
+listhAppend ZH idx' = idx'
+listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx'
+listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx'
+
+listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i
+listhDrop ZH long = long
+listhDrop (_ `ConsUnknown` short) long = case long of
+ _ `ConsUnknown` long' -> listhDrop short long'
+listhDrop (_ `ConsKnown` short) long = case long of
+ _ `ConsKnown` long' -> listhDrop short long'
+
+listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i
+listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh
+listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh
+listhInit (_ `ConsUnknown` ZH) = ZH
+listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh
+listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh
+listhInit (_ `ConsKnown` ZH) = ZH
+
+listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh))
+listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh
+listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh
+listhLast (x `ConsUnknown` ZH) = SUnknown x
+listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh
+listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh
+listhLast (x `ConsKnown` ZH) = SKnown x
+
+-- * Mixed shapes
+
+-- | This is a newtype over 'ListH'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
- deriving (Eq, Ord, Generic)
+newtype ShX sh i = ShX (ListH sh i)
+ deriving (Eq, Ord, NFData)
pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
-pattern ZSX = ShX ZX
+pattern ZSX = ShX ZH
pattern (:$%)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
-pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
- where i :$% ShX shl = ShX (i ::% shl)
+ => SMayNat i n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i))
+ where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl)
infixr 3 :$%
{-# COMPLETE ZSX, (:$%) #-}
@@ -359,17 +507,12 @@ type IShX sh = ShX sh Int
deriving instance Show i => Show (ShX sh i)
#else
instance Show i => Show (ShX sh i) where
- showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+ showsPrec _ (ShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l
#endif
instance Functor (ShX sh) where
{-# INLINE fmap #-}
- fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
-
-instance NFData i => NFData (ShX sh i) where
- rnf (ShX ZX) = ()
- rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+ fmap f (ShX l) = ShX (listhFmap (fromSMayNat (SUnknown . f) SKnown) l)
-- | This checks only whether the types are equal; unknown dimensions might
-- still differ. This corresponds to 'testEquality', except on the penultimate
@@ -401,38 +544,40 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
shxEqual _ _ = Nothing
shxLength :: ShX sh i -> Int
-shxLength (ShX l) = listxLength l
+shxLength (ShX l) = listhLength l
shxRank :: ShX sh i -> SNat (Rank sh)
-shxRank (ShX l) = listxRank l
+shxRank (ShX l) = listhRank l
-- | The number of elements in an array described by this shape.
shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
+-- We don't report the size of the list in case of errors in order not to retain the list.
+{-# INLINEABLE shxFromList #-}
shxFromList :: StaticShX sh -> [Int] -> IShX sh
-shxFromList topssh topl = go topssh topl
+shxFromList (StaticShX topssh) topl = ShX $ go topssh topl
where
- go :: StaticShX sh' -> [Int] -> IShX sh'
- go ZKX [] = ZSX
- go (SKnown sn :!% sh) (i : is)
- | i == fromSNat' sn = SKnown sn :$% go sh is
- | otherwise = error $ "shxFromList: Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is
- go _ _ = error $ "shxFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ go :: ListH sh' () -> [Int] -> ListH sh' Int
+ go ZH [] = ZH
+ go ZH _ = error $ "shxFromList: List too long (type says " ++ show (listhLength topssh) ++ ")"
+ go (ConsKnown sn sh) (i : is)
+ | i == fromSNat' sn = ConsKnown sn (go sh is)
+ | otherwise = error $ "shxFromList: Value does not match typing"
+ go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is)
+ go _ _ = error $ "shxFromList: List too short (type says " ++ show (listhLength topssh) ++ ")"
{-# INLINEABLE shxToList #-}
shxToList :: IShX sh -> [Int]
-shxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: IShX sh -> is
- go ZSX = nil
- go (smn :$% sh) = fromSMayNat' smn `cons` go sh
- in go list)
+shxToList (ShX l) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListH sh Int -> is
+ go ZH = nil
+ go (ConsUnknown i rest) = i `cons` go rest
+ go (ConsKnown sn rest) = fromSNat' sn `cons` go rest
+ in go l)
+-- If it ever matters for performance, this is unsafeCoercible.
shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
shxFromSSX ZKX = ZSX
shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
@@ -447,35 +592,40 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+shxAppend = coerce (listhAppend @_ @i)
-shxHead :: ShX (n : sh) i -> SMayNat i SNat n
-shxHead (ShX list) = listxHead list
+shxHead :: ShX (n : sh) i -> SMayNat i n
+shxHead (ShX list) = listhHead list
shxTail :: ShX (n : sh) i -> ShX sh i
-shxTail (ShX list) = ShX (listxTail list)
+shxTail (ShX list) = ShX (listhTail list)
+
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
+
+shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSh _ ZSX _ = ZSX
+shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+shxDropSSX = coerce (listhDrop @i @())
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+shxDropIx ZIX long = long
+shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long'
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+shxDropSh = coerce (listhDrop @i @i)
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listxInit @(SMayNat i SNat))
+shxInit = coerce (listhInit @i)
-shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
-shxLast = coerce (listxLast @(SMayNat i SNat))
-
-shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
-shxTakeSSX _ ZKX _ = ZSX
-shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
+shxLast = coerce (listhLast @i)
{-# INLINE shxZipWith #-}
-shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n)
-> ShX sh i -> ShX sh j -> ShX sh k
shxZipWith _ ZSX ZSX = ZSX
shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
@@ -490,22 +640,6 @@ shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
-shxEnum :: IShX sh -> [IIxX sh]
-shxEnum = shxEnum'
-
-{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
-shxEnum' :: Num i => IShX sh -> [IxX sh i]
-shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
- where
- suffixes = drop 1 (scanr (*) 1 (shxToList sh))
-
- fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
- fromLin ZSX _ _ = ZIX
- fromLin (_ :$% sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
- in fromIntegral (I# q#) :.% fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
-
shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
shxCast ZKX ZSX = Just ZSX
shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
@@ -523,20 +657,24 @@ shxCast' ssh sh = case shxCast ssh sh of
-- * Static mixed shapes
--- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+-- | The part of a shape that is statically known. (A newtype over 'ListH'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
- deriving (Eq, Ord)
+newtype StaticShX sh = StaticShX (ListH sh ())
+ deriving (NFData)
+
+instance Eq (StaticShX sh) where _ == _ = True
+instance Ord (StaticShX sh) where compare _ _ = EQ
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
-pattern ZKX = StaticShX ZX
+pattern ZKX = StaticShX ZH
pattern (:!%)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
-pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
- where i :!% StaticShX shl = StaticShX (i ::% shl)
+ => SMayNat () n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl)
+
infixr 3 :!%
{-# COMPLETE ZKX, (:!%) #-}
@@ -545,51 +683,50 @@ infixr 3 :!%
deriving instance Show (StaticShX sh)
#else
instance Show (StaticShX sh) where
- showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+ showsPrec _ (StaticShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l
#endif
-instance NFData (StaticShX sh) where
- rnf (StaticShX ZX) = ()
- rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
- rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
-
instance TestEquality StaticShX where
- testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
+ testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2
ssxLength :: StaticShX sh -> Int
-ssxLength (StaticShX l) = listxLength l
+ssxLength (StaticShX l) = listhLength l
ssxRank :: StaticShX sh -> SNat (Rank sh)
-ssxRank (StaticShX l) = listxRank l
+ssxRank (StaticShX l) = listhRank l
-- | @ssxEqType = 'testEquality'@. Provided for consistency.
ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
ssxEqType = testEquality
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKX sh' = sh'
-ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+ssxAppend = coerce (listhAppend @_ @())
-ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
-ssxHead (StaticShX list) = listxHead list
+ssxHead :: StaticShX (n : sh) -> SMayNat () n
+ssxHead (StaticShX list) = listhHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
-ssxTail (_ :!% ssh) = ssh
+ssxTail (StaticShX list) = StaticShX (listhTail list)
-ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
+ssxTakeIx _ (IxX ZX) _ = ZKX
+ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short'
ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropIx (IxX ZX) long = long
+ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long'
ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+ssxDropSh = coerce (listhDrop @() @i)
+
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listhDrop @() @())
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listxInit @(SMayNat () SNat))
+ssxInit = coerce (listhInit @())
-ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
-ssxLast = coerce (listxLast @(SMayNat () SNat))
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
+ssxLast = coerce (listhLast @())
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
@@ -599,7 +736,7 @@ ssxReplicate (SS (n :: SNat n'))
ssxIotaFrom :: StaticShX sh -> Int -> [Int]
ssxIotaFrom ZKX _ = []
-ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1)
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1)
ssxFromShX :: ShX sh i -> StaticShX sh
ssxFromShX ZSX = ZKX
@@ -632,18 +769,18 @@ type family Flatten' acc sh where
Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-- This function is currently unused
-ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh)
ssxFlatten = go (SNat @1)
where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh)
go acc ZKX = SKnown acc
go _ (SUnknown () :!% _) = SUnknown ()
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
-shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten :: IShX sh -> SMayNat Int (Flatten sh)
shxFlatten = go (SNat @1)
where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh)
go acc ZSX = SKnown acc
go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
@@ -655,8 +792,8 @@ shxFlatten = go (SNat @1)
-- | Very untyped: only length is checked (at runtime).
-instance KnownShX sh => IsList (ListX sh (Const i)) where
- type Item (ListX sh (Const i)) = i
+instance KnownShX sh => IsList (ListX sh i) where
+ type Item (ListX sh i) = i
fromList = listxFromList (knownShX @sh)
toList = listxToList
@@ -667,12 +804,7 @@ instance KnownShX sh => IsList (IxX sh i) where
toList = Foldable.toList
-- | Untyped: length and known dimensions are checked (at runtime).
-instance KnownShX sh => IsList (ShX sh Int) where
- type Item (ShX sh Int) = Int
+instance KnownShX sh => IsList (IShX sh) where
+ type Item (IShX sh) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
-
--- This needs to be at the bottom of the file to not split the file into
--- pieces; some of the shape/index stuff refers to StaticShX.
-$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |])
-{-# INLINEABLE ixxFromLinear #-}
diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs
deleted file mode 100644
index 2a86ac1..0000000
--- a/src/Data/Array/Nested/Mixed/Shape/Internal.hs
+++ /dev/null
@@ -1,59 +0,0 @@
-{-# LANGUAGE BangPatterns #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Nested.Mixed.Shape.Internal where
-
-import Language.Haskell.TH
-
-
--- | A TH stub function to avoid having to write the same code three times for
--- the three kinds of shapes.
-ixFromLinearStub :: String
- -> TypeQ -> TypeQ
- -> PatQ -> (PatQ -> PatQ -> PatQ)
- -> ExpQ -> ExpQ
- -> ExpQ
- -> DecsQ
-ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do
- let fname = mkName fname'
- typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |]
-
- locals <- [d|
- -- Unfold first iteration of fromLin to do the range check.
- -- Don't inline this function at first to allow GHC to inline the outer
- -- function and realise that 'suffixes' is shared. But then later inline it
- -- anyway, to avoid the function call. Removing the pragma makes GHC
- -- somehow unable to recognise that 'suffixes' can be shared in a loop.
- {-# NOINLINE [0] fromLin0 #-}
- fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
- fromLin0 sh suffixes i =
- if i < 0 then outrange sh i else
- case (sh, suffixes) of
- ($zshC, _) | i > 0 -> outrange sh i
- | otherwise -> $ixz
- ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) ->
- let (q, r) = i `quotRem` suff
- in if q >= n then outrange sh i else
- $ixcons (fromIntegral q) (fromLin sh' suffs r)
- _ -> error "impossible"
-
- fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i
- fromLin $zshC _ !_ = $ixz
- fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i =
- let (q, r) = i `quotRem` suff -- suff == shrSize sh'
- in $ixcons (fromIntegral q) (fromLin sh' suffs r)
- fromLin _ _ _ = error "impossible"
-
- {-# NOINLINE outrange #-}
- outrange :: $ishty sh -> Int -> a
- outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")" |]
-
- body <- [|
- \sh -> -- give this function arity 1 so that 'suffixes' is shared when
- -- it's called many times
- let suffixes = drop 1 (scanr (*) 1 ($shtolist sh))
- in fromLin0 sh suffixes |]
-
- return [SigD fname typesig
- ,FunD fname [Clause [] (NormalB body) locals]]
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 065c9fd..ecdb06d 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -18,7 +18,6 @@
module Data.Array.Nested.Permutation where
import Data.Coerce (coerce)
-import Data.Functor.Const
import Data.List (sort)
import Data.Maybe (fromMaybe)
import Data.Proxy
@@ -172,52 +171,95 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
-listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
-listxTakeLen PNil _ = ZX
-listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
-listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
-
-listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
-listxDropLen PNil sh = sh
-listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
-listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i
+listhTakeLen PNil _ = ZH
+listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh
+listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh
+listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape"
-listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
-listxPermute PNil _ = ZX
-listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
+listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i
+listhDropLen PNil sh = sh
+listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape"
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
-listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
-listxIndex _ _ _ ZX = error "Index into empty shape"
+listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i
+listhPermute PNil _ = ZH
+listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) =
+ case listhIndex i sh of
+ SUnknown x -> x `ConsUnknown` listhPermute is sh
+ SKnown x -> x `ConsKnown` listhPermute is sh
-listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
-listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh)
+listhIndex SZ (n `ConsUnknown` _) = SUnknown n
+listhIndex SZ (n `ConsKnown` _) = SKnown n
+listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh')
+ = listhIndex i sh
+listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh')
+ = listhIndex i sh
+listhIndex _ ZH = error "Index into empty shape"
-ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
-ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
+listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i
+listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh)
ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+ssxTakeLen = coerce (listhTakeLen @())
ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+ssxDropLen = coerce (listhDropLen @())
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+ssxPermute = coerce (listhPermute @())
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
+ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh)
+ssxIndex k = coerce (listhIndex @() k)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+ssxPermutePrefix = coerce (listhPermutePrefix @())
+
+shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh)
+shxTakeLen = coerce (listhTakeLen @Int)
+
+shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh)
+shxDropLen = coerce (listhDropLen @Int)
+
+shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh)
+shxPermute = coerce (listhPermute @Int)
+
+shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh)
+shxIndex k = coerce (listhIndex @i k)
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+shxPermutePrefix = coerce (listhPermutePrefix @Int)
+
+listxTakeLen :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i
+listxTakeLen PNil _ = ZX
+listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+
+listxDropLen :: forall i is sh. Perm is -> ListX sh i -> ListX (DropLen is sh) i
+listxDropLen PNil sh = sh
+listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
+
+listxPermute :: forall i is sh. Perm is -> ListX sh i -> ListX (Permute is sh) i
+listxPermute PNil _ = ZX
+listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
+ listxIndex i sh ::% listxPermute is sh
+
+listxIndex :: forall j i sh. SNat i -> ListX sh j -> j
+listxIndex SZ (n ::% _) = n
+listxIndex (SS i) (_ ::% sh) = listxIndex i sh
+listxIndex _ ZX = error "Index into empty shape"
+
+listxPermutePrefix :: forall i is sh. Perm is -> ListX sh i -> ListX (PermutePrefix is sh) i
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix = coerce (listxPermutePrefix @i)
-- * Operations on permutations
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index d687983..b448685 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -79,6 +79,7 @@ rgeneratePrim sh f =
in rfromVector sh $ VS.generate (shrSize sh) g
-- | See the documentation of 'mlift'.
+{-# INLINE rlift #-}
rlift :: forall n1 n2 a. Elt a
=> SNat n2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
@@ -86,12 +87,14 @@ rlift :: forall n1 n2 a. Elt a
rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-- | See the documentation of 'mlift2'.
+{-# INLINE rlift2 #-}
rlift2 :: forall n1 n2 n3 a. Elt a
=> SNat n3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+{-# INLINE rsumOuter1PrimP #-}
rsumOuter1PrimP :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
@@ -99,13 +102,16 @@ rsumOuter1PrimP (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (msumOuter1PrimP arr)
+{-# INLINEABLE rsumOuter1Prim #-}
rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+{-# INLINE rsumAllPrimP #-}
rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
rsumAllPrimP (Ranked arr) = msumAllPrimP arr
+{-# INLINE rsumAllPrim #-}
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -137,17 +143,21 @@ rappend arr1 arr2
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
+{-# INLINEABLE rfromVectorP #-}
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mfromVectorP (shxFromShR sh) v)
+{-# INLINEABLE rfromVector #-}
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+{-# INLINEABLE rtoVectorP #-}
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
rtoVectorP = coerce mtoVectorP
+{-# INLINEABLE rtoVector #-}
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
@@ -220,7 +230,7 @@ rfromOrthotope sn arr
rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh)
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX @n sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -333,6 +343,7 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixrFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE rdot1Inner #-}
rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
rdot1Inner arr1 arr2
| SNat <- rrank arr1
@@ -341,14 +352,15 @@ rdot1Inner arr1 arr2
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'rdot1Inner' if applicable.
+{-# INLINE rdot #-}
rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr)
+rtoXArrayPrimP (Ranked arr) = first shrFromShX (mtoXArrayPrimP arr)
rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr)
+rtoXArrayPrim (Ranked arr) = first shrFromShX (mtoXArrayPrim arr)
rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 11a8ffb..beedbcf 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -26,16 +26,11 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-#ifndef OXAR_DEFAULT_SHOW_INSTANCES
-import Data.Foldable (toList)
-#endif
-
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
@@ -65,7 +60,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Ranked n a) where
showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
+ let sh = show (shrToList (rshape arr))
in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
#endif
@@ -87,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
instance Elt a => Elt (Ranked n a) where
+ {-# INLINE mshape #-}
mshape (M_Ranked arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Ranked arr) i = Ranked (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i =
coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
@@ -104,6 +102,7 @@ instance Elt a => Elt (Ranked n a) where
mtoListOuter (M_Ranked arr) =
coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -112,6 +111,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -120,6 +120,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -139,7 +140,7 @@ instance Elt a => Elt (Ranked n a) where
type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
- mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)
+ mshapeTree (Ranked arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -149,18 +150,19 @@ instance Elt a => Elt (Ranked n a) where
marrayStrides (M_Ranked arr) = marrayStrides arr
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh' (Ranked n a))
@(Mixed sh' (Mixed (Replicate n Nothing) a))
arr)
@@ -176,6 +178,14 @@ instance Elt a => Elt (Ranked n a) where
(coerce @(MixedVecs s sh (Ranked n a))
@(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
@@ -188,6 +198,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
@@ -249,20 +263,9 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
ratan2Array = liftRanked2 matan2Array
+{-# INLINE rshape #-}
rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shrFromShX2 (mshape arr)
+rshape (Ranked arr) = coerce (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 6d61bd5..6d47ade 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,8 +1,5 @@
-{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -36,15 +33,16 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, build)
-import GHC.Generics (Generic)
+import GHC.Exts (build)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Lemmas
-import Data.Array.Nested.Mixed.Shape.Internal
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -183,7 +181,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) =
listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+listrSplitAt SZ sh = (ZR, sh)
+listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
case listrRank sh of { shlen@SNat ->
@@ -195,11 +198,6 @@ listrPermutePrefix = \perm sh ->
++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
}
where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
applyPermRFull _ ZR _ = ZR
applyPermRFull sm@SNat (i ::: perm) l =
@@ -216,8 +214,7 @@ listrPermutePrefix = \perm sh ->
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
@@ -243,8 +240,6 @@ instance Show i => Show (IxR n i) where
showsPrec _ (IxR l) = listrShow shows l
#endif
-instance NFData i => NFData (IxR sh i)
-
ixrLength :: IxR sh i -> Int
ixrLength (IxR l) = listrLength l
@@ -255,12 +250,12 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
+{-# INLINEABLE ixrFromList #-}
ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
ixrFromList = coerce (listrFromList @_ @i)
-{-# INLINEABLE ixrToList #-}
-ixrToList :: forall n i. IxR n i -> [i]
-ixrToList = coerce (listrToList @_ @i)
+ixrToList :: IxR n i -> [i]
+ixrToList = Foldable.toList
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -288,27 +283,69 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixrToLinear #-}
+ixrToLinear :: Num i => IShR m -> IxR m i -> i
+ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix)
+
+ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
+ixxFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+{-# INLINEABLE ixrFromLinear #-}
+ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i
+ixrFromLinear (ShR sh) i
+ | Refl <- lemRankReplicate (Proxy @m)
+ = ixrFromIxX $ ixxFromLinear sh i
+
+ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
+ixrFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+shrEnum :: IShR n -> [IIxR n]
+shrEnum = shrEnum'
+
+{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
+shrEnum' :: forall i n. Num i => IShR n -> [IxR n i]
+shrEnum' (ShR sh)
+ | Refl <- lemRankReplicate (Proxy @n)
+ = (unsafeCoerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh
+ -- TODO: switch to coerce once newtypes overhauled
-- * Ranked shapes
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+newtype ShR n i = ShR (ShX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor)
pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
+pattern ZSR <- ShR (matchZSR @n -> Just Refl)
+ where ZSR = ShR ZSX
+
+matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZSR _ = Nothing
pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
+pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i))
+ where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl
+ = ShR (SUnknown i :$% shl)
+
+data UnconsShRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i
+shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1)
+shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
+ | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl
+ = Just (UnconsShRRes (ShR sh') x)
+shrUncons (ShR _) = Nothing
+
infixr 3 :$:
{-# COMPLETE ZSR, (:$:) #-}
@@ -319,85 +356,140 @@ type IShR n = ShR n Int
deriving instance Show i => Show (ShR n i)
#else
instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
+ showsPrec d (ShR l) = showsPrec d l
#endif
-instance NFData i => NFData (ShR sh i)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+shrEqRank ZSR ZSR = Just Refl
+shrEqRank (_ :$: sh) (_ :$: sh')
+ | Just Refl <- shrEqRank sh sh'
+ = Just Refl
+shrEqRank _ _ = Nothing
-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+shrEqual ZSR ZSR = Just Refl
+shrEqual (i :$: sh) (i' :$: sh')
+ | Just Refl <- shrEqual sh sh'
+ , i == i'
+ = Just Refl
+shrEqual _ _ = Nothing
shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
+shrLength (ShR l) = shxLength l
-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
+shrRank :: forall n i. ShR n i -> SNat n
+shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
+shrSize (ShR sh) = shxSize sh
-shrFromList :: forall n i. SNat n -> [i] -> ShR n i
-shrFromList = coerce (listrFromList @_ @i)
+-- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@.
+-- We don't report the size of the list in case of errors in order not to retain the list.
+{-# INLINEABLE shrFromList #-}
+shrFromList :: SNat n -> [Int] -> IShR n
+shrFromList snat topl = ShR $ ShX $ go snat topl
+ where
+ go :: SNat n -> [Int] -> ListH (Replicate n Nothing) Int
+ go SZ [] = ZH
+ go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")"
+ go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is)
+ go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")"
+-- This is equivalent to but faster than @coerce shxToList@.
{-# INLINEABLE shrToList #-}
-shrToList :: forall n i. ShR n i -> [i]
-shrToList = coerce (listrToList @_ @i)
+shrToList :: IShR n -> [Int]
+shrToList (ShR (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListH sh Int -> is
+ go ZH = nil
+ go (ConsUnknown i rest) = i `cons` go rest
+ go ConsKnown{} = error "shrToList: impossible case"
+ in go l)
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
+shrHead :: forall n i. ShR (n + 1) i -> i
+shrHead (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxHead @Nothing @(Replicate n Nothing) sh of
+ SUnknown i -> i
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
+shrTail :: forall n i. ShR (n + 1) i -> ShR n i
+shrTail
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (shxTail @_ @_ @i)
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
+shrInit :: forall n i. ShR (n + 1) i -> ShR n i
+shrInit
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = -- TODO: change this and all other unsafeCoerceRefl to lemmas:
+ gcastWith (unsafeCoerceRefl
+ :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $
+ coerce (shxInit @_ @_ @i)
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
+shrLast :: forall n i. ShR (n + 1) i -> i
+shrLast (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxLast sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrLast: impossible SKnown"
-- | Performs a runtime check that the lengths are identical.
shrCast :: SNat n' -> ShR n i -> ShR n' i
-shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+shrCast SZ ZSR = ZSR
+shrCast (SS n) (i :$: sh) = i :$: shrCast n sh
+shrCast _ _ = error "shrCast: ranks don't match"
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+shrAppend =
+ -- lemReplicatePlusApp requires an SNat
+ gcastWith (unsafeCoerceRefl
+ :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $
+ coerce (shxAppend @_ @_ @i)
{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+shrZipWith _ ZSR ZSR = ZSR
+shrZipWith f (i :$: irest) (j :$: jrest) =
+ f i j :$: shrZipWith f irest jrest
+shrZipWith _ _ _ =
+ error "shrZipWith: impossible pattern needlessly required"
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i)
+shrSplitAt SZ sh = (ZSR, sh)
+shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh)
+shrSplitAt SS{} ZSR = error "m' + 1 <= 0"
-shrEnum :: IShR sh -> [IIxR sh]
-shrEnum = shrEnum'
+shrIndex :: forall k sh i. SNat k -> ShR sh i -> i
+shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrIndex: impossible SKnown"
-{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
-shrEnum' :: Num i => IShR sh -> [IxR sh i]
-shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
+-- Copy-pasted from listrPermutePrefix, probably unavoidably.
+shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i
+shrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case shrRank sh of { shlen@SNat ->
+ let sperm = shrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
where
- suffixes = drop 1 (scanr (*) 1 (shrToList sh))
-
- fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
- fromLin ZSR _ _ = ZIR
- fromLin (_ :$: sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
- in fromIntegral (I# q#) :.: fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
+ applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i
+ applyPermRFull _ ZSR _ = ZSR
+ applyPermRFull sm@SNat (i :$: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> shrIndex si l :$: applyPermRFull sm perm l
+ EQI -> shrIndex si l :$: applyPermRFull sm perm l
+ GTI -> error "shrPermutePrefix: Index in permutation out of range"
-- | Untyped: length is checked at runtime.
@@ -413,18 +505,15 @@ instance KnownNat n => IsList (IxR n i) where
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+instance KnownNat n => IsList (IShR n) where
+ type Item (IShR n) = Int
+ fromList = shrFromList (SNat @n)
+ toList = shrToList
-- * Internal helper functions
listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
listrCastWithName _ SZ ZR = ZR
-listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l
listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
-
-$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
-{-# INLINEABLE ixrFromLinear #-}
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 99ad590..36ef24a 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -56,7 +56,7 @@ ssize = shsSize . sshape
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh
shsTakeIx _ _ ZIS = ZSS
shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
@@ -70,7 +70,7 @@ sindexPartial sarr@(Shaped arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX))
-- | See 'mgeneratePrim'.
{-# INLINE sgeneratePrim #-}
@@ -81,6 +81,7 @@ sgeneratePrim sh f =
in sfromVector sh $ VS.generate (shsSize sh) g
-- | See the documentation of 'mlift'.
+{-# INLINE slift #-}
slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -88,23 +89,28 @@ slift :: forall sh1 sh2 a. Elt a
slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
-- | See the documentation of 'mlift'.
+{-# INLINE slift2 #-}
slift2 :: forall sh1 sh2 sh3 a. Elt a
=> ShS sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+{-# INLINE ssumOuter1PrimP #-}
ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
+{-# INLINEABLE ssumOuter1Prim #-}
ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+{-# INLINE ssumAllPrimP #-}
ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
ssumAllPrimP (Shaped arr) = msumAllPrimP arr
+{-# INLINE ssumAllPrim #-}
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -124,15 +130,19 @@ sappend = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
+{-# INLINEABLE sfromVectorP #-}
sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+{-# INLINEABLE sfromVector #-}
sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+{-# INLINEABLE stoVectorP #-}
stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
stoVectorP = coerce mtoVectorP
+{-# INLINEABLE stoVector #-}
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
@@ -246,21 +256,20 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape
sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
-sflatten arr =
- case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
- n@SNat -> sreshape (n :$$ ZSS) arr
+sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
-- | Throws if the array is empty.
sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
+sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr)
-- | Throws if the array is empty.
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE sdot1Inner #-}
sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
@@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
-> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
_ -> error "unreachable"
+{-# INLINE sdot #-}
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 98f1241..4b119c4 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -26,7 +26,6 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -80,9 +79,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
instance Elt a => Elt (Shaped sh a) where
+ {-# INLINE mshape #-}
mshape (M_Shaped arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Shaped arr) i = Shaped (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
mindexPartial (M_Shaped arr) i =
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
@@ -97,6 +99,7 @@ instance Elt a => Elt (Shaped sh a) where
mtoListOuter (M_Shaped arr)
= coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -105,6 +108,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -113,6 +117,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -132,7 +137,7 @@ instance Elt a => Elt (Shaped sh a) where
type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
- mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr)
+ mshapeTree (Shaped arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -142,18 +147,19 @@ instance Elt a => Elt (Shaped sh a) where
marrayStrides (M_Shaped arr) = marrayStrides arr
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWriteLinear idx (Shaped arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh2 (Shaped sh a))
@(Mixed sh2 (Mixed (MapJust sh) a))
arr)
@@ -169,6 +175,14 @@ instance Elt a => Elt (Shaped sh a) where
(coerce @(MixedVecs s sh' (Shaped sh a))
@(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
@@ -181,6 +195,10 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
@@ -242,14 +260,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s
satan2Array = liftShaped2 matan2Array
+{-# INLINE sshape #-}
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shsFromShX (mshape arr)
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
-shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
- castWith (subst1 (sym (lemMapJustCons Refl))) $
- n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
- idx)
-shsFromShX (SUnknown _ :$% _) = error "impossible"
+sshape (Shaped arr) = coerce (mshape arr)
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0d90e91..c5e3202 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,10 +1,8 @@
-{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
@@ -32,173 +30,157 @@ import Control.DeepSeq (NFData(..))
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Functor.Product qualified as Fun
import Data.Kind (Constraint, Type)
-import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
-import GHC.Generics (Generic)
+import GHC.Exts (build, withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
-- * Shaped lists
--- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
--- removed in a future release.
type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+type ListS :: [Nat] -> Type -> Type
+data ListS sh i where
+ ZS :: ListS '[] i
+ (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i
+deriving instance Eq i => Eq (ListS sh i)
+deriving instance Ord i => Ord (ListS sh i)
+
infixr 3 ::$
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance (forall n. Show (f n)) => Show (ListS sh f)
+deriving instance Show i => Show (ListS sh i)
#else
-instance (forall n. Show (f n)) => Show (ListS sh f) where
+instance Show i => Show (ListS sh i) where
showsPrec _ = listsShow shows
#endif
-instance (forall m. NFData (f m)) => NFData (ListS n f) where
+instance NFData i => NFData (ListS n i) where
rnf ZS = ()
rnf (x ::$ l) = rnf x `seq` rnf l
-data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+data UnconsListSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i
+listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqType ZS ZS = Just Refl
-listsEqType (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listsEqType sh sh'
- = Just Refl
-listsEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEqual ZS ZS = Just Refl
-listsEqual (n ::$ sh) (m ::$ sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listsEqual sh sh'
- = Just Refl
-listsEqual _ _ = Nothing
-
-{-# INLINE listsFmap #-}
-listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-listsFmap _ ZS = ZS
-listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-
-{-# INLINE listsFoldMap #-}
-listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFoldMap _ ZS = mempty
-listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs
-
-listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
where
- go :: String -> ListS sh' f -> ShowS
+ go :: String -> ListS sh' i -> ShowS
go _ ZS = id
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFoldMap (\_ -> Sum 1)
+instance Functor (ListS l) where
+ {-# INLINE fmap #-}
+ fmap _ ZS = ZS
+ fmap f (x ::$ xs) = f x ::$ fmap f xs
+
+instance Foldable (ListS l) where
+ {-# INLINE foldMap #-}
+ foldMap _ ZS = mempty
+ foldMap f (x ::$ xs) = f x <> foldMap f xs
+ {-# INLINE foldr #-}
+ foldr _ z ZS = z
+ foldr f z (x ::$ xs) = f x (foldr f z xs)
+ toList = listsToList
+ null ZS = False
+ null _ = True
+
+listsLength :: ListS sh i -> Int
+listsLength = length
-listsRank :: ListS sh f -> SNat (Rank sh)
+listsRank :: ListS sh i -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
-listsFromList :: ShS sh -> [i] -> ListS sh (Const i)
+listsFromList :: ShS sh -> [i] -> ListS sh i
listsFromList topsh topl = go topsh topl
where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go :: ShS sh' -> [i] -> ListS sh' i
go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go (_ :$$ sh) (i : is) = i ::$ go sh is
go _ _ = error $ "listsFromList: Mismatched list length (type says "
++ show (shsLength topsh) ++ ", list has length "
++ show (length topl) ++ ")"
+{-# INLINEABLE listsFromListS #-}
+listsFromListS :: ListS sh i0 -> [i] -> ListS sh i
+listsFromListS topl0 topl = go topl0 topl
+ where
+ go :: ListS sh i0 -> [i] -> ListS sh i
+ go ZS [] = ZS
+ go (_ ::$ l0) (i : is) = i ::$ go l0 is
+ go _ _ = error $ "listsFromListS: Mismatched list length (the model says "
+ ++ show (listsLength topl0) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
{-# INLINEABLE listsToList #-}
-listsToList :: ListS sh (Const i) -> [i]
+listsToList :: ListS sh i -> [i]
listsToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: ListS sh (Const i) -> is
+ let go :: ListS sh i -> is
go ZS = nil
- go (Const i ::$ is) = i `cons` go is
+ go (i ::$ is) = i `cons` go is
in go list)
-listsHead :: ListS (n : sh) f -> f n
+listsHead :: ListS (n : sh) i -> i
listsHead (i ::$ _) = i
-listsTail :: ListS (n : sh) f -> ListS sh f
+listsTail :: ListS (n : sh) i -> ListS sh i
listsTail (_ ::$ sh) = sh
-listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i
listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
listsInit (_ ::$ ZS) = ZS
-listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast :: ListS (n : sh) i -> i
listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
listsLast (n ::$ ZS) = n
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
+listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j)
listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js
+listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js
{-# INLINE listsZipWith #-}
-listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
- -> ListS sh h
+listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k
listsZipWith _ ZS ZS = ZS
listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js
-listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i
listsDropLenPerm PNil sh = sh
listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
- case listsIndex (Proxy @is') (Proxy @sh) i sh of
- (item, SNat) -> item ::$ listsPermute is sh
+ case listsIndex i sh of
+ item -> item ::$ listsPermute is sh
--- TODO: remove this SNat when the KnownNat constaint in ListS is removed
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
-listsIndex _ _ SZ (n ::$ _) = (n, SNat)
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh
-listsIndex _ _ _ ZS = error "Index into empty shape"
+-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed
+listsIndex :: forall j i sh. SNat i -> ListS sh j -> j
+listsIndex SZ (n ::$ _) = n
+listsIndex (SS i) (_ ::$ sh) = listsIndex i sh
+listsIndex _ ZS = error "Index into empty shape"
-listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-- * Shaped indices
@@ -206,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe
-- | An index into a shape-typed array.
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord, Generic)
+newtype IxS sh i = IxS (ListS sh i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
@@ -216,10 +198,10 @@ pattern ZIS = IxS ZS
-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i))
+ where i :.$ IxS shl = IxS (i ::$ shl)
infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
@@ -232,25 +214,9 @@ type IIxS sh = IxS sh Int
deriving instance Show i => Show (IxS sh i)
#else
instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+ showsPrec _ (IxS l) = listsShow (\i -> shows i) l
#endif
-instance Functor (IxS sh) where
- {-# INLINE fmap #-}
- fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
-
-instance Foldable (IxS sh) where
- {-# INLINE foldMap #-}
- foldMap f (IxS l) = listsFoldMap (f . getConst) l
- {-# INLINE foldr #-}
- foldr _ z ZIS = z
- foldr f z (x :.$ xs) = f x (foldr f z xs)
- toList = ixsToList
- null ZIS = False
- null _ = True
-
-instance NFData i => NFData (IxS sh i)
-
ixsLength :: IxS sh i -> Int
ixsLength (IxS l) = listsLength l
@@ -260,16 +226,19 @@ ixsRank (IxS l) = listsRank l
ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
ixsFromList = coerce (listsFromList @_ @i)
-{-# INLINEABLE ixsToList #-}
-ixsToList :: forall sh i. IxS sh i -> [i]
-ixsToList = coerce (listsToList @_ @i)
+{-# INLINEABLE ixsFromIxS #-}
+ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i
+ixsFromIxS = coerce (listsFromListS @_ @i0 @i)
+
+ixsToList :: IxS sh i -> [i]
+ixsToList = Foldable.toList
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
ixsHead :: IxS (n : sh) i -> i
-ixsHead (IxS list) = getConst (listsHead list)
+ixsHead (IxS list) = listsHead list
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
@@ -278,16 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
-ixsLast (IxS list) = getConst (listsLast list)
+ixsLast (IxS list) = listsLast list
--- TODO: this takes a ShS because there are KnownNats inside IxS.
-ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
-ixsCast ZSS ZIS = ZIS
-ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
-ixsCast _ _ = error "ixsCast: ranks don't match"
+ixsCast :: IxS sh i -> IxS sh i
+ixsCast ZIS = ZIS
+ixsCast (i :.$ idx) = i :.$ ixsCast idx
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
-ixsAppend = coerce (listsAppend @_ @(Const i))
+ixsAppend = coerce (listsAppend @_ @i)
ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
@@ -299,8 +266,31 @@ ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+ixsPermutePrefix = coerce (listsPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix)
+
+ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
+ixxFromIxS = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+{-# INLINEABLE ixsFromLinear #-}
+ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i
+ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i
+
+ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i
+ixsFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled
+
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' (ShS sh) = (unsafeCoerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh
+ -- TODO: switch to coerce once newtypes overhauled
-- * Shaped shapes
@@ -310,21 +300,34 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Generic)
+newtype ShS sh = ShS (ShX (MapJust sh) Int)
+ deriving (NFData)
instance Eq (ShS sh) where _ == _ = True
instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
+pattern ZSS <- ShS (matchZSX -> Just Refl)
+ where ZSS = ShS ZSX
+
+matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZSX _ = Nothing
pattern (:$$)
:: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
+ forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
+pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl))
+ where i :$$ ShS shl = ShS (SKnown i :$% shl)
+
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
+shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1)
+shsUncons (ShS (SKnown x :$% sh'))
+ | Refl <- lemMapJustCons @sh1 Refl
+ = Just (UnconsShSRes x (ShS sh'))
+shsUncons (ShS _) = Nothing
infixr 3 :$$
@@ -334,15 +337,13 @@ infixr 3 :$$
deriving instance Show (ShS sh)
#else
instance Show (ShS sh) where
- showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+ showsPrec d (ShS shx) = showsPrec d shx
#endif
-instance NFData (ShS sh) where
- rnf (ShS ZS) = ()
- rnf (ShS (SNat ::$ l)) = rnf (ShS l)
-
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+ testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of
+ Nothing -> Nothing
+ Just Refl -> Just unsafeCoerceRefl
-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
@@ -350,64 +351,106 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality
shsLength :: ShS sh -> Int
-shsLength (ShS l) = listsLength l
+shsLength (ShS shx) = shxLength shx
-shsRank :: ShS sh -> SNat (Rank sh)
-shsRank (ShS l) = listsRank l
+shsRank :: forall sh. ShS sh -> SNat (Rank sh)
+shsRank (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Rank (MapJust sh) :~: Rank sh) $
+ shxRank shx
shsSize :: ShS sh -> Int
-shsSize ZSS = 1
-shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+shsSize (ShS sh) = shxSize sh
-- | This is a partial @const@ that fails when the second argument
--- doesn't match the first.
+-- doesn't match the first. We don't report the size of the list
+-- in case of errors in order not to retain the list.
+{-# INLINEABLE shsFromList #-}
shsFromList :: ShS sh -> [Int] -> ShS sh
-shsFromList topsh topl = go topsh topl `seq` topsh
+shsFromList sh0@(ShS (ShX topsh)) topl = go topsh topl `seq` sh0
where
- go :: ShS sh' -> [Int] -> ()
- go ZSS [] = ()
- go (sn :$$ sh) (i : is)
+ go :: ListH sh' Int -> [Int] -> ()
+ go ZH [] = ()
+ go ZH _ = error $ "shsFromList: List too long (type says " ++ show (listhLength topsh) ++ ")"
+ go (ConsKnown sn sh) (i : is)
| i == fromSNat' sn = go sh is
- | otherwise = error $ "shsFromList: Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "shsFromList: Mismatched list length (type says "
- ++ show (shsLength topsh) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ | otherwise = error $ "shsFromList: Value does not match typing"
+ go ConsUnknown{} _ = error "shsFromList: impossible case"
+ go _ _ = error $ "shsFromList: List too short (type says " ++ show (listhLength topsh) ++ ")"
+-- This is equivalent to but faster than @coerce shxToList@.
{-# INLINEABLE shsToList #-}
shsToList :: ShS sh -> [Int]
-shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) ->
- let go :: ShS sh -> is
- go ZSS = nil
- go (sn :$$ sh) = fromSNat' sn `cons` go sh
- in go topsh)
+shsToList (ShS (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListH sh Int -> is
+ go ZH = nil
+ go ConsUnknown{} = error "shsToList: impossible case"
+ go (ConsKnown snat rest) = fromSNat' snat `cons` go rest
+ in go l)
shsHead :: ShS (n : sh) -> SNat n
-shsHead (ShS list) = listsHead list
+shsHead (ShS shx) = case shxHead shx of
+ SKnown SNat -> SNat
-shsTail :: ShS (n : sh) -> ShS sh
-shsTail (ShS list) = ShS (listsTail list)
+shsTail :: forall n sh. ShS (n : sh) -> ShS sh
+shsTail = coerce (shxTail @_ @_ @Int)
-shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
-shsInit (ShS list) = ShS (listsInit list)
+shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh))
+shsInit =
+ gcastWith (unsafeCoerceRefl
+ :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $
+ coerce (shxInit @_ @_ @Int)
-shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
-shsLast (ShS list) = listsLast list
+shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS shx) =
+ gcastWith (unsafeCoerceRefl
+ :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $
+ case shxLast shx of
+ SKnown SNat -> SNat
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
-shsAppend = coerce (listsAppend @_ @SNat)
+shsAppend =
+ gcastWith (unsafeCoerceRefl
+ :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $
+ coerce (shxAppend @_ @_ @Int)
+
+shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen =
+ gcastWith (unsafeCoerceRefl
+ :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $
+ coerce shxTakeLen
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLenPerm @SNat)
+shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh)
+shsDropLen =
+ gcastWith (unsafeCoerceRefl
+ :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $
+ coerce shxDropLen
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
+shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute =
+ gcastWith (unsafeCoerceRefl
+ :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $
+ coerce shxPermute
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
-shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex i (ShS sh) =
+ gcastWith (unsafeCoerceRefl
+ :: Index i (MapJust sh) :~: Just (Index i sh)) $
+ case shxIndex @_ @_ @Int i sh of
+ SKnown SNat -> SNat
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-shsPermutePrefix = coerce (listsPermutePrefix @SNat)
+shsPermutePrefix perm (ShS shx)
+ {- TODO: here and elsewhere, solve the module dependency cycle and add this:
+ | Refl <- lemTakeLenMapJust perm sh
+ , Refl <- lemDropLenMapJust perm sh
+ , Refl <- lemPermuteMapJust perm sh
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -}
+ = gcastWith (unsafeCoerceRefl
+ :: Permute is (TakeLen is (MapJust sh))
+ ++ DropLen is (MapJust sh)
+ :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $
+ ShS (shxPermutePrefix perm shx)
type family Product sh where
Product '[] = 1
@@ -435,37 +478,10 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
--- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
--- This function may be removed in a future release.
-shsFromListS :: ListS sh f -> ShS sh
-shsFromListS ZS = ZSS
-shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
-
--- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
--- function may be removed in a future release.
-shsFromIxS :: IxS sh i -> ShS sh
-shsFromIxS (IxS l) = shsFromListS l
-
-shsEnum :: ShS sh -> [IIxS sh]
-shsEnum = shsEnum'
-
-{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
-shsEnum' :: Num i => ShS sh -> [IxS sh i]
-shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]]
- where
- suffixes = drop 1 (scanr (*) 1 (shsToList sh))
-
- fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i
- fromLin ZSS _ _ = ZIS
- fromLin (_ :$$ sh') (I# suff# : suffs) i# =
- let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh'
- in fromIntegral (I# q#) :.$ fromLin sh' suffs r#
- fromLin _ _ _ = error "impossible"
-
-- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
+instance KnownShS sh => IsList (ListS sh i) where
+ type Item (ListS sh i) = i
fromList = listsFromList (knownShS @sh)
toList = listsToList
@@ -480,6 +496,3 @@ instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
fromList = shsFromList (knownShS @sh)
toList = shsToList
-
-$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |])
-{-# INLINEABLE ixsFromLinear #-}
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index a43ae0c..8bb5b85 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -46,7 +46,6 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Unsafe.Coerce qualified
-
-- Reasoning helpers
subst1 :: forall f a b. a :~: b -> f a :~: f b
@@ -59,8 +58,9 @@ subst2 Refl = Refl
data Dict c a where
Dict :: c a => Dict c a
+{-# INLINE fromSNat' #-}
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
sameNat' n@SNat m@SNat = sameNat n m
@@ -110,7 +110,7 @@ type family Replicate n a where
Replicate n a = a : Replicate (n - 1) a
lemReplicateSucc :: forall a n proxy.
- proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a
+ proxy n -> a : Replicate n a :~: Replicate (n + 1) a
lemReplicateSucc _ = unsafeCoerceRefl
type family MapJust l = r | r -> l where
diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs
index 5c38d14..e2cd17c 100644
--- a/src/Data/Array/Strided/Orthotope.hs
+++ b/src/Data/Array/Strided/Orthotope.hs
@@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve
toO :: AS.Array n a -> RS.Array n a
toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
+{-# INLINE liftO1 #-}
liftO1 :: (AS.Array n a -> AS.Array n' b)
-> RS.Array n a -> RS.Array n' b
liftO1 f = toO . f . fromO
+{-# INLINE liftO2 #-}
liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
-> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
liftO2 f x y = toO (f (fromO x) (fromO y))
+-- We don't inline this lifting function, because its code is not just
+-- a wrapper, being relatively long and expensive.
+{-# INLINEABLE liftVEltwise1 #-}
liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
-> (VS.Vector a -> VS.Vector b)
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 1445ce6..4f5bb08 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -17,7 +17,7 @@
module Data.Array.XArray where
import Control.DeepSeq (NFData)
-import Control.Monad (foldM)
+import Control.Monad (foldM_, foldM)
import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
@@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
+import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
@@ -62,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
go _ _ = error "Invalid shapeL"
+{-# INLINEABLE fromVector #-}
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
@@ -87,7 +88,7 @@ cast ssh1 sh2 ssh' (XArray arr)
| Refl <- lemRankApp ssh1 ssh'
, Refl <- lemRankApp (ssxFromShX sh2) ssh'
= let arrsh :: IShX sh1
- (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
in if shxToList arrsh == shxToList sh2
then XArray arr
else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
@@ -184,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f xarr@(XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
+ = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
in if 0 `elem` shxToList sh
then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
@@ -211,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
+ = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
in if 0 `elem` shxToList sh
then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
@@ -274,14 +275,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr)
sh'F = shxFlatten sh' :$% ZSX
ssh'F = ssxFromShX sh'F
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
- , let sn = listxRank (let StaticShX l = ssh in l)
+ , let sn = ssxRank ssh
= XArray (liftO1 (numEltSum1Inner sn) arr')
in go $
@@ -294,7 +295,7 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
shF = shxFlatten sh :$% ZSX
in sumInner ssh' (ssxFromShX shF) $
transpose2 (ssxFromShX shF) ssh' $
@@ -305,50 +306,48 @@ sumOuter ssh ssh' arr
-- the list's spine must be fully materialised to compute its length before
-- constructing the array. The list can't be empty (not enough information
-- in the given shape to guess the shape of the empty array, in general).
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX (ssxTail ssh)
- , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
- = case ssh of
- _ :!% ZKX ->
- fromList1 ssh (map S.unScalar l')
- SKnown m :!% _ ->
- let n = fromSNat' m
- in XArray (ravelOuterN n l')
- _ ->
- let n = length l
- in XArray (ravelOuterN n l')
+{-# INLINE fromListOuterSN #-}
+fromListOuterSN :: forall n sh a. Storable a
+ => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a
+fromListOuterSN m sh l
+ | Dict <- lemKnownNatRank sh
+ , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l
+ = case sh of
+ ZSX -> fromList1SN m (map S.unScalar (toList l'))
+ _ -> XArray (ravelOuterN (fromSNat' m) l')
-- | This checks that the list has the given length and that all shapes in the
-- list are equal. The list must be non-empty, and is streamed.
+{-# INLINEABLE ravelOuterN #-}
ravelOuterN :: (KnownNat k, Storable a)
- => Int -> [S.Array k a] -> S.Array (1 + k) a
+ => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a
ravelOuterN 0 _ = error "ravelOuterN: N == 0"
-ravelOuterN _ [] = error "ravelOuterN: empty list"
-ravelOuterN k as@(a0 : _) = runST $ do
+ravelOuterN k as@(a0 :| _) = runST $ do
let sh0 = S.shapeL a0
len = product sh0
vecSize = k * len
vec <- VSM.unsafeNew vecSize
- let f !n a =
+ let f !n (ORS.A (ORG.A sht t)) =
if | n >= k ->
error $ "ravelOuterN: list too long " ++ show (n, k)
-- if we do this check just once at the end, we may
-- crash instead of producing an accurate error message
- | S.shapeL a == sh0 -> do
- VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a)
- return $! n + 1
+ | sht == sh0 -> do
+ let g off el = do
+ VS.unsafeCopy (VSM.slice off (VS.length el) vec) el
+ return $! off + VS.length el
+ foldM_ g (n * len) (OI.toVectorListT sht t)
+ return $! n + 1
| otherwise ->
- error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0)
+ error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0)
nFinal <- foldM f 0 as
if nFinal == k
then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec
else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
- case S.shapeL arr of
+toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) =
+ case shArr of
[] -> error "impossible"
0 : _ -> []
-- using orthotope's functions here would entail using rerank, which is slow, so we don't
@@ -358,15 +357,20 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
-- the list's spine must be fully materialised to compute its length before
-- constructing the array.
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- case ssh of
- SKnown m :!% _ ->
- let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
- in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
- _ ->
- let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
- in XArray (S.fromVector [n] (VS.fromListN n l))
+{-# INLINE fromList1 #-}
+fromList1 :: Storable a => [a] -> XArray '[Nothing] a
+fromList1 l =
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
+
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
+{-# INLINE fromList1SN #-}
+fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a
+fromList1SN m l =
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr
diff --git a/test/Gen.hs b/test/Gen.hs
index 4f5fe96..789a59c 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -11,7 +11,6 @@
module Gen where
import Data.ByteString qualified as BS
-import Data.Foldable (toList)
import Data.Type.Equality
import Data.Type.Ord
import Data.Vector.Storable qualified as VS
@@ -46,7 +45,7 @@ genLowBiased (lo, hi) = do
return (lo + x * x * x * (hi - lo))
shuffleShR :: IShR n -> Gen (IShR n)
-shuffleShR = \sh -> go (length sh) (toList sh) sh
+shuffleShR = \sh -> go (shrLength sh) (shrToList sh) sh
where
go :: Int -> [Int] -> IShR n -> Gen (IShR n)
go _ _ ZSR = return ZSR
@@ -78,7 +77,7 @@ genShRwithTarget targetMax sn = do
dims <- genDims m (if dim == 0 then 0 else tgt `div` dim)
return (dim :$: dims)
dims <- genDims sn targetSize
- let maxdim = maximum dims
+ let maxdim = maximum $ shrToList dims
cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize)
shuffleShR (min cap <$> dims)
@@ -139,7 +138,7 @@ genStaticShX = \n k -> case n of
genStaticShX n' $ \ssh ->
k (item :!% ssh)
where
- genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m ()
+ genItem :: Monad m => (forall n. SMayNat () n -> PropertyT m ()) -> PropertyT m ()
genItem k = do
b <- forAll Gen.bool
if b
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index e26c3dd..2d35cd9 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -15,7 +15,6 @@ module Tests.C where
import Control.Monad
import Data.Array.RankedS qualified as OR
-import Data.Foldable (toList)
import Data.Functor.Const
import Data.Type.Equality
import Foreign
@@ -50,10 +49,10 @@ gen_red_nonempty f = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
let inrank = SNat @(n + 1)
sh <- forAll $ genShR inrank
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- guard (all (> 0) (shrTail sh)) -- only constrain the tail
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
- genStorables (Range.singleton (product sh))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (shrSize sh))
+ guard (all (> 0) (shrToList $ shrTail sh)) -- only constrain the tail
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (shrToList sh) <$>
+ genStorables (Range.singleton (shrSize sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
f inrank outrank arr
@@ -67,19 +66,19 @@ gen_red_empty f = property $ genRank $ \outrankm1@(SNat @nm1) -> do
sht <- shuffleShR (0 :$: shtt) -- n
n <- Gen.int (Range.linear 0 20)
return (n :$: sht) -- n + 1
- guard (0 `elem` shrTail sh)
- -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
- let arr = OR.fromList @(n + 1) @Double (toList sh) []
+ guard (0 `elem` (shrToList $ shrTail sh))
+ -- traceM ("sh: " ++ show sh ++ " -> " ++ show (shrSize sh))
+ let arr = OR.fromList @(n + 1) @Double (shrToList sh) []
f inrank arr
gen_red_lasteq1 :: (forall n. SNat (n + 1) -> SNat n -> OR.Array (n + 1) Double -> PropertyT IO ()) -> Property
gen_red_lasteq1 f = property $ genRank $ \outrank@(SNat @n) -> do
let inrank = SNat @(n + 1)
outsh <- forAll $ genShR outrank
- guard (all (> 0) outsh)
+ guard (all (> 0) $ shrToList outsh)
let insh = shrAppend outsh (1 :$: ZSR)
- arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$>
- genStorables (Range.singleton (product insh))
+ arr <- forAllT $ OR.fromVector @Double @(n + 1) (shrToList insh) <$>
+ genStorables (Range.singleton (shrSize insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
f inrank outrank arr
@@ -96,12 +95,12 @@ gen_red_replicated doTranspose f = property $
label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1)))
label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int)))
label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int)))
- guard (all (> 0) sh3)
+ guard (all (> 0) $ shrToList sh3)
arr <- forAllT $
- OR.stretch (toList sh3)
- . OR.reshape (toList sh2)
- . OR.fromVector @Double @m (toList sh1) <$>
- genStorables (Range.singleton (product sh1))
+ OR.stretch (shrToList sh3)
+ . OR.reshape (shrToList sh2)
+ . OR.fromVector @Double @m (shrToList sh1) <$>
+ genStorables (Range.singleton (shrSize sh1))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
arrTrans <-
if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2)
@@ -160,9 +159,9 @@ prop_negate_with :: forall f b. Show b
prop_negate_with genRank' genB preproc = property $
genRank' $ \extra rank@(SNat @n) -> do
sh <- forAll $ genShR rank
- guard (all (> 0) sh)
- arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
- genStorables (Range.singleton (product sh))
+ guard (all (> 0) $ shrToList sh)
+ arr <- forAllT $ OR.fromVector @Double @n (shrToList sh) <$>
+ genStorables (Range.singleton (shrSize sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
bval <- forAll $ genB extra sh
let arr' = preproc extra bval arr
@@ -202,7 +201,7 @@ tests = testGroup "C"
(\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1))
len <- Gen.integral (Range.constant 0 (n-lo-1))
return (lo, len)
- pairs <- mapM genPair (toList sh)
+ pairs <- mapM genPair (shrToList sh)
return pairs)
(\_ -> OR.slice)
]