From f8e131d7924c24e0ed015507e2299638b72b6a57 Mon Sep 17 00:00:00 2001
From: Mikolaj Konarski <mikolaj.konarski@gmail.com>
Date: Tue, 23 Apr 2024 16:53:20 +0200
Subject: Define and expose the recomputing of KnownINat for ranks

---
 src/Data/Array/Nested.hs          |  6 +++---
 src/Data/Array/Nested/Internal.hs | 14 ++++++++++----
 2 files changed, 13 insertions(+), 7 deletions(-)

(limited to 'src/Data')

diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 7f2c232..9cb3182 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -3,9 +3,9 @@
 module Data.Array.Nested (
   -- * Ranked arrays
   Ranked,
-  ListR, pattern (:::), pattern ZR,
-  IxR(..), pattern (:.:), pattern ZIR, IIxR,
-  StaticShapeR(..), pattern (:$:), pattern ZSR,
+  ListR, pattern (:::), pattern ZR, knownListR,
+  IxR(..), pattern (:.:), pattern ZIR, IIxR, knownIxR,
+  StaticShapeR(..), pattern (:$:), pattern ZSR, knownStaticShapeR,
   rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
   rtranspose, rappend, rscalar, rfromVector, runScalar,
   rconstant, rfromList, rfromList1, rtoList, rtoList1,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index fb2ae48..594abd4 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -872,6 +872,10 @@ listRToList :: ListR n i -> [i]
 listRToList ZR = []
 listRToList (i ::: is) = i : listRToList is
 
+knownListR :: ListR n i -> Dict KnownINat n
+knownListR ZR = Dict
+knownListR (_ ::: l) | Dict <- knownListR l = Dict
+
 -- | An index into a rank-typed array.
 type role IxR nominal representational
 type IxR :: INat -> Type -> Type
@@ -904,6 +908,9 @@ unconsIxR (IxR sh) = case sh of
 
 type IIxR n = IxR n Int
 
+knownIxR :: IxR n i -> Dict KnownINat n
+knownIxR (IxR sh) = knownListR sh
+
 type role StaticShapeR nominal representational
 type StaticShapeR :: INat -> Type -> Type
 newtype StaticShapeR n i = StaticShapeR (ListR n i)
@@ -933,6 +940,9 @@ unconsStaticShapeR (StaticShapeR sh) = case sh of
   i ::: sh' -> Just (UnconsStaticShapeRRes (StaticShapeR sh') i)
   ZR -> Nothing
 
+knownStaticShapeR :: StaticShapeR n i -> Dict KnownINat n
+knownStaticShapeR (StaticShapeR sh) = knownListR sh
+
 zeroIxR :: SINat n -> IIxR n
 zeroIxR SZ = ZIR
 zeroIxR (SS n) = 0 :.: zeroIxR n
@@ -946,10 +956,6 @@ ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
 ixCvtRX ZIR = ZIX
 ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
 
-knownIxR :: IIxR n -> Dict KnownINat n
-knownIxR ZIR = Dict
-knownIxR (_ :.: idx) | Dict <- knownIxR idx = Dict
-
 shapeSizeR :: IIxR n -> Int
 shapeSizeR ZIR = 1
 shapeSizeR (n :.: sh) = n * shapeSizeR sh
-- 
cgit v1.2.3-70-g09d2