diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Convert.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 737 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 4 | 
6 files changed, 9 insertions, 746 deletions
| diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs index c316161..5d6cee4 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -12,12 +12,12 @@ import Data.Proxy  import Data.Type.Equality  import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Lemmas  import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Shaped.Shape  import Data.Array.Nested.Internal.Shaped diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs index f894f78..b8baf96 100644 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ b/src/Data/Array/Nested/Internal/Lemmas.hs @@ -11,9 +11,9 @@ import GHC.TypeLits  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Shaped.Shape  lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index a2f9737..b76aa50 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -45,7 +45,7 @@ import Unsafe.Coerce (unsafeCoerce)  import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Mixed.XArray (XArray(..))  import Data.Array.Mixed.XArray qualified as X diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index daf0374..368e337 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -40,12 +40,12 @@ import GHC.TypeNats qualified as TN  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Mixed.XArray (XArray(..))  import Data.Array.Mixed.XArray qualified as X  import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape  import Data.Array.Strided.Arith diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs deleted file mode 100644 index 97b9456..0000000 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ /dev/null @@ -1,737 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shape where - -import Control.DeepSeq (NFData(..)) -import Data.Array.Mixed.Types -import Data.Array.Shape qualified as O -import Data.Coerce (coerce) -import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product qualified as Fun -import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import GHC.Exts (withDict) -import GHC.Generics (Generic) -import GHC.IsList (IsList) -import GHC.IsList qualified as IsList -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape - - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where -  ZR :: ListR 0 i -  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where -  showsPrec _ = listrShow shows - -instance NFData i => NFData (ListR n i) where -  rnf ZR = () -  rnf (x ::: l) = rnf x `seq` rnf l - -data UnconsListRRes i n1 = -  forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqRank ZR ZR = Just Refl -listrEqRank (_ ::: sh) (_ ::: sh') -  | Just Refl <- listrEqRank sh sh' -  = Just Refl -listrEqRank _ _ = Nothing - --- | This compares the lists for value equality. -listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqual ZR ZR = Just Refl -listrEqual (i ::: sh) (j ::: sh') -  | Just Refl <- listrEqual sh sh' -  , i == j -  = Just Refl -listrEqual _ _ = Nothing - -listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" -  where -    go :: String -> ListR n' i -> ShowS -    go _ ZR = id -    go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrLength :: ListR n i -> Int -listrLength = length - -listrRank :: ListR n i -> SNat n -listrRank ZR = SNat -listrRank (_ ::: sh) = snatSucc (listrRank sh) - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrHead :: ListR (n + 1) i -> i -listrHead (i ::: _) = i -listrHead ZR = error "unreachable" - -listrTail :: ListR (n + 1) i -> ListR n i -listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" - -listrInit :: ListR (n + 1) i -> ListR n i -listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh -listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" - -listrLast :: ListR (n + 1) i -> i -listrLast (_ ::: sh@(_ ::: _)) = listrLast sh -listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrZip :: ListR n i -> ListR n j -> ListR n (i, j) -listrZip ZR ZR = ZR -listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest -listrZip _ _ = error "listrZip: impossible pattern needlessly required" - -listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k -listrZipWith _ ZR ZR = ZR -listrZipWith f (i ::: irest) (j ::: jrest) = -  f i j ::: listrZipWith f irest jrest -listrZipWith _ _ _ = -  error "listrZipWith: impossible pattern needlessly required" - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> -  listrFromList perm $ \sperm -> -    case (listrRank sperm, listrRank sh) of -      (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of -        LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post -        EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post -        GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" -                       ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" -  where -    listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) -    listrSplitAt SZ sh = (ZR, sh) -    listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) -    listrSplitAt SS{} ZR = error "m' + 1 <= 0" - -    applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i -    applyPermRFull _ ZR _ = ZR -    applyPermRFull sm@SNat (i ::: perm) l = -      TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> -        case cmpNat (SNat @(idx + 1)) sm of -          LTI -> listrIndex si l ::: applyPermRFull sm perm l -          EQI -> listrIndex si l ::: applyPermRFull sm perm l -          GTI -> error "listrPermutePrefix: Index in permutation out of range" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) -  deriving (Eq, Ord, Generic) -  deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) -  :: forall {n1} {i}. -     forall n. (n + 1 ~ n1) -  => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) -  where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where -  showsPrec _ (IxR l) = listrShow shows l - -instance NFData i => NFData (IxR sh i) - -ixrLength :: IxR sh i -> Int -ixrLength (IxR l) = listrLength l - -ixrRank :: IxR n i -> SNat n -ixrRank (IxR sh) = listrRank sh - -ixrZero :: SNat n -> IIxR n -ixrZero SZ = ZIR -ixrZero (SS n) = 0 :.: ixrZero n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = -  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) -    (n :.% ixCvtRX idx) - -ixrHead :: IxR (n + 1) i -> i -ixrHead (IxR list) = listrHead list - -ixrTail :: IxR (n + 1) i -> IxR n i -ixrTail (IxR list) = IxR (listrTail list) - -ixrInit :: IxR (n + 1) i -> IxR n i -ixrInit (IxR list) = IxR (listrInit list) - -ixrLast :: IxR (n + 1) i -> i -ixrLast (IxR list) = listrLast list - -ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) - -ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) -ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 - -ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k -ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 - -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) -  deriving (Eq, Ord, Generic) -  deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) -  :: forall {n1} {i}. -     forall n. (n + 1 ~ n1) -  => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) -  where i :$: ShR sh = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where -  showsPrec _ (ShR l) = listrShow shows l - -instance NFData i => NFData (ShR sh i) - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = -  castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) -    ZSR -shCvtXR' (n :$% (idx :: IShX sh)) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = -  castWith (subst2 (lem1 @sh Refl)) -    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) -  where -    lem1 :: forall sh' n' k. -            k : sh' :~: Replicate n' Nothing -         -> Rank sh' + 1 :~: n' -    lem1 Refl = unsafeCoerceRefl - -    lem2 :: k : sh :~: Replicate n Nothing -         -> sh :~: Replicate (Rank sh) Nothing -    lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = -  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) -    (SUnknown n :$% shCvtRX idx) - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' - --- | This compares the shapes for value equality. -shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' - -shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l - --- | This function can also be used to conjure up a 'KnownNat' dictionary; --- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern --- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh - --- | The number of elements in an array described by this shape. -shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh - -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list - -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) - -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) - -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list - -shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 - -shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where -  type Item (ListR n i) = i -  fromList topl = go (SNat @n) topl -    where -      go :: SNat n' -> [i] -> ListR n' i -      go SZ [] = ZR -      go (SS n) (i : is) = i ::: go n is -      go _ _ = error $ "IsList(ListR): Mismatched list length (type says " -                         ++ show (fromSNat (SNat @n)) ++ ", list has length " -                         ++ show (length topl) ++ ")" -  toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where -  type Item (IxR n i) = i -  fromList = IxR . IsList.fromList -  toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where -  type Item (ShR n i) = i -  fromList = ShR . IsList.fromList -  toList = Foldable.toList - - -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where -  ZS :: ListS '[] f -  -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity -  (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) -infixr 3 ::$ - -instance (forall n. Show (f n)) => Show (ListS sh f) where -  showsPrec _ = listsShow shows - -instance (forall m. NFData (f m)) => NFData (ListS n f) where -  rnf ZS = () -  rnf (x ::$ l) = rnf x `seq` rnf l - -data UnconsListSRes f sh1 = -  forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) -listsUncons ZS = Nothing - --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqType ZS ZS = Just Refl -listsEqType (n ::$ sh) (m ::$ sh') -  | Just Refl <- testEquality n m -  , Just Refl <- listsEqType sh sh' -  = Just Refl -listsEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqual ZS ZS = Just Refl -listsEqual (n ::$ sh) (m ::$ sh') -  | Just Refl <- testEquality n m -  , n == m -  , Just Refl <- listsEqual sh sh' -  = Just Refl -listsEqual _ _ = Nothing - -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS -listsShow f l = showString "[" . go "" l . showString "]" -  where -    go :: String -> ListS sh' f -> ShowS -    go _ ZS = id -    go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listsLength :: ListS sh f -> Int -listsLength = getSum . listsFold (\_ -> Sum 1) - -listsRank :: ListS sh f -> SNat (Rank sh) -listsRank ZS = SNat -listsRank (_ ::$ sh) = snatSucc (listsRank sh) - -listsToList :: ListS sh (Const i) -> [i] -listsToList ZS = [] -listsToList (Const i ::$ is) = i : listsToList is - -listsHead :: ListS (n : sh) f -> f n -listsHead (i ::$ _) = i - -listsTail :: ListS (n : sh) f -> ListS sh f -listsTail (_ ::$ sh) = sh - -listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f -listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh -listsInit (_ ::$ ZS) = ZS - -listsLast :: ListS (n : sh) f -> f (Last (n : sh)) -listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh -listsLast (n ::$ ZS) = n - -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f -listsAppend ZS idx' = idx' -listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' - -listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) -listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = -  Fun.Pair i j ::$ listsZip is js - -listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -             -> ListS sh h -listsZipWith _ ZS ZS = ZS -listsZipWith f (i ::$ is) (j ::$ js) = -  f i j ::$ listsZipWith f is js - -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f -listsTakeLenPerm PNil _ = ZS -listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh -listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLenPerm PNil sh = sh -listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh -listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = -  case listsIndex (Proxy @is') (Proxy @sh) i sh of -    (item, SNat) -> item ::$ listsPermute is sh - --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) -  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" - -listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f -listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) - - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) -  deriving (Eq, Ord, Generic) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) -  :: forall {sh1} {i}. -     forall n sh. (KnownNat n, n : sh ~ sh1) -  => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) -  where i :.$ IxS shl = IxS (Const i ::$ shl) -infixr 3 :.$ - -{-# COMPLETE ZIS, (:.$) #-} - -type IIxS sh = IxS sh Int - -instance Show i => Show (IxS sh i) where -  showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l - -instance Functor (IxS sh) where -  fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where -  foldMap f (IxS l) = listsFold (f . getConst) l - -instance NFData i => NFData (IxS sh i) - -ixsLength :: IxS sh i -> Int -ixsLength (IxS l) = listsLength l - -ixsRank :: IxS sh i -> SNat (Rank sh) -ixsRank (IxS l) = listsRank l - -ixsZero :: ShS sh -> IIxS sh -ixsZero ZSS = ZIS -ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - -ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = getConst (listsHead list) - -ixsTail :: IxS (n : sh) i -> IxS sh i -ixsTail (IxS list) = IxS (listsTail list) - -ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i -ixsInit (IxS list) = IxS (listsInit list) - -ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = getConst (listsLast list) - -ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @(Const i)) - -ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) -ixsZip ZIS ZIS = ZIS -ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js - -ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k -ixsZipWith _ ZIS ZIS = ZIS -ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js - -ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) - - --- | The shape of a shape-typed array given as a list of 'SNat' values. --- --- Note that because the shape of a shape-typed array is known statically, you --- can also retrieve the array shape from a 'KnownShS' dictionary. -type role ShS nominal -type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) -  deriving (Eq, Ord, Generic) - -pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS - -pattern (:$$) -  :: forall {sh1}. -     forall n sh. (KnownNat n, n : sh ~ sh1) -  => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) -  where i :$$ ShS shl = ShS (i ::$ shl) - -infixr 3 :$$ - -{-# COMPLETE ZSS, (:$$) #-} - -instance Show (ShS sh) where -  showsPrec _ (ShS l) = listsShow (shows . fromSNat) l - -instance NFData (ShS sh) where -  rnf (ShS ZS) = () -  rnf (ShS (SNat ::$ l)) = rnf (ShS l) - -instance TestEquality ShS where -  testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 - --- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are --- equal if and only if values are equal.) -shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') -shsEqual = testEquality - -shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l - -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l - -shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh - -shsToList :: ShS sh -> [Int] -shsToList ZSS = [] -shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh - -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = -  castWith (subst1 (lem Refl)) $ -    n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) -                                 idx) -  where -    lem :: forall sh1 sh' n. -           Just n : sh1 :~: MapJust sh' -        -> n : Tail sh' :~: sh' -    lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - -shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list - -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) - -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) - -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list - -shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) - -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) - -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) - -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) - -shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) - -type family Product sh where -  Product '[] = 1 -  Product (n : ns) = n * Product ns - -shsProduct :: ShS sh -> SNat (Product sh) -shsProduct ZSS = SNat -shsProduct (n :$$ sh) = n `snatMul` shsProduct sh - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShS :: [Nat] -> Constraint -class KnownShS sh where knownShS :: ShS sh -instance KnownShS '[] where knownShS = ZSS -instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS - -withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k - -shsKnownShS :: ShS sh -> Dict KnownShS sh -shsKnownShS ZSS = Dict -shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict - -shsOrthotopeShape :: ShS sh -> Dict O.Shape sh -shsOrthotopeShape ZSS = Dict -shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict - - --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where -  type Item (ListS sh (Const i)) = i -  fromList topl = go (knownShS @sh) topl -    where -      go :: ShS sh' -> [i] -> ListS sh' (Const i) -      go ZSS [] = ZS -      go (_ :$$ sh) (i : is) = Const i ::$ go sh is -      go _ _ = error $ "IsList(ListS): Mismatched list length (type says " -                         ++ show (shsLength (knownShS @sh)) ++ ", list has length " -                         ++ show (length topl) ++ ")" -  toList = listsToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShS sh => IsList (IxS sh i) where -  type Item (IxS sh i) = i -  fromList = IxS . IsList.fromList -  toList = Foldable.toList - --- | Untyped: length and values are checked at runtime. -instance KnownShS sh => IsList (ShS sh) where -  type Item (ShS sh) = Int -  fromList topl = ShS (go (knownShS @sh) topl) -    where -      go :: ShS sh' -> [Int] -> ListS sh' SNat -      go ZSS [] = ZS -      go (sn :$$ sh) (i : is) -        | i == fromSNat' sn = sn ::$ go sh is -        | otherwise = error $ "IsList(ShS): Value does not match typing (type says " -                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" -      go _ _ = error $ "IsList(ShS): Mismatched list length (type says " -                         ++ show (shsLength (knownShS @sh)) ++ ", list has length " -                         ++ show (length topl) ++ ")" -  toList = shsToList diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 372439f..86dcee2 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -40,13 +40,13 @@ import GHC.TypeLits  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Mixed.XArray (XArray)  import Data.Array.Mixed.XArray qualified as X  import Data.Array.Nested.Internal.Lemmas  import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Shaped.Shape  import Data.Array.Strided.Arith | 
