aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@gmail.com>2024-04-23 16:53:20 +0200
committerMikolaj Konarski <mikolaj.konarski@gmail.com>2024-04-23 16:53:20 +0200
commitf8e131d7924c24e0ed015507e2299638b72b6a57 (patch)
treed2092c236168d3a6381bc4a49b7c41fe3d1690e0
parentda0f6fa4515dbb2c4b794e6418fd0633415af17d (diff)
Define and expose the recomputing of KnownINat for ranks
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs14
2 files changed, 13 insertions, 7 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 7f2c232..9cb3182 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -3,9 +3,9 @@
module Data.Array.Nested (
-- * Ranked arrays
Ranked,
- ListR, pattern (:::), pattern ZR,
- IxR(..), pattern (:.:), pattern ZIR, IIxR,
- StaticShapeR(..), pattern (:$:), pattern ZSR,
+ ListR, pattern (:::), pattern ZR, knownListR,
+ IxR(..), pattern (:.:), pattern ZIR, IIxR, knownIxR,
+ StaticShapeR(..), pattern (:$:), pattern ZSR, knownStaticShapeR,
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, runScalar,
rconstant, rfromList, rfromList1, rtoList, rtoList1,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index fb2ae48..594abd4 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -872,6 +872,10 @@ listRToList :: ListR n i -> [i]
listRToList ZR = []
listRToList (i ::: is) = i : listRToList is
+knownListR :: ListR n i -> Dict KnownINat n
+knownListR ZR = Dict
+knownListR (_ ::: l) | Dict <- knownListR l = Dict
+
-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: INat -> Type -> Type
@@ -904,6 +908,9 @@ unconsIxR (IxR sh) = case sh of
type IIxR n = IxR n Int
+knownIxR :: IxR n i -> Dict KnownINat n
+knownIxR (IxR sh) = knownListR sh
+
type role StaticShapeR nominal representational
type StaticShapeR :: INat -> Type -> Type
newtype StaticShapeR n i = StaticShapeR (ListR n i)
@@ -933,6 +940,9 @@ unconsStaticShapeR (StaticShapeR sh) = case sh of
i ::: sh' -> Just (UnconsStaticShapeRRes (StaticShapeR sh') i)
ZR -> Nothing
+knownStaticShapeR :: StaticShapeR n i -> Dict KnownINat n
+knownStaticShapeR (StaticShapeR sh) = knownListR sh
+
zeroIxR :: SINat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n
@@ -946,10 +956,6 @@ ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
-knownIxR :: IIxR n -> Dict KnownINat n
-knownIxR ZIR = Dict
-knownIxR (_ :.: idx) | Dict <- knownIxR idx = Dict
-
shapeSizeR :: IIxR n -> Int
shapeSizeR ZIR = 1
shapeSizeR (n :.: sh) = n * shapeSizeR sh