diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-20 17:25:28 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-20 17:25:28 +0200 |
commit | 697d4360aefee9e5142091e880e7384112a3419d (patch) | |
tree | c1f401a0b4f58634587f40b21a68c1379069021b /src/Data/Array/Nested/Internal.hs | |
parent | 52c0237fbdbc3c99ee6565ba18250360a330fb8b (diff) |
Make Storable a superclass of PrimElt
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index badb910..a05ff84 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -421,7 +421,7 @@ newtype Primitive a = Primitive a -- | Element types that are primitive; arrays of these types are just a newtype -- wrapper over an array. -class PrimElt a where +class Storable a => PrimElt a where fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) @@ -881,13 +881,13 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) -mfromVector :: forall sh a. (Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a mtoVectorP (M_Primitive _ v) = X.toVector v -mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (coerce toPrimitive arr) mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a @@ -910,7 +910,7 @@ mrerankP ssh sh2 f (M_Primitive sh arr) = (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) -mrerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a) +mrerank :: forall sh1 sh2 sh a. PrimElt a => StaticShX sh -> IShX sh2 -> (Mixed sh1 a -> Mixed sh2 a) -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a @@ -920,7 +920,7 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) mreplicateP sh x = M_Primitive sh (X.replicate sh x) -mreplicate :: forall sh a. (Storable a, PrimElt a) +mreplicate :: forall sh a. PrimElt a => IShX sh -> a -> Mixed sh a mreplicate sh x = fromPrimitive (mreplicateP sh x) @@ -953,18 +953,18 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP -mliftPrim :: (Storable a, PrimElt a) +mliftPrim :: PrimElt a => (a -> a) -> Mixed sh a -> Mixed sh a mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) -mliftPrim2 :: (Storable a, PrimElt a) +mliftPrim2 :: PrimElt a => (a -> a -> a) -> Mixed sh a -> Mixed sh a -> Mixed sh a mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) -instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where +instance (Num a, PrimElt a) => Num (Mixed sh a) where (+) = mliftPrim2 (+) (-) = mliftPrim2 (-) (*) = mliftPrim2 (*) @@ -1248,7 +1248,7 @@ arithPromoteRanked2 :: forall n a. PrimElt a -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 = coerce -instance (Storable a, Num a, PrimElt a) => Num (Ranked n a) where +instance (Num a, PrimElt a) => Num (Ranked n a) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) @@ -1346,7 +1346,7 @@ rsumOuter1P (Ranked (M_Primitive sh arr)) , _ :$% shT <- sh = Ranked (M_Primitive shT (X.sumOuter (SUnknown () :!% ZKX) (X.staticShapeFrom shT) arr)) -rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a) +rsumOuter1 :: forall n a. (Num a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive @@ -1378,13 +1378,13 @@ rfromVectorP sh v | Dict <- lemKnownReplicate (snatFromShR sh) = Ranked (mfromVectorP (shCvtRX sh) v) -rfromVector :: forall n a. (Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a rtoVectorP = coerce mtoVectorP -rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a @@ -1417,7 +1417,7 @@ rrerankP sn sh2 f (Ranked arr) (\a -> let Ranked r = f (Ranked a) in r) arr) -rrerank :: forall n1 n2 n a. (Storable a, PrimElt a) +rrerank :: forall n1 n2 n a. PrimElt a => SNat n -> IShR n2 -> (Ranked n1 a -> Ranked n2 a) -> Ranked (n + n1) a -> Ranked (n + n2) a @@ -1429,7 +1429,7 @@ rreplicateP sh x | Dict <- lemKnownReplicate (snatFromShR sh) = Ranked (mreplicateP (shCvtRX sh) x) -rreplicate :: forall n a. (Storable a, PrimElt a) +rreplicate :: forall n a. PrimElt a => IShR n -> a -> Ranked n a rreplicate sh x = coerce fromPrimitive (rreplicateP sh x) @@ -1492,7 +1492,7 @@ arithPromoteShaped2 :: forall sh a. PrimElt a -> Shaped sh a -> Shaped sh a -> Shaped sh a arithPromoteShaped2 = coerce -instance (Storable a, Num a, PrimElt a) => Num (Shaped sh a) where +instance (Num a, PrimElt a) => Num (Shaped sh a) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) @@ -1571,7 +1571,7 @@ ssumOuter1P :: forall sh n a. (Storable a, Num a) ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) = Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr)) -ssumOuter1 :: forall sh n a. (Storable a, Num a, PrimElt a) +ssumOuter1 :: forall sh n a. (Num a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive @@ -1636,13 +1636,13 @@ sscalar x = Shaped (mscalar x) sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) -sfromVector :: (Storable a, PrimElt a) => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = coerce fromPrimitive (sfromVectorP sh v) stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP -stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1672,7 +1672,7 @@ srerankP sh sh2 f sarr@(Shaped arr) (\a -> let Shaped r = f (Shaped a) in r) arr) -srerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a) +srerank :: forall sh1 sh2 sh a. PrimElt a => StaticShX sh -> IShX sh2 -> (Mixed sh1 a -> Mixed sh2 a) -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a @@ -1682,7 +1682,7 @@ srerank ssh sh2 f (toPrimitive -> arr) = sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x) -sreplicate :: (Storable a, PrimElt a) => ShS sh -> a -> Shaped sh a +sreplicate :: PrimElt a => ShS sh -> a -> Shaped sh a sreplicate sh x = coerce fromPrimitive (sreplicateP sh x) sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a |