aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-05-14 11:38:22 +0200
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-05-14 11:38:22 +0200
commitffa8dacb1d7ea53438f784bf5f8b425b8cd48f46 (patch)
tree154026b503f334bf14851a3826d1562c5cfb9d08
parent978919b5be0fe74f5da1e071e97557d2e30f0ad2 (diff)
Split and uniformly rename Shape modules
-rw-r--r--ox-arrays.cabal5
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs2
-rw-r--r--src/Data/Array/Mixed/Permutation.hs2
-rw-r--r--src/Data/Array/Mixed/XArray.hs2
-rw-r--r--src/Data/Array/Nested.hs5
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Lemmas.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs4
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs (renamed from src/Data/Array/Mixed/Shape.hs)2
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs363
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs (renamed from src/Data/Array/Nested/Internal/Shape.hs)325
-rw-r--r--test/Gen.hs3
-rw-r--r--test/Tests/C.hs2
15 files changed, 386 insertions, 343 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index d47cdfa..2a15cbb 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -50,15 +50,16 @@ library
Data.Array.Mixed.Internal.Arith
Data.Array.Mixed.Lemmas
Data.Array.Mixed.Permutation
- Data.Array.Mixed.Shape
Data.Array.Mixed.Types
Data.Array.Mixed.XArray
Data.Array.Nested.Internal.Convert
Data.Array.Nested.Internal.Mixed
Data.Array.Nested.Internal.Lemmas
Data.Array.Nested.Internal.Ranked
- Data.Array.Nested.Internal.Shape
Data.Array.Nested.Internal.Shaped
+ Data.Array.Nested.Mixed.Shape
+ Data.Array.Nested.Ranked.Shape
+ Data.Array.Nested.Shaped.Shape
Data.Bag
if flag(trace-wrappers)
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
index 560f762..ca82573 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -13,7 +13,7 @@ import Data.Type.Equality
import GHC.TypeLits
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index cedfa22..22672cb 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -29,7 +29,7 @@ import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index cb790e1..3e7a498 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -34,7 +34,7 @@ import GHC.TypeLits
import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9869cba..8198a54 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -104,13 +104,14 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Convert
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Strided.Arith
import Foreign.Storable
import GHC.TypeLits
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
index c316161..5d6cee4 100644
--- a/src/Data/Array/Nested/Internal/Convert.hs
+++ b/src/Data/Array/Nested/Internal/Convert.hs
@@ -12,12 +12,12 @@ import Data.Proxy
import Data.Type.Equality
import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Internal.Shaped
diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs
index f894f78..b8baf96 100644
--- a/src/Data/Array/Nested/Internal/Lemmas.hs
+++ b/src/Data/Array/Nested/Internal/Lemmas.hs
@@ -11,9 +11,9 @@ import GHC.TypeLits
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Shaped.Shape
lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index a2f9737..b76aa50 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -45,7 +45,7 @@ import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index daf0374..368e337 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -40,12 +40,12 @@ import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 372439f..86dcee2 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -40,13 +40,13 @@ import GHC.TypeLits
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.XArray (XArray)
import Data.Array.Mixed.XArray qualified as X
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Shaped.Shape
import Data.Array.Strided.Arith
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index eb8434f..5f4775c 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -20,7 +20,7 @@
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Shape where
+module Data.Array.Nested.Mixed.Shape where
import Control.DeepSeq (NFData(..))
import Data.Bifunctor (first)
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
new file mode 100644
index 0000000..1c0b9eb
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -0,0 +1,363 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked.Shape where
+
+import Control.DeepSeq (NFData(..))
+import Data.Array.Mixed.Types
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Kind (Type)
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Nested.Mixed.Shape
+
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+
+instance NFData i => NFData (ListR n i) where
+ rnf ZR = ()
+ rnf (x ::: l) = rnf x `seq` rnf l
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqRank ZR ZR = Just Refl
+listrEqRank (_ ::: sh) (_ ::: sh')
+ | Just Refl <- listrEqRank sh sh'
+ = Just Refl
+listrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqual ZR ZR = Just Refl
+listrEqual (i ::: sh) (j ::: sh')
+ | Just Refl <- listrEqual sh sh'
+ , i == j
+ = Just Refl
+listrEqual _ _ = Nothing
+
+listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR n' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrLength :: ListR n i -> Int
+listrLength = length
+
+listrRank :: ListR n i -> SNat n
+listrRank ZR = SNat
+listrRank (_ ::: sh) = snatSucc (listrRank sh)
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrHead :: ListR (n + 1) i -> i
+listrHead (i ::: _) = i
+listrHead ZR = error "unreachable"
+
+listrTail :: ListR (n + 1) i -> ListR n i
+listrTail (_ ::: sh) = sh
+listrTail ZR = error "unreachable"
+
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
+listrZip ZR ZR = ZR
+listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
+listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+
+listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
+listrZipWith _ ZR ZR = ZR
+listrZipWith f (i ::: irest) (j ::: jrest) =
+ f i j ::: listrZipWith f irest jrest
+listrZipWith _ _ _ =
+ error "listrZipWith: impossible pattern needlessly required"
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrRank sperm, listrRank sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+type IIxR n = IxR n Int
+
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+
+instance NFData i => NFData (IxR sh i)
+
+ixrLength :: IxR sh i -> Int
+ixrLength (IxR l) = listrLength l
+
+ixrRank :: IxR n i -> SNat n
+ixrRank (IxR sh) = listrRank sh
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
+ixCvtXR ZIX = ZIR
+ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
+
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
+ixCvtRX ZIR = ZIX
+ixCvtRX (n :.: (idx :: IxR m Int)) =
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixCvtRX idx)
+
+ixrHead :: IxR (n + 1) i -> i
+ixrHead (IxR list) = listrHead list
+
+ixrTail :: IxR (n + 1) i -> IxR n i
+ixrTail (IxR list) = IxR (listrTail list)
+
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
+ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
+ixrAppend = coerce (listrAppend @_ @i)
+
+ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
+ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+
+ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
+ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord, Generic)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: ShR sh = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+
+instance NFData i => NFData (ShR sh i)
+
+shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
+shCvtXR' ZSX =
+ castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
+ ZSR
+shCvtXR' (n :$% (idx :: IShX sh))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
+ castWith (subst2 (lem1 @sh Refl))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ where
+ lem1 :: forall sh' n' k.
+ k : sh' :~: Replicate n' Nothing
+ -> Rank sh' + 1 :~: n'
+ lem1 Refl = unsafeCoerceRefl
+
+ lem2 :: k : sh :~: Replicate n Nothing
+ -> sh :~: Replicate (Rank sh) Nothing
+ lem2 Refl = unsafeCoerceRefl
+
+shCvtRX :: IShR n -> IShX (Replicate n Nothing)
+shCvtRX ZSR = ZSX
+shCvtRX (n :$: (idx :: ShR m Int)) =
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shCvtRX idx)
+
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+
+-- | This compares the shapes for value equality.
+shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+
+shrLength :: ShR sh i -> Int
+shrLength (ShR l) = listrLength l
+
+-- | This function can also be used to conjure up a 'KnownNat' dictionary;
+-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
+-- synonym yields 'KnownNat' evidence.
+shrRank :: ShR n i -> SNat n
+shrRank (ShR sh) = listrRank sh
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrHead :: ShR (n + 1) i -> i
+shrHead (ShR list) = listrHead list
+
+shrTail :: ShR (n + 1) i -> ShR n i
+shrTail (ShR list) = ShR (listrTail list)
+
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
+shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
+shrAppend = coerce (listrAppend @_ @i)
+
+shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
+shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+
+shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
+shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList topl = go (SNat @n) topl
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
+ ++ show (fromSNat (SNat @n)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 97b9456..6c43fa7 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -24,7 +24,7 @@
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Shape where
+module Data.Array.Nested.Shaped.Shape where
import Control.DeepSeq (NFData(..))
import Data.Array.Mixed.Types
@@ -42,331 +42,10 @@ import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
-import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
-infixr 3 :::
-
-instance Show i => Show (ListR n i) where
- showsPrec _ = listrShow shows
-
-instance NFData i => NFData (ListR n i) where
- rnf ZR = ()
- rnf (x ::: l) = rnf x `seq` rnf l
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqRank ZR ZR = Just Refl
-listrEqRank (_ ::: sh) (_ ::: sh')
- | Just Refl <- listrEqRank sh sh'
- = Just Refl
-listrEqRank _ _ = Nothing
-
--- | This compares the lists for value equality.
-listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
-listrEqual ZR ZR = Just Refl
-listrEqual (i ::: sh) (j ::: sh')
- | Just Refl <- listrEqual sh sh'
- , i == j
- = Just Refl
-listrEqual _ _ = Nothing
-
-listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
-listrShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR n' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrLength :: ListR n i -> Int
-listrLength = length
-
-listrRank :: ListR n i -> SNat n
-listrRank ZR = SNat
-listrRank (_ ::: sh) = snatSucc (listrRank sh)
-
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend ZR sh = sh
-listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
-
-listrHead :: ListR (n + 1) i -> i
-listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
-
-listrTail :: ListR (n + 1) i -> ListR n i
-listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
-
-listrInit :: ListR (n + 1) i -> ListR n i
-listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
-listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
-
-listrLast :: ListR (n + 1) i -> i
-listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
-listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
-listrZip ZR ZR = ZR
-listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
-listrZip _ _ = error "listrZip: impossible pattern needlessly required"
-
-listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
-listrZipWith _ ZR ZR = ZR
-listrZipWith f (i ::: irest) (j ::: jrest) =
- f i j ::: listrZipWith f irest jrest
-listrZipWith _ _ _ =
- error "listrZipWith: impossible pattern needlessly required"
-
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "listrPermutePrefix: Index in permutation out of range"
-
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
- where i :.: IxR sh = IxR (i ::: sh)
-infixr 3 :.:
-
-{-# COMPLETE ZIR, (:.:) #-}
-
-type IIxR n = IxR n Int
-
-instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = listrShow shows l
-
-instance NFData i => NFData (IxR sh i)
-
-ixrLength :: IxR sh i -> Int
-ixrLength (IxR l) = listrLength l
-
-ixrRank :: IxR n i -> SNat n
-ixrRank (IxR sh) = listrRank sh
-
-ixrZero :: SNat n -> IIxR n
-ixrZero SZ = ZIR
-ixrZero (SS n) = 0 :.: ixrZero n
-
-ixCvtXR :: IIxX sh -> IIxR (Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixCvtRX idx)
-
-ixrHead :: IxR (n + 1) i -> i
-ixrHead (IxR list) = listrHead list
-
-ixrTail :: IxR (n + 1) i -> IxR n i
-ixrTail (IxR list) = IxR (listrTail list)
-
-ixrInit :: IxR (n + 1) i -> IxR n i
-ixrInit (IxR list) = IxR (listrInit list)
-
-ixrLast :: IxR (n + 1) i -> i
-ixrLast (IxR list) = listrLast list
-
-ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @_ @i)
-
-ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
-ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
-
-ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
-ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
-type role ShR nominal representational
-type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
-
-pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
-infixr 3 :$:
-
-{-# COMPLETE ZSR, (:$:) #-}
-
-type IShR n = ShR n Int
-
-instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
-
-instance NFData i => NFData (ShR sh i)
-
-shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
-shCvtXR' ZSX =
- castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
- ZSR
-shCvtXR' (n :$% (idx :: IShX sh))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
- castWith (subst2 (lem1 @sh Refl))
- (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
- where
- lem1 :: forall sh' n' k.
- k : sh' :~: Replicate n' Nothing
- -> Rank sh' + 1 :~: n'
- lem1 Refl = unsafeCoerceRefl
-
- lem2 :: k : sh :~: Replicate n Nothing
- -> sh :~: Replicate (Rank sh) Nothing
- lem2 Refl = unsafeCoerceRefl
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shCvtRX idx)
-
--- | This checks only whether the ranks are equal, not whether the actual
--- values are.
-shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
-
--- | This compares the shapes for value equality.
-shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
-
-shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
-
--- | This function can also be used to conjure up a 'KnownNat' dictionary;
--- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
--- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
-
--- | The number of elements in an array described by this shape.
-shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
-
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
-
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
-
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
-
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
-
-shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
-
-shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
-
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
-
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList topl = go (SNat @n) topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
- ++ show (fromSNat (SNat @n)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (IxR n i) where
- type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+import Data.Array.Nested.Mixed.Shape
type role ListS nominal representational
diff --git a/test/Gen.hs b/test/Gen.hs
index bf002ca..50c671f 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -21,10 +21,9 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Ranked.Shape
import Hedgehog
import Hedgehog.Gen qualified as Gen
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 4861eb1..72bf8f6 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -20,7 +20,7 @@ import GHC.TypeLits
import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested
-import Data.Array.Nested.Internal.Shape
+import Data.Array.Nested.Ranked.Shape
import Hedgehog
import Hedgehog.Gen qualified as Gen