aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-02 23:22:23 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-02 23:22:39 +0200
commit8344417a8ed9ad59337ce3b880e2fadf89bec964 (patch)
treef0e4fdc9d7a272f4f76dded3aae9871d4c3b7c66 /src/Data/Array/Nested/Mixed.hs
parent75ee1572b75b45dcdc50e3af82ed50259ca77df0 (diff)
WIP simplify Castable
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs13
1 files changed, 11 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 221393f..8d25bc3 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Nested.Mixed where
@@ -29,7 +30,7 @@ import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
import Data.Int
-import Data.Kind (Type)
+import Data.Kind (Type, Constraint)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty qualified as NE
import Data.Proxy
@@ -290,7 +291,9 @@ matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
-class Elt a where
+class EltC a => Elt a where
+ type EltC a :: Constraint
+
-- ====== PUBLIC METHODS ====== --
mshape :: Mixed sh a -> IShX sh
@@ -396,6 +399,8 @@ class Elt a => KnownElt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ type EltC (Primitive a) = Storable a
+
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
@@ -500,6 +505,8 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ type EltC (a, b) = (Elt a, Elt b)
+
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
@@ -549,6 +556,8 @@ instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
-- Arrays of arrays are just arrays, but with more dimensions.
instance Elt a => Elt (Mixed sh' a) where
+ type EltC (Mixed sh' a) = Elt a
+
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.