diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 22:49:26 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-10 16:33:30 +0200 | 
| commit | 890f4afd45ea416134ddfaf8a9115602316e17dc (patch) | |
| tree | 0c21a1523d42802a82e902c017a0ed995823bd8c | |
| parent | cebce678a2b86b03796ef71ceec42664d180b107 (diff) | |
Very unsafe fromInteger that crashes everything if you do it wrong
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 8 | 
3 files changed, 14 insertions, 12 deletions
| diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index c80325c..1120dbb 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -40,6 +40,7 @@ import Foreign.Storable (Storable)  import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)  import GHC.Generics (Generic)  import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce)  import Data.Array.Mixed.XArray (XArray(..))  import Data.Array.Mixed.XArray qualified as X @@ -204,21 +205,22 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_    | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))    | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where +instance (NumElt a, PrimElt a, Num a) => Num (Mixed sh a) where    (+) = mliftNumElt2 numEltAdd    (-) = mliftNumElt2 numEltSub    (*) = mliftNumElt2 numEltMul    negate = mliftNumElt1 numEltNeg    abs = mliftNumElt1 numEltAbs    signum = mliftNumElt1 numEltSignum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" +  -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS +  fromInteger n = unsafeCoerce @(Mixed '[] a) @(Mixed sh a) $ fromPrimitive $ M_Primitive ZSX (X.scalar (fromInteger n)) -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Mixed sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"    recip = mliftNumElt1 floatEltRecip    (/) = mliftNumElt2 floatEltDiv -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Mixed sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"    exp = mliftNumElt1 floatEltExp    log = mliftNumElt1 floatEltLog diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index c67e892..3e9f528 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -184,21 +184,21 @@ arithPromoteRanked2 :: forall n a. PrimElt a                      -> Ranked n a -> Ranked n a -> Ranked n a  arithPromoteRanked2 = coerce -instance (NumElt a, PrimElt a) => Num (Ranked n a) where +instance (NumElt a, PrimElt a, Num a) => Num (Ranked n a) where    (+) = arithPromoteRanked2 (+)    (-) = arithPromoteRanked2 (-)    (*) = arithPromoteRanked2 (*)    negate = arithPromoteRanked negate    abs = arithPromoteRanked abs    signum = arithPromoteRanked signum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" +  fromInteger = Ranked . fromInteger -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Ranked n a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"    recip = arithPromoteRanked recip    (/) = arithPromoteRanked2 (/) -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"    exp = arithPromoteRanked exp    log = arithPromoteRanked log diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 9320495..863e604 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -179,21 +179,21 @@ arithPromoteShaped2 :: forall sh a. PrimElt a                      -> Shaped sh a -> Shaped sh a -> Shaped sh a  arithPromoteShaped2 = coerce -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where +instance (NumElt a, PrimElt a, Num a) => Num (Shaped sh a) where    (+) = arithPromoteShaped2 (+)    (-) = arithPromoteShaped2 (-)    (*) = arithPromoteShaped2 (*)    negate = arithPromoteShaped negate    abs = arithPromoteShaped abs    signum = arithPromoteShaped signum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" +  fromInteger = Shaped . fromInteger -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Shaped sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"    recip = arithPromoteShaped recip    (/) = arithPromoteShaped2 (/) -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where +instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Shaped sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"    exp = arithPromoteShaped exp    log = arithPromoteShaped log | 
