From a0010622885dcb55a916bf3514c0e9040f6871e9 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 23 May 2024 00:18:17 +0200 Subject: Fast numeric operations for Num --- src/Data/Array/Nested/Internal.hs | 46 ++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 8 deletions(-) (limited to 'src/Data/Array/Nested/Internal.hs') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 588237d..831a9b5 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveFoldable #-} @@ -62,6 +63,7 @@ import Unsafe.Coerce import Data.Array.Mixed import qualified Data.Array.Mixed as X +import Data.Array.Nested.Internal.Arith -- Invariant in the API @@ -999,6 +1001,7 @@ mliftPrim2 :: PrimElt a mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) +{-} instance (Num a, PrimElt a) => Num (Mixed sh a) where (+) = mliftPrim2 (+) (-) = mliftPrim2 (-) @@ -1008,12 +1011,39 @@ instance (Num a, PrimElt a) => Num (Mixed sh a) where signum = mliftPrim signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" -instance (Fractional a, PrimElt a) => Fractional (Mixed sh a) where +type NumConstr a = Num a +--} + +{--} +mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr)) + +mliftNumElt2 :: PrimElt a + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +-- TODO: Clean up this mess and remove NumConstr +type NumConstr a = NumElt a + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 numEltAdd + (-) = mliftNumElt2 numEltSub + (*) = mliftNumElt2 numEltMul + negate = mliftNumElt1 numEltNeg + abs = mliftNumElt1 numEltAbs + signum = mliftNumElt1 numEltSignum + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" +--} + +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" recip = mliftPrim recip (/) = mliftPrim2 (/) -instance (Floating a, PrimElt a) => Floating (Mixed sh a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" exp = mliftPrim exp log = mliftPrim log @@ -1316,7 +1346,7 @@ arithPromoteRanked2 :: forall n a. PrimElt a -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 = coerce -instance (Num a, PrimElt a) => Num (Ranked n a) where +instance (NumConstr a, PrimElt a) => Num (Ranked n a) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) @@ -1325,12 +1355,12 @@ instance (Num a, PrimElt a) => Num (Ranked n a) where signum = arithPromoteRanked signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicate" -instance (Fractional a, PrimElt a) => Fractional (Ranked n a) where +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Ranked n a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicate" recip = arithPromoteRanked recip (/) = arithPromoteRanked2 (/) -instance (Floating a, PrimElt a) => Floating (Ranked n a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Ranked n a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicate" exp = arithPromoteRanked exp log = arithPromoteRanked log @@ -1616,7 +1646,7 @@ arithPromoteShaped2 :: forall sh a. PrimElt a -> Shaped sh a -> Shaped sh a -> Shaped sh a arithPromoteShaped2 = coerce -instance (Num a, PrimElt a) => Num (Shaped sh a) where +instance (NumConstr a, PrimElt a) => Num (Shaped sh a) where (+) = arithPromoteShaped2 (+) (-) = arithPromoteShaped2 (-) (*) = arithPromoteShaped2 (*) @@ -1625,12 +1655,12 @@ instance (Num a, PrimElt a) => Num (Shaped sh a) where signum = arithPromoteShaped signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicate" -instance (Fractional a, PrimElt a) => Fractional (Shaped sh a) where +instance (Fractional a, PrimElt a, NumConstr a) => Fractional (Shaped sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicate" recip = arithPromoteShaped recip (/) = arithPromoteShaped2 (/) -instance (Floating a, PrimElt a) => Floating (Shaped sh a) where +instance (Floating a, PrimElt a, NumConstr a) => Floating (Shaped sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicate" exp = arithPromoteShaped exp log = arithPromoteShaped log -- cgit v1.2.3-70-g09d2