diff options
| -rw-r--r-- | src/Data/Array/Nested.hs | 1 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 282 |
3 files changed, 104 insertions, 182 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ec81843..9922644 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -3,7 +3,6 @@ module Data.Array.Nested ( -- * Ranked arrays Ranked(Ranked), - ListR(ZR, (:::)), IxR(.., ZIR, (:.:)), IIxR, ShR(.., ZSR, (:$:)), IShR, rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim, diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 2595c64..c6f23ae 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -16,7 +16,7 @@ module Data.Array.Nested.Convert ( -- * Shape\/index\/list casting functions -- ** To ranked ixrFromIxS, ixrFromIxS', ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX, - listrCast, ixrCast, shrCast, + ixrCast, shrCast, -- ** To shaped ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, ixsCast, @@ -86,7 +86,6 @@ shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx shrFromShX :: IShX (Replicate n Nothing) -> IShR n shrFromShX = coerce --- listrCast re-exported -- ixrCast re-exported -- shrCast re-exported diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index a352eb3..5e84a2d 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -44,170 +44,40 @@ import Data.Array.Nested.Permutation import Data.Array.Nested.Types --- * Ranked lists - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -newtype ListR n i = ListR (ListX (Replicate n Nothing) i) - deriving (Eq, Ord, NFData, Functor, Foldable) - -pattern ZR :: forall n i. () => n ~ 0 => ListR n i -pattern ZR <- ListR (matchZX @n -> Just Refl) - where ZR = ListR ZX - -matchZX :: forall n i. ListX (Replicate n Nothing) i -> Maybe (n :~: 0) -matchZX ZX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl -matchZX _ = Nothing - -pattern (:::) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ListR n i -> ListR n1 i -pattern i ::: l <- (listrUncons -> Just (UnconsListRRes i l)) - where i ::: ListR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ListR (i ::% l) -infixr 3 ::: - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes i (ListR n i) -listrUncons :: forall n1 i. ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (ListR ((::%) @n @sh i l)) - | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl - , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl - , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl = - Just (UnconsListRRes i (ListR @(Rank sh) l)) -listrUncons (ListR _) = Nothing - -{-# COMPLETE ZR, (:::) #-} - -#ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance Show i => Show (ListR n i) -#else -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows -#endif - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqRank ZR ZR = Just Refl -listrEqRank (_ ::: sh) (_ ::: sh') - | Just Refl <- listrEqRank sh sh' - = Just Refl -listrEqRank _ _ = Nothing - --- | This compares the lists for value equality. -listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqual ZR ZR = Just Refl -listrEqual (i ::: sh) (j ::: sh') - | Just Refl <- listrEqual sh sh' - , i == j - = Just Refl -listrEqual _ _ = Nothing - -{-# INLINE listrShow #-} -listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR n' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrRank :: ListR n i -> SNat n -listrRank ZR = SNat -listrRank (_ ::: sh) = snatSucc (listrRank sh) - --- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked) -listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i -listrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $ - coerce (listxAppend @_ @_ @i) - -{-# INLINE listrFromList #-} -listrFromList :: SNat n -> [i] -> ListR n i -listrFromList topsn topl = assert (fromSNat' topsn == length topl) - $ ListR $ IsList.fromList topl - -listrHead :: ListR (n + 1) i -> i -listrHead (i ::: _) = i - -listrTail :: ListR (n + 1) i -> ListR n i -listrTail (_ ::: sh) = sh - -listrInit :: ListR (n + 1) i -> ListR n i -listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh -listrInit (_ ::: ZR) = ZR - -listrLast :: ListR (n + 1) i -> i -listrLast (_ ::: sh@(_ ::: _)) = listrLast sh -listrLast (n ::: ZR) = n - --- | Performs a runtime check that the lengths are identical. -listrCast :: SNat n' -> ListR n i -> ListR n' i -listrCast = listrCastWithName "listrCast" - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrZip :: ListR n i -> ListR n j -> ListR n (i, j) -listrZip ZR ZR = ZR -listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest -listrZip _ _ = error "listrZip: impossible pattern needlessly required" - -{-# INLINE listrZipWith #-} -listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k -listrZipWith _ ZR ZR = ZR -listrZipWith f (i ::: irest) (j ::: jrest) = - f i j ::: listrZipWith f irest jrest -listrZipWith _ _ _ = - error "listrZipWith: impossible pattern needlessly required" - -listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) -listrSplitAt SZ sh = (ZR, sh) -listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) -listrSplitAt SS{} ZR = error "m' + 1 <= 0" - -listrPermutePrefix :: forall n i. PermR -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> - case listrRank sh of { shlen@SNat -> - let sperm = listrFromList permlen perm in - case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - } - where - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - - -- * Ranked indices -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) +newtype IxR n i = IxR (IxX (Replicate n Nothing) i) deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR +pattern ZIR <- IxR (matchZIX @n -> Just Refl) + where ZIR = IxR ZIX + +matchZIX :: forall n i. IxX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZIX ZIX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZIX _ = Nothing pattern (:.:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i -pattern i :.: l <- IxR (i ::: (IxR -> l)) - where i :.: IxR l = IxR (i ::: l) +pattern i :.: l <- (ixrUncons -> Just (UnconsIxRRes i l)) + where i :.: IxR l | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = IxR (i :.% l) infixr 3 :.: +data UnconsIxRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsIxRRes i (IxR n i) +ixrUncons :: forall n1 i. IxR n1 i -> Maybe (UnconsIxRRes i n1) +ixrUncons (IxR ((:.%) @n @sh i l)) + | Refl <- lemReplicateHead (Proxy @n) (Proxy @sh) (Proxy @Nothing) (Proxy @n1) Refl + , Refl <- lemReplicateCons (Proxy @sh) (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh) (Proxy @n1) Refl = + Just (UnconsIxRRes i (IxR @(Rank sh) l)) +ixrUncons (IxR _) = Nothing + {-# COMPLETE ZIR, (:.:) #-} -- For convenience, this contains regular 'Int's instead of bounded integers @@ -218,48 +88,116 @@ type IIxR n = IxR n Int deriving instance Show i => Show (IxR n i) #else instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l + showsPrec _ = ixrShow shows #endif +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +ixrEqRank :: IxR n i -> IxR n' i -> Maybe (n :~: n') +ixrEqRank ZIR ZIR = Just Refl +ixrEqRank (_ :.: sh) (_ :.: sh') + | Just Refl <- ixrEqRank sh sh' + = Just Refl +ixrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +ixrEqual :: Eq i => IxR n i -> IxR n' i -> Maybe (n :~: n') +ixrEqual ZIR ZIR = Just Refl +ixrEqual (i :.: sh) (j :.: sh') + | Just Refl <- ixrEqual sh sh' + , i == j + = Just Refl +ixrEqual _ _ = Nothing + +{-# INLINE ixrShow #-} +ixrShow :: forall n i. (i -> ShowS) -> IxR n i -> ShowS +ixrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> IxR n' i -> ShowS + go _ ZIR = id + go prefix (x :.: xs) = showString prefix . f x . go "," xs + ixrRank :: IxR n i -> SNat n -ixrRank (IxR sh) = listrRank sh +ixrRank ZIR = SNat +ixrRank (_ :.: sh) = snatSucc (ixrRank sh) ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n -{-# INLINEABLE ixrFromList #-} -ixrFromList :: forall n i. SNat n -> [i] -> IxR n i -ixrFromList = coerce (listrFromList @_ @i) +{-# INLINE ixrFromList #-} +ixrFromList :: SNat n -> [i] -> IxR n i +ixrFromList topsn topl = assert (fromSNat' topsn == length topl) + $ IxR $ IsList.fromList topl ixrHead :: IxR (n + 1) i -> i -ixrHead (IxR list) = listrHead list +ixrHead (i :.: _) = i ixrTail :: IxR (n + 1) i -> IxR n i -ixrTail (IxR list) = IxR (listrTail list) +ixrTail (_ :.: sh) = sh ixrInit :: IxR (n + 1) i -> IxR n i -ixrInit (IxR list) = IxR (listrInit list) +ixrInit (n :.: sh@(_ :.: _)) = n :.: ixrInit sh +ixrInit (_ :.: ZIR) = ZIR ixrLast :: IxR (n + 1) i -> i -ixrLast (IxR list) = listrLast list +ixrLast (_ :.: sh@(_ :.: _)) = ixrLast sh +ixrLast (n :.: ZIR) = n -- | Performs a runtime check that the lengths are identical. ixrCast :: SNat n' -> IxR n i -> IxR n' i -ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) +ixrCast SZ ZIR = ZIR +ixrCast (SS n) (i :.: l) = i :.: ixrCast n l +ixrCast _ _ = error "ixrCast: ranks don't match" +-- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @n @m @i) +ixrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $ + coerce (listxAppend @_ @_ @i) + +ixrIndex :: forall k n i. (k + 1 <= n) => SNat k -> IxR n i -> i +ixrIndex SZ (x :.: _) = x +ixrIndex (SS i) (_ :.: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = ixrIndex i xs +ixrIndex _ ZIR = error "k + 1 <= 0" ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) -ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 +ixrZip ZIR ZIR = ZIR +ixrZip (i :.: irest) (j :.: jrest) = (i, j) :.: ixrZip irest jrest +ixrZip _ _ = error "ixrZip: impossible pattern needlessly required" {-# INLINE ixrZipWith #-} ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k -ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 +ixrZipWith _ ZIR ZIR = ZIR +ixrZipWith f (i :.: irest) (j :.: jrest) = + f i j :.: ixrZipWith f irest jrest +ixrZipWith _ _ _ = + error "ixrZipWith: impossible pattern needlessly required" + +ixrSplitAt :: m <= n' => SNat m -> IxR n' i -> (IxR m i, IxR (n' - m) i) +ixrSplitAt SZ sh = (ZIR, sh) +ixrSplitAt (SS m) (n :.: sh) = (\(pre, post) -> (n :.: pre, post)) (ixrSplitAt m sh) +ixrSplitAt SS{} ZIR = error "m' + 1 <= 0" ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @n @i) +ixrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case ixrRank sh of { shlen@SNat -> + let sperm = ixrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = ixrSplitAt permlen sh in ixrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } + where + applyPermRFull :: SNat m -> IxR k Int -> IxR m i -> IxR k i + applyPermRFull _ ZIR _ = ZIR + applyPermRFull sm@SNat (i :.: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> ixrIndex si l :.: applyPermRFull sm perm l + EQI -> ixrIndex si l :.: applyPermRFull sm perm l + GTI -> error "ixrPermutePrefix: Index in permutation out of range" -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. @@ -451,7 +389,7 @@ shrIndex k (ShR sh) = case shxIndex @i k sh of SUnknown i -> i SKnown{} -> error "shrIndex: impossible SKnown" --- Copy-pasted from listrPermutePrefix, probably unavoidably. +-- Copy-pasted from ixrPermutePrefix, probably unavoidably. shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i shrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> @@ -475,15 +413,9 @@ shrPermutePrefix = \perm sh -> -- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList = listrFromList (SNat @n) - toList = Foldable.toList - --- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IxR n i) where type Item (IxR n i) = i - fromList = IxR . IsList.fromList + fromList = ixrFromList (SNat @n) toList = Foldable.toList -- | Untyped: length is checked at runtime. @@ -491,11 +423,3 @@ instance KnownNat n => IsList (IShR n) where type Item (IShR n) = Int fromList = shrFromList (SNat @n) toList = shrToList - - --- * Internal helper functions - -listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i -listrCastWithName _ SZ ZR = ZR -listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l -listrCastWithName name _ _ = error $ name ++ ": ranks don't match" |
