From 3a5d069565cf4a19fbf94c7b548f072bbada265b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 6 Jun 2025 12:15:15 +0200 Subject: Bidirectional inference for KnownElt --- src/Data/Array/Nested/Mixed.hs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'src/Data/Array/Nested/Mixed.hs') diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 0a2fc17..250c999 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -386,7 +386,9 @@ class EltC a => Elt a where -- This class is (currently) only required for 'mgenerate', -- 'Data.Array.Nested.Ranked.rgenerate' and -- 'Data.Array.Nested.Shaped.sgenerate'. -class Elt a => KnownElt a where +class (Elt a, KnownEltC a) => KnownElt a where + type KnownEltC a :: Constraint + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArrayUnsafe :: IShX sh -> Mixed sh a @@ -489,6 +491,7 @@ deriving via Primitive Float instance Elt Float deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where + type KnownEltC (Primitive a) = () memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 @@ -550,6 +553,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + type KnownEltC (a, b) = (KnownEltC a, KnownEltC b) memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) @@ -690,6 +694,7 @@ instance Elt a => Elt (Mixed sh' a) where mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + type KnownEltC (Mixed sh' a) = (KnownShX sh', KnownElt a) memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) mvecsUnsafeNew sh example -- cgit v1.2.3-70-g09d2