aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2026-04-04 16:59:37 +0200
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-04-04 23:51:39 +0200
commitdec7d6c47fe9b783e1a98008a4efffb77df6f393 (patch)
treeefad22c6f6a4c489d4ad8e7397acf934b6a2ce73 /src/Data
parentee319119b1f24db2b2e981e303db9935a1dca425 (diff)
Implement ListX as [] with strict pattern synonyms
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Mixed/ListX.hs125
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs63
-rw-r--r--src/Data/Array/Nested/Permutation.hs1
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs13
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs25
6 files changed, 148 insertions, 85 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index c898a75..14de7f9 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -113,8 +113,12 @@ module Data.Array.Nested (
import Prelude hiding (mappend, mconcat)
+import Foreign.Storable
+import GHC.TypeLits
+
import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Ranked
@@ -123,8 +127,6 @@ import Data.Array.Nested.Shaped
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types
import Data.Array.Strided.Arith
-import Foreign.Storable
-import GHC.TypeLits
-- $integralRealFloat
--
diff --git a/src/Data/Array/Nested/Mixed/ListX.hs b/src/Data/Array/Nested/Mixed/ListX.hs
new file mode 100644
index 0000000..e89d1c8
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed/ListX.hs
@@ -0,0 +1,125 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Mixed.ListX (ListX, pattern ZX, pattern (::%), listxShow, lazily, lazilyConcat, lazilyForce) where
+
+import Control.DeepSeq (NFData(..))
+import Data.Foldable qualified as Foldable
+import Data.Kind (Type)
+import Data.Type.Equality
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import GHC.TypeLits.Orphans ()
+#endif
+
+import Data.Array.Nested.Types
+
+-- * Mixed lists implementation
+
+type role ListX nominal representational
+type ListX :: [Maybe Nat] -> Type -> Type
+newtype ListX sh i = ListX [i]
+ -- data invariant: each element is in WHNF; the spine maybe be not forced
+ deriving (Eq, Ord, NFData, Foldable)
+
+pattern ZX :: forall sh i. () => sh ~ '[] => ListX sh i
+pattern ZX <- (listxNull -> Just Refl)
+ where ZX = ListX []
+
+{-# INLINE listxNull #-}
+listxNull :: ListX sh i -> Maybe (sh :~: '[])
+listxNull (ListX []) = Just unsafeCoerceRefl
+listxNull (ListX (_ : _)) = Nothing
+
+{-# INLINE (::%) #-}
+pattern (::%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> ListX sh i -> ListX sh1 i
+pattern i ::% sh <- (listxUncons -> Just (UnconsListXRes sh i))
+ where !i ::% ListX !sh = ListX (i : sh)
+infixr 3 ::%
+
+data UnconsListXRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh i) i
+{-# INLINE listxUncons #-}
+listxUncons :: forall sh1 i. ListX sh1 i -> Maybe (UnconsListXRes i sh1)
+listxUncons (ListX (i : sh')) = gcastWith (unsafeCoerceRefl :: Head sh1 ': Tail sh1 :~: sh1) $
+ Just (UnconsListXRes (ListX @(Tail sh1) sh') i)
+listxUncons (ListX []) = Nothing
+
+{-# COMPLETE ZX, (::%) #-}
+
+{-# INLINE lazily #-}
+lazily :: ([a] -> [b]) -> ListX sh a -> ListX sh b
+lazily f (ListX l) = ListX $ f l
+
+{-# INLINE lazilyConcat #-}
+lazilyConcat :: ([a] -> [b] -> [c]) -> ListX sh a -> ListX sh' b -> ListX (sh ++ sh') c
+lazilyConcat f (ListX l) (ListX k) = ListX $ f l k
+
+{-# INLINE lazilyForce #-}
+lazilyForce :: ([a] -> [b]) -> ListX sh a -> ListX sh b
+lazilyForce f (ListX l) = let res = f l
+ in foldr seq () res `seq` ListX res
+
+#ifdef OXAR_DEFAULT_SHOW_INSTANCES
+deriving instance Show i => Show (ListX sh i)
+#else
+instance Show i => Show (ListX sh i) where
+ showsPrec _ = listxShow shows
+#endif
+
+{-# INLINE listxShow #-}
+listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS
+listxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListX sh' i -> ShowS
+ go _ ZX = id
+ go prefix (x ::% xs) = showString prefix . f x . go "," xs
+
+-- This can't be derived, becauses the list needs to be fully evaluated,
+-- per data invariant. This version is faster than versions defined using
+-- (::%) or lazilyForce.
+instance Functor (ListX l) where
+ {-# INLINE fmap #-}
+ fmap f (ListX l) =
+ let fmap' [] = []
+ fmap' (x : xs) = let y = f x
+ rest = fmap' xs
+ in y `seq` rest `seq` y : rest
+ in ListX $ fmap' l
+
+-- | Very untyped: not even length is checked (at runtime).
+instance IsList (ListX sh i) where
+ type Item (ListX sh i) = i
+ {-# INLINE fromList #-}
+ fromList l = foldr seq () l `seq` ListX l
+ {-# INLINE toList #-}
+ toList = Foldable.toList
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 2dfcc8c..611ec19 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -28,6 +28,7 @@
module Data.Array.Nested.Mixed.Shape where
import Control.DeepSeq (NFData(..))
+import Control.Exception (assert)
import Data.Bifunctor (first)
import Data.Coerce
import Data.Foldable qualified as Foldable
@@ -43,6 +44,7 @@ import GHC.TypeLits
import GHC.TypeLits.Orphans ()
#endif
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Types
@@ -55,63 +57,14 @@ type family Rank sh where
-- * Mixed lists
-type role ListX nominal representational
-type ListX :: [Maybe Nat] -> Type -> Type
-data ListX sh i where
- ZX :: ListX '[] i
- (::%) :: forall n sh {i}. i -> ListX sh i -> ListX (n : sh) i
-deriving instance Eq i => Eq (ListX sh i)
-deriving instance Ord i => Ord (ListX sh i)
-infixr 3 ::%
-
-#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance Show i => Show (ListX sh i)
-#else
-instance Show i => Show (ListX sh i) where
- showsPrec _ = listxShow shows
-#endif
-
-instance NFData i => NFData (ListX sh i) where
- rnf ZX = ()
- rnf (x ::% l) = rnf x `seq` rnf l
-
-instance Functor (ListX l) where
- {-# INLINE fmap #-}
- fmap _ ZX = ZX
- fmap f (x ::% xs) = f x ::% fmap f xs
-
-instance Foldable (ListX l) where
- {-# INLINE foldMap #-}
- foldMap _ ZX = mempty
- foldMap f (x ::% xs) = f x <> foldMap f xs
- {-# INLINE foldr #-}
- foldr _ z ZX = z
- foldr f z (x ::% xs) = f x (foldr f z xs)
- null ZX = False
- null _ = True
+{-# INLINE listxFromList #-}
+listxFromList :: StaticShX sh -> [i] -> ListX sh i
+listxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l
listxRank :: ListX sh i -> SNat (Rank sh)
listxRank ZX = SNat
listxRank (_ ::% l) | SNat <- listxRank l = SNat
-{-# INLINE listxShow #-}
-listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS
-listxShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListX sh' i -> ShowS
- go _ ZX = id
- go prefix (x ::% xs) = showString prefix . f x . go "," xs
-
-listxFromList :: StaticShX sh -> [i] -> ListX sh i
-listxFromList topssh topl = go topssh topl
- where
- go :: StaticShX sh' -> [i] -> ListX sh' i
- go ZKX [] = ZX
- go (_ :!% sh) (i : is) = i ::% go sh is
- go _ _ = error $ "listxFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
-
listxHead :: ListX (mn ': sh) i -> i
listxHead (i ::% _) = i
@@ -772,12 +725,6 @@ shxFlatten = go (SNat @1)
goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
--- | Very untyped: only length is checked (at runtime).
-instance KnownShX sh => IsList (ListX sh i) where
- type Item (ListX sh i) = i
- fromList = listxFromList (knownShX @sh)
- toList = Foldable.toList
-
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShX sh => IsList (IxX sh i) where
type Item (IxX sh i) = i
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 9254728..b6e5f47 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -31,6 +31,7 @@ import GHC.TypeNats qualified as TN
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Types
+import Data.Array.Nested.Mixed.ListX
-- * Permutations
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 2166123..4eef090 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -28,6 +28,7 @@
module Data.Array.Nested.Ranked.Shape where
import Control.DeepSeq (NFData(..))
+import Control.Exception (assert)
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
@@ -40,6 +41,7 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -121,15 +123,10 @@ listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+{-# INLINE listrFromList #-}
listrFromList :: SNat n -> [i] -> ListR n i
-listrFromList topsn topl = go topsn topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "listrFromList: Mismatched list length (type says "
- ++ show (fromSNat topsn) ++ ", list has length "
- ++ show (length topl) ++ ")"
+listrFromList topsn topl = assert (fromSNat' topsn == length topl)
+ $ ListR $ IsList.fromList topl
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 274f954..13596a7 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -26,6 +26,7 @@
module Data.Array.Nested.Shaped.Shape where
import Control.DeepSeq (NFData(..))
+import Control.Exception (assert)
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
@@ -37,6 +38,7 @@ import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+import Data.Array.Nested.Mixed.ListX
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -93,26 +95,15 @@ listsRank :: ListS sh i -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+{-# INLINE listsFromList #-}
listsFromList :: ShS sh -> [i] -> ListS sh i
-listsFromList topsh topl = go topsh topl
- where
- go :: ShS sh' -> [i] -> ListS sh' i
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = i ::$ go sh is
- go _ _ = error $ "listsFromList: Mismatched list length (type says "
- ++ show (shsLength topsh) ++ ", list has length "
- ++ show (length topl) ++ ")"
+listsFromList sh l = assert (shsLength sh == length l)
+ $ ListS $ IsList.fromList l
-{-# INLINEABLE listsFromListS #-}
+{-# INLINE listsFromListS #-}
listsFromListS :: ListS sh i0 -> [i] -> ListS sh i
-listsFromListS topl0 topl = go topl0 topl
- where
- go :: ListS sh i0 -> [i] -> ListS sh i
- go ZS [] = ZS
- go (_ ::$ l0) (i : is) = i ::$ go l0 is
- go _ _ = error $ "listsFromListS: Mismatched list length (the model says "
- ++ show (length topl0) ++ ", list has length "
- ++ show (length topl) ++ ")"
+listsFromListS sh l = assert (length sh == length l)
+ $ ListS $ IsList.fromList l
listsHead :: ListS (n : sh) i -> i
listsHead (i ::$ _) = i