aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-03 22:49:26 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:33:30 +0200
commit890f4afd45ea416134ddfaf8a9115602316e17dc (patch)
tree0c21a1523d42802a82e902c017a0ed995823bd8c
parentcebce678a2b86b03796ef71ceec42664d180b107 (diff)
Very unsafe fromInteger that crashes everything if you do it wrong
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs10
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs8
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs8
3 files changed, 14 insertions, 12 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index c80325c..1120dbb 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -40,6 +40,7 @@ import Foreign.Storable (Storable)
import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
import GHC.Generics (Generic)
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
@@ -204,21 +205,22 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_
| sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
| otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
-instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+instance (NumElt a, PrimElt a, Num a) => Num (Mixed sh a) where
(+) = mliftNumElt2 numEltAdd
(-) = mliftNumElt2 numEltSub
(*) = mliftNumElt2 numEltMul
negate = mliftNumElt1 numEltNeg
abs = mliftNumElt1 numEltAbs
signum = mliftNumElt1 numEltSignum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+ -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
+ fromInteger n = unsafeCoerce @(Mixed '[] a) @(Mixed sh a) $ fromPrimitive $ M_Primitive ZSX (X.scalar (fromInteger n))
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Mixed sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
recip = mliftNumElt1 floatEltRecip
(/) = mliftNumElt2 floatEltDiv
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Mixed sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
exp = mliftNumElt1 floatEltExp
log = mliftNumElt1 floatEltLog
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index c67e892..3e9f528 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -184,21 +184,21 @@ arithPromoteRanked2 :: forall n a. PrimElt a
-> Ranked n a -> Ranked n a -> Ranked n a
arithPromoteRanked2 = coerce
-instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+instance (NumElt a, PrimElt a, Num a) => Num (Ranked n a) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
+ fromInteger = Ranked . fromInteger
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where
+instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Ranked n a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
recip = arithPromoteRanked recip
(/) = arithPromoteRanked2 (/)
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where
+instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
exp = arithPromoteRanked exp
log = arithPromoteRanked log
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 9320495..863e604 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -179,21 +179,21 @@ arithPromoteShaped2 :: forall sh a. PrimElt a
-> Shaped sh a -> Shaped sh a -> Shaped sh a
arithPromoteShaped2 = coerce
-instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+instance (NumElt a, PrimElt a, Num a) => Num (Shaped sh a) where
(+) = arithPromoteShaped2 (+)
(-) = arithPromoteShaped2 (-)
(*) = arithPromoteShaped2 (*)
negate = arithPromoteShaped negate
abs = arithPromoteShaped abs
signum = arithPromoteShaped signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
+ fromInteger = Shaped . fromInteger
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Num a) => Fractional (Shaped sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
recip = arithPromoteShaped recip
(/) = arithPromoteShaped2 (/)
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where
+instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Shaped sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
exp = arithPromoteShaped exp
log = arithPromoteShaped log