diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 177 |
1 files changed, 122 insertions, 55 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index b6bee2e..6ce0f4f 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,8 +1,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -37,14 +35,15 @@ import Data.Kind (Type) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#) -import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Mixed.Shape.Internal +import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -183,7 +182,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) = listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) +listrSplitAt SZ sh = (ZR, sh) +listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) +listrSplitAt SS{} ZR = error "m' + 1 <= 0" + +listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case listrRank sh of { shlen@SNat -> @@ -195,11 +199,6 @@ listrPermutePrefix = \perm sh -> ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" } where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = @@ -216,8 +215,7 @@ listrPermutePrefix = \perm sh -> type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR @@ -243,8 +241,6 @@ instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l #endif -instance NFData i => NFData (IxR sh i) - ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l @@ -288,7 +284,7 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) -- | Given a multidimensional index, get the corresponding linear @@ -309,19 +305,34 @@ ixrToLinear = \sh i -> go sh i 0 type role ShR nominal representational type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) +newtype ShR n i = ShR (ShX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR +pattern ZSR <- ShR (matchZSR @n -> Just Refl) + where ZSR = ShR ZSX + +matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZSR _ = Nothing pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) +pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i)) + where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl + = ShR (SUnknown i :$% shl) + +data UnconsShRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i +shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1) +shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) + | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl + = Just (UnconsShRRes (ShR sh') x) +shrUncons (ShR _) = Nothing + infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} @@ -332,69 +343,125 @@ type IShR n = ShR n Int deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l + showsPrec d (ShR l) = showsPrec d l #endif -instance NFData i => NFData (ShR sh i) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' +shrEqRank ZSR ZSR = Just Refl +shrEqRank (_ :$: sh) (_ :$: sh') + | Just Refl <- shrEqRank sh sh' + = Just Refl +shrEqRank _ _ = Nothing -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' +shrEqual ZSR ZSR = Just Refl +shrEqual (i :$: sh) (i' :$: sh') + | Just Refl <- shrEqual sh sh' + , i == i' + = Just Refl +shrEqual _ _ = Nothing shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l +shrLength (ShR l) = shxLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh +shrRank :: forall n i. ShR n i -> SNat n +shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh +shrSize (ShR sh) = shxSize sh -shrFromList :: forall n i. SNat n -> [i] -> ShR n i -shrFromList = coerce (listrFromList @_ @i) +shrFromList :: SNat n -> [Int] -> IShR n +shrFromList snat = coerce (shxFromList (ssxReplicate snat)) {-# INLINEABLE shrToList #-} -shrToList :: forall n i. ShR n i -> [i] -shrToList = coerce (listrToList @_ @i) +shrToList :: IShR n -> [Int] +shrToList = coerce shxToList -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list +shrHead :: forall n i. ShR (n + 1) i -> i +shrHead (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxHead @Nothing @(Replicate n Nothing) sh of + SUnknown i -> i -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) +shrTail :: forall n i. ShR (n + 1) i -> ShR n i +shrTail + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (shxTail @_ @_ @i) -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) +shrInit :: forall n i. ShR (n + 1) i -> ShR n i +shrInit + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = -- TODO: change this and all other unsafeCoerceRefl to lemmas: + gcastWith (unsafeCoerceRefl + :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ + coerce (shxInit @_ @_ @i) -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list +shrLast :: forall n i. ShR (n + 1) i -> i +shrLast (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxLast sh of + SUnknown i -> i + SKnown{} -> error "shrLast: impossible SKnown" -- | Performs a runtime check that the lengths are identical. shrCast :: SNat n' -> ShR n i -> ShR n' i -shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) +shrCast SZ ZSR = ZSR +shrCast (SS n) (i :$: sh) = i :$: shrCast n sh +shrCast _ _ = error "shrCast: ranks don't match" shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +shrAppend = + -- lemReplicatePlusApp requires an SNat + gcastWith (unsafeCoerceRefl + :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ + coerce (shxAppend @_ @_ @i) {-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 +shrZipWith _ ZSR ZSR = ZSR +shrZipWith f (i :$: irest) (j :$: jrest) = + f i j :$: shrZipWith f irest jrest +shrZipWith _ _ _ = + error "shrZipWith: impossible pattern needlessly required" + +shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i) +shrSplitAt SZ sh = (ZSR, sh) +shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh) +shrSplitAt SS{} ZSR = error "m' + 1 <= 0" -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) +shrIndex :: forall k sh i. SNat k -> ShR sh i -> i +shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of + SUnknown i -> i + SKnown{} -> error "shrIndex: impossible SKnown" + +-- Copy-pasted from listrPermutePrefix, probably unavoidably. +shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i +shrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case shrRank sh of { shlen@SNat -> + let sperm = shrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } + where + applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i + applyPermRFull _ ZSR _ = ZSR + applyPermRFull sm@SNat (i :$: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> shrIndex si l :$: applyPermRFull sm perm l + EQI -> shrIndex si l :$: applyPermRFull sm perm l + GTI -> error "shrPermutePrefix: Index in permutation out of range" shrEnum :: IShR sh -> [IIxR sh] shrEnum = shrEnum' @@ -426,17 +493,17 @@ instance KnownNat n => IsList (IxR n i) where toList = Foldable.toList -- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +instance KnownNat n => IsList (IShR n) where + type Item (IShR n) = Int + fromList = shrFromList (SNat @n) + toList = shrToList -- * Internal helper functions listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR -listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l listrCastWithName name _ _ = error $ name ++ ": ranks don't match" $(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) |
