aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs177
1 files changed, 122 insertions, 55 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index b6bee2e..6ce0f4f 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,8 +1,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -37,14 +35,15 @@ import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, build, quotRemInt#)
-import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Mixed.Shape.Internal
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -183,7 +182,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) =
listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+listrSplitAt SZ sh = (ZR, sh)
+listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
case listrRank sh of { shlen@SNat ->
@@ -195,11 +199,6 @@ listrPermutePrefix = \perm sh ->
++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
}
where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
applyPermRFull _ ZR _ = ZR
applyPermRFull sm@SNat (i ::: perm) l =
@@ -216,8 +215,7 @@ listrPermutePrefix = \perm sh ->
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
@@ -243,8 +241,6 @@ instance Show i => Show (IxR n i) where
showsPrec _ (IxR l) = listrShow shows l
#endif
-instance NFData i => NFData (IxR sh i)
-
ixrLength :: IxR sh i -> Int
ixrLength (IxR l) = listrLength l
@@ -288,7 +284,7 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
-- | Given a multidimensional index, get the corresponding linear
@@ -309,19 +305,34 @@ ixrToLinear = \sh i -> go sh i 0
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, Generic)
- deriving newtype (Functor, Foldable)
+newtype ShR n i = ShR (ShX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor)
pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
+pattern ZSR <- ShR (matchZSR @n -> Just Refl)
+ where ZSR = ShR ZSX
+
+matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZSR _ = Nothing
pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
+pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i))
+ where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl
+ = ShR (SUnknown i :$% shl)
+
+data UnconsShRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i
+shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1)
+shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
+ | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl
+ = Just (UnconsShRRes (ShR sh') x)
+shrUncons (ShR _) = Nothing
+
infixr 3 :$:
{-# COMPLETE ZSR, (:$:) #-}
@@ -332,69 +343,125 @@ type IShR n = ShR n Int
deriving instance Show i => Show (ShR n i)
#else
instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
+ showsPrec d (ShR l) = showsPrec d l
#endif
-instance NFData i => NFData (ShR sh i)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+shrEqRank ZSR ZSR = Just Refl
+shrEqRank (_ :$: sh) (_ :$: sh')
+ | Just Refl <- shrEqRank sh sh'
+ = Just Refl
+shrEqRank _ _ = Nothing
-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+shrEqual ZSR ZSR = Just Refl
+shrEqual (i :$: sh) (i' :$: sh')
+ | Just Refl <- shrEqual sh sh'
+ , i == i'
+ = Just Refl
+shrEqual _ _ = Nothing
shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
+shrLength (ShR l) = shxLength l
-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
+shrRank :: forall n i. ShR n i -> SNat n
+shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
+shrSize (ShR sh) = shxSize sh
-shrFromList :: forall n i. SNat n -> [i] -> ShR n i
-shrFromList = coerce (listrFromList @_ @i)
+shrFromList :: SNat n -> [Int] -> IShR n
+shrFromList snat = coerce (shxFromList (ssxReplicate snat))
{-# INLINEABLE shrToList #-}
-shrToList :: forall n i. ShR n i -> [i]
-shrToList = coerce (listrToList @_ @i)
+shrToList :: IShR n -> [Int]
+shrToList = coerce shxToList
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
+shrHead :: forall n i. ShR (n + 1) i -> i
+shrHead (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxHead @Nothing @(Replicate n Nothing) sh of
+ SUnknown i -> i
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
+shrTail :: forall n i. ShR (n + 1) i -> ShR n i
+shrTail
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (shxTail @_ @_ @i)
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
+shrInit :: forall n i. ShR (n + 1) i -> ShR n i
+shrInit
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = -- TODO: change this and all other unsafeCoerceRefl to lemmas:
+ gcastWith (unsafeCoerceRefl
+ :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $
+ coerce (shxInit @_ @_ @i)
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
+shrLast :: forall n i. ShR (n + 1) i -> i
+shrLast (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxLast sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrLast: impossible SKnown"
-- | Performs a runtime check that the lengths are identical.
shrCast :: SNat n' -> ShR n i -> ShR n' i
-shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+shrCast SZ ZSR = ZSR
+shrCast (SS n) (i :$: sh) = i :$: shrCast n sh
+shrCast _ _ = error "shrCast: ranks don't match"
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+shrAppend =
+ -- lemReplicatePlusApp requires an SNat
+ gcastWith (unsafeCoerceRefl
+ :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $
+ coerce (shxAppend @_ @_ @i)
{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+shrZipWith _ ZSR ZSR = ZSR
+shrZipWith f (i :$: irest) (j :$: jrest) =
+ f i j :$: shrZipWith f irest jrest
+shrZipWith _ _ _ =
+ error "shrZipWith: impossible pattern needlessly required"
+
+shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i)
+shrSplitAt SZ sh = (ZSR, sh)
+shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh)
+shrSplitAt SS{} ZSR = error "m' + 1 <= 0"
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrIndex :: forall k sh i. SNat k -> ShR sh i -> i
+shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrIndex: impossible SKnown"
+
+-- Copy-pasted from listrPermutePrefix, probably unavoidably.
+shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i
+shrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case shrRank sh of { shlen@SNat ->
+ let sperm = shrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
+ where
+ applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i
+ applyPermRFull _ ZSR _ = ZSR
+ applyPermRFull sm@SNat (i :$: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> shrIndex si l :$: applyPermRFull sm perm l
+ EQI -> shrIndex si l :$: applyPermRFull sm perm l
+ GTI -> error "shrPermutePrefix: Index in permutation out of range"
shrEnum :: IShR sh -> [IIxR sh]
shrEnum = shrEnum'
@@ -426,17 +493,17 @@ instance KnownNat n => IsList (IxR n i) where
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+instance KnownNat n => IsList (IShR n) where
+ type Item (IShR n) = Int
+ fromList = shrFromList (SNat @n)
+ toList = shrToList
-- * Internal helper functions
listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
listrCastWithName _ SZ ZR = ZR
-listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l
listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])