diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-04-09 22:58:39 +0200 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-04-09 22:58:39 +0200 |
| commit | 21e52b349aaf0978f0ce5925fef6e53e0c9436f9 (patch) | |
| tree | 6798918805e1142ffa889cba66e9dc6d1ffa9249 /src/Data | |
| parent | 7b0824ad591e9df501a57b8a2e4b5148d55f6dd0 (diff) | |
Get rid of most ListX operations
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested.hs | 1 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 81 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 39 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 2 |
5 files changed, 44 insertions, 81 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9922644..6d4ae78 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -58,7 +58,6 @@ module Data.Array.Nested ( -- * Mixed arrays Mixed, - ListX(ZX, (::%)), IxX(.., ZIX, (:.%)), IIxX, ShX(.., (:$%)), KnownShX(..), IShX, StaticShX(.., ZKX, (:!%)), diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 9869d03..a01e0f3 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -50,43 +50,6 @@ import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Types --- * Mixed lists - -{-# INLINE listxFromList #-} -listxFromList :: StaticShX sh -> [i] -> ListX sh i -listxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l - -listxRank :: ListX sh i -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat - -listxHead :: ListX (mn ': sh) i -> i -listxHead (i ::% _) = i - -listxTail :: ListX (n : sh) i -> ListX sh i -listxTail (_ ::% sh) = sh - -listxAppend :: forall sh sh' i. ListX sh i -> ListX sh' i -> ListX (sh ++ sh') i -listxAppend = lazilyConcat (++) - -listxDrop :: forall i j sh sh'. ListX sh j -> ListX (sh ++ sh') i -> ListX sh' i -listxDrop ZX long = long -listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' - -listxInit :: forall i n sh. ListX (n : sh) i -> ListX (Init (n : sh)) i -listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh -listxInit (_ ::% ZX) = ZX - -listxLast :: forall i n sh. ListX (n : sh) i -> i -listxLast (_ ::% sh@(_ ::% _)) = listxLast sh -listxLast (x ::% ZX) = x - -{-# INLINE listxZipWith #-} -listxZipWith :: (i -> j -> k) -> ListX sh i -> ListX sh j -> ListX sh k -listxZipWith _ ZX ZX = ZX -listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js - - -- * Mixed indices -- | An index into a mixed-typed array. @@ -119,8 +82,13 @@ instance Show i => Show (IxX sh i) where showsPrec _ (IxX l) = listxShow shows l #endif +{-# INLINE ixxFromList #-} +ixxFromList :: StaticShX sh -> [i] -> IxX sh i +ixxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l + ixxRank :: IxX sh i -> SNat (Rank sh) -ixxRank (IxX l) = listxRank l +ixxRank ZIX = SNat +ixxRank (_ :.% l) | SNat <- ixxRank l = SNat ixxZero :: StaticShX sh -> IIxX sh ixxZero ZKX = ZIX @@ -130,32 +98,26 @@ ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh -{-# INLINEABLE ixxFromList #-} -ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i -ixxFromList = coerce (listxFromList @_ @i) - -ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = listxHead list +ixxHead :: IxX (mn ': sh) i -> i +ixxHead (i :.% _) = i ixxTail :: IxX (n : sh) i -> IxX sh i -ixxTail (IxX list) = IxX (listxTail list) +ixxTail (_ :.% sh) = sh ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @_ @i) - -ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i -ixxDrop = coerce (listxDrop @i @i) +ixxAppend (IxX l1) (IxX l2) = IxX $ lazilyConcat (++) l1 l2 -ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @i) +ixxDrop :: forall i j sh sh'. IxX sh j -> IxX (sh ++ sh') i -> IxX sh' i +ixxDrop ZIX long = long +ixxDrop (_ :.% short) long = case long of _ :.% long' -> ixxDrop short long' -ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @i) +ixxInit :: forall i n sh. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit (i :.% sh@(_ :.% _)) = i :.% ixxInit sh +ixxInit (_ :.% ZIX) = ZIX -ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i -ixxCast ZKX ZIX = ZIX -ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx -ixxCast _ _ = error "ixxCast: ranks don't match" +ixxLast :: forall i n sh. IxX (n : sh) i -> i +ixxLast (_ :.% sh@(_ :.% _)) = ixxLast sh +ixxLast (x :.% ZIX) = x ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX @@ -166,6 +128,11 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. {-# INLINEABLE ixxToLinear #-} diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index ee79ecf..19d81f0 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -31,7 +31,6 @@ import GHC.TypeNats qualified as TN import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Types -import Data.Array.Nested.Mixed.ListX -- * Permutations @@ -222,31 +221,29 @@ ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (shxPermutePrefix @()) -listxTakeLenPerm :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i -listxTakeLenPerm PNil _ = ZX -listxTakeLenPerm (_ `PCons` is) (n ::% sh) = n ::% listxTakeLenPerm is sh -listxTakeLenPerm (_ `PCons` _) ZX = error "Permutation longer than shape" +ixxTakeLenPerm :: forall i is sh. Perm is -> IxX sh i -> IxX (TakeLen is sh) i +ixxTakeLenPerm PNil _ = ZIX +ixxTakeLenPerm (_ `PCons` is) (n :.% sh) = n :.% ixxTakeLenPerm is sh +ixxTakeLenPerm (_ `PCons` _) ZIX = error "Permutation longer than shape" -listxDropLenPerm :: forall i is sh. Perm is -> ListX sh i -> ListX (DropLen is sh) i -listxDropLenPerm PNil sh = sh -listxDropLenPerm (_ `PCons` is) (_ ::% sh) = listxDropLenPerm is sh -listxDropLenPerm (_ `PCons` _) ZX = error "Permutation longer than shape" +ixxDropLenPerm :: forall i is sh. Perm is -> IxX sh i -> IxX (DropLen is sh) i +ixxDropLenPerm PNil sh = sh +ixxDropLenPerm (_ `PCons` is) (_ :.% sh) = ixxDropLenPerm is sh +ixxDropLenPerm (_ `PCons` _) ZIX = error "Permutation longer than shape" -listxPermute :: forall i is sh. Perm is -> ListX sh i -> ListX (Permute is sh) i -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex i sh ::% listxPermute is sh +ixxPermute :: forall i is sh. Perm is -> IxX sh i -> IxX (Permute is sh) i +ixxPermute PNil _ = ZIX +ixxPermute (i `PCons` (is :: Perm is')) (sh :: IxX sh f) = + ixxIndex i sh :.% ixxPermute is sh -listxIndex :: forall j i sh. SNat i -> ListX sh j -> j -listxIndex SZ (n ::% _) = n -listxIndex (SS i) (_ ::% sh) = listxIndex i sh -listxIndex _ ZX = error "Index into empty shape" - -listxPermutePrefix :: forall i is sh. Perm is -> ListX sh i -> ListX (PermutePrefix is sh) i -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLenPerm perm sh)) (listxDropLenPerm perm sh) +ixxIndex :: forall j i sh. SNat i -> IxX sh j -> j +ixxIndex SZ (n :.% _) = n +ixxIndex (SS i) (_ :.% sh) = ixxIndex i sh +ixxIndex _ ZIX = error "Index into empty shape" ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @i) +ixxPermutePrefix perm sh = ixxAppend (ixxPermute perm (ixxTakeLenPerm perm sh)) (ixxDropLenPerm perm sh) + -- * Operations on permutations diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 5e84a2d..f260d07 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -153,7 +153,7 @@ ixrCast _ _ = error "ixrCast: ranks don't match" -- lemReplicatePlusApp requires SNat that would cause overhead (not benchmarked) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = gcastWith (unsafeCoerceRefl :: Replicate (n + m) (Nothing @Nat) :~: Replicate n Nothing ++ Replicate m Nothing) $ - coerce (listxAppend @_ @_ @i) + coerce (ixxAppend @_ @_ @i) ixrIndex :: forall k n i. (k + 1 <= n) => SNat k -> IxR n i -> i ixrIndex SZ (x :.: _) = x diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 392ceac..60e0252 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -132,7 +132,7 @@ ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ - coerce (listxAppend @_ @_ @i) + coerce (ixxAppend @_ @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS |
