From 890f4afd45ea416134ddfaf8a9115602316e17dc Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Jun 2024 22:49:26 +0200 Subject: Very unsafe fromInteger that crashes everything if you do it wrong --- src/Data/Array/Nested/Internal/Mixed.hs | 10 ++++++---- src/Data/Array/Nested/Internal/Ranked.hs | 8 ++++---- src/Data/Array/Nested/Internal/Shaped.hs | 8 ++++---- 3 files changed, 14 insertions(+), 12 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index c80325c..1120dbb 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -40,6 +40,7 @@ import Foreign.Storable (Storable) import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) import GHC.Generics (Generic) import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Mixed.XArray qualified as X @@ -204,21 +205,22 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where +instance (NumElt a, PrimElt a, Num a) => Num (Mixed sh a) where (+) = mliftNumElt2 numEltAdd (-) = mliftNumElt2 numEltSub (*) = mliftNumElt2 numEltMul negate = mliftNumElt1 numEltNeg abs = mliftNumElt1 numEltAbs signum = mliftNumElt1 numEltSignum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS + fromInteger n = unsafeCoerce @(Mixed '[] a) @(Mixed sh a) $ fromPrimitive $ M_Primitive ZSX (X.scalar (fromInteger n)) -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" recip = mliftNumElt1 floatEltRecip (/) = mliftNumElt2 floatEltDiv -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" exp = mliftNumElt1 floatEltExp log = mliftNumElt1 floatEltLog diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index c67e892..3e9f528 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -184,21 +184,21 @@ arithPromoteRanked2 :: forall n a. PrimElt a -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 = coerce -instance (NumElt a, PrimElt a) => Num (Ranked n a) where +instance (NumElt a, PrimElt a, Num a) => Num (Ranked n a) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" + fromInteger = Ranked . fromInteger -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Ranked n a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" recip = arithPromoteRanked recip (/) = arithPromoteRanked2 (/) -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" exp = arithPromoteRanked exp log = arithPromoteRanked log diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 9320495..863e604 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -179,21 +179,21 @@ arithPromoteShaped2 :: forall sh a. PrimElt a -> Shaped sh a -> Shaped sh a -> Shaped sh a arithPromoteShaped2 = coerce -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where +instance (NumElt a, PrimElt a, Num a) => Num (Shaped sh a) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) negate = arithPromoteShaped negate abs = arithPromoteShaped abs signum = arithPromoteShaped signum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" + fromInteger = Shaped . fromInteger -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Shaped sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Shaped sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" exp = arithPromoteShaped exp log = arithPromoteShaped log -- cgit v1.2.3-70-g09d2