From 890f4afd45ea416134ddfaf8a9115602316e17dc Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Jun 2024 22:49:26 +0200 Subject: Very unsafe fromInteger that crashes everything if you do it wrong --- src/Data/Array/Nested/Internal/Mixed.hs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'src/Data/Array/Nested/Internal/Mixed.hs') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index c80325c..1120dbb 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -40,6 +40,7 @@ import Foreign.Storable (Storable) import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) import GHC.Generics (Generic) import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Mixed.XArray qualified as X @@ -204,21 +205,22 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where +instance (NumElt a, PrimElt a, Num a) => Num (Mixed sh a) where (+) = mliftNumElt2 numEltAdd (-) = mliftNumElt2 numEltSub (*) = mliftNumElt2 numEltMul negate = mliftNumElt1 numEltNeg abs = mliftNumElt1 numEltAbs signum = mliftNumElt1 numEltSignum - fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS + fromInteger n = unsafeCoerce @(Mixed '[] a) @(Mixed sh a) $ fromPrimitive $ M_Primitive ZSX (X.scalar (fromInteger n)) -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" recip = mliftNumElt1 floatEltRecip (/) = mliftNumElt2 floatEltDiv -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" exp = mliftNumElt1 floatEltExp log = mliftNumElt1 floatEltLog -- cgit v1.2.3-70-g09d2