aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-13 13:09:17 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-13 13:09:17 +0200
commitc6b912051ddac25c9d7efe2f8162eac9068a335c (patch)
tree1d19b621fa2b536ee4c409f24e34c769c82702bc /src/Data/Array/Nested/Internal
parent20173c939486ed6e27b8170e94f666d8ae3df152 (diff)
Add KnownShape generators from ShS
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 4fa4284..f449536 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -25,6 +25,7 @@
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Shape where
+import Data.Array.Shape qualified as O
import Data.Array.Mixed.Types
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
@@ -237,6 +238,9 @@ shrTail (ShR list) = ShR (listrTail list)
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
+-- | This function can also be used to conjure up a 'KnownNat' dictionary;
+-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
+-- synonym yields 'KnownNat' evidence.
shrRank :: ShR n i -> SNat n
shrRank (ShR sh) = listrRank sh
@@ -492,6 +496,14 @@ class KnownShS sh where knownShS :: ShS sh
instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+shsKnownShS :: ShS sh -> Dict KnownShS sh
+shsKnownShS ZSS = Dict
+shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict
+
+shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
+shsOrthotopeShape ZSS = Dict
+shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where