diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Nested/Internal/Arith.hs | |
parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 435 |
1 files changed, 0 insertions, 435 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs deleted file mode 100644 index 95fcfcf..0000000 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ /dev/null @@ -1,435 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Arith where - -import Control.Monad (forM, guard) -import qualified Data.Array.Internal as OI -import qualified Data.Array.Internal.RankedG as RG -import qualified Data.Array.Internal.RankedS as RS -import Data.Bits -import Data.Int -import Data.List (sort) -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.C.Types -import Foreign.Ptr -import Foreign.Storable (Storable) -import GHC.TypeLits -import Language.Haskell.TH -import System.IO.Unsafe - -import Data.Array.Nested.Internal.Arith.Foreign -import Data.Array.Nested.Internal.Arith.Lists - - -liftVEltwise1 :: Storable a - => SNat n - -> (VS.Vector a -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just prefixSz <- stridesDense sh strides = - let vec' = f (VS.slice offset prefixSz vec) - in RS.A (RG.A sh (OI.T strides 0 vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) - -liftVEltwise2 :: Storable a - => SNat n - -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -> RS.Array n a -liftVEltwise2 SNat f - arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) - arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 - | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs - | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of - (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars - let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) - in RS.A (RG.A sh1 (OI.T strides1 0 vec')) - (Just 1, Just n) -> -- scalar * dense - RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) - (Just n, Just 1) -> -- dense * scalar - RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) - (_, _) -> -- fallback case - RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) - --- | Given the shape vector and the stride vector, return whether this vector --- of strides uses a dense prefix of its backing array. If so, the number of --- elements in this prefix is returned. --- This excludes any offset. -stridesDense :: [Int] -> [Int] -> Maybe Int -stridesDense sh _ | any (<= 0) sh = Just 0 -stridesDense sh str = - -- sort dimensions on their stride, ascending, dropping any zero strides - case dropWhile ((== 0) . fst) (sort (zip str sh)) of - [] -> Just 1 - (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' - _ -> Nothing -- if the smallest stride is not 1, it will never be dense - where - -- Given size of currently densely covered region at beginning of the - -- array, the remaining shape vector and the corresponding remaining stride - -- vector, return whether this all together covers a dense prefix of the - -- array. If it does, return the number of elements in this prefix. - checkCover :: Int -> [Int] -> [Int] -> Maybe Int - checkCover block [] [] = Just block - checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' - checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" - -{-# NOINLINE vectorOp1 #-} -vectorOp1 :: forall a b. Storable a - => (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr b -> IO ()) - -> VS.Vector a -> VS.Vector a -vectorOp1 ptrconv f v = unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length v) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith v $ \pv -> - f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) - VS.unsafeFreeze outv - --- | If two vectors are given, assumes that they have the same length. -{-# NOINLINE vectorOp2 #-} -vectorOp2 :: forall a b. Storable a - => (a -> b) - -> (Ptr a -> Ptr b) - -> (a -> a -> a) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv - -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs - -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv - -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a -vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases - (Left x) (Left y) -> VS.singleton (fss x y) - - (Left x) (Right vy) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vy) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vy $ \pvy -> - fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) - VS.unsafeFreeze outv - - (Right vx) (Left y) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) - VS.unsafeFreeze outv - - (Right vx) (Right vy) - | VS.length vx == VS.length vy -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - VS.unsafeWith vy $ \pvy -> - fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) - VS.unsafeFreeze outv - | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) - --- TODO: test all the weird cases of this function --- | Reduce along the inner dimension -{-# NOINLINE vectorRedInnerOp #-} -vectorRedInnerOp :: forall a b n. (Num a, Storable a) - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel - -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) - | null sh = error "unreachable" - | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) - | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) - -- now the input array is nonempty - | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) - | last strides == 0 = - liftVEltwise1 sn - (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) - (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) - -- now there is useful work along the inner dimension - | otherwise = - let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those - (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) - ndimsF = length shF - in unsafePerformIO $ do - outv <- VSM.unsafeNew (product (init shF)) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> - fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) - RS.fromVector (init sh) <$> VS.unsafeFreeze outv - -flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) - -> Int64 -> Ptr a -> Ptr a -> a -> IO () -flipOp f n out v s = f n out s v - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_binary_" ++ atCName arithtype - c_ss = varE (aboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_fbinary_" ++ atCName arithtype - c_ss = varE (afboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) - c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] - return $ FunD name [Clause [] (NormalB body) []]]) - --- This branch is ostensibly a runtime branch, but will (hopefully) be --- constant-folded away by GHC. -intWidBranch1 :: forall i n. (FiniteBits i, Storable i) - => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) - -> (SNat n -> RS.Array n i -> RS.Array n i) -intWidBranch1 f32 f64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) - | otherwise = error "Unsupported Int width" - -intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) - => (i -> i -> i) -- ss - -- int32 - -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv - -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) -intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) - | otherwise = error "Unsupported Int width" - -intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) - => -- int32 - (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed fsc32 fred32 fsc64 fred64 sn - | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 - | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 - | otherwise = error "Unsupported Int width" - -class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - -instance NumElt Int32 where - numEltAdd = addVectorInt32 - numEltSub = subVectorInt32 - numEltMul = mulVectorInt32 - numEltNeg = negVectorInt32 - numEltAbs = absVectorInt32 - numEltSignum = signumVectorInt32 - numEltSum1Inner = sum1VectorInt32 - numEltProduct1Inner = product1VectorInt32 - -instance NumElt Int64 where - numEltAdd = addVectorInt64 - numEltSub = subVectorInt64 - numEltMul = mulVectorInt64 - numEltNeg = negVectorInt64 - numEltAbs = absVectorInt64 - numEltSignum = signumVectorInt64 - numEltSum1Inner = sum1VectorInt64 - numEltProduct1Inner = product1VectorInt64 - -instance NumElt Float where - numEltAdd = addVectorFloat - numEltSub = subVectorFloat - numEltMul = mulVectorFloat - numEltNeg = negVectorFloat - numEltAbs = absVectorFloat - numEltSignum = signumVectorFloat - numEltSum1Inner = sum1VectorFloat - numEltProduct1Inner = product1VectorFloat - -instance NumElt Double where - numEltAdd = addVectorDouble - numEltSub = subVectorDouble - numEltMul = mulVectorDouble - numEltNeg = negVectorDouble - numEltAbs = absVectorDouble - numEltSignum = signumVectorDouble - numEltSum1Inner = sum1VectorDouble - numEltProduct1Inner = product1VectorDouble - -instance NumElt Int where - numEltAdd = intWidBranch2 @Int (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @Int (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @Int (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed @Int - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) - numEltProduct1Inner = intWidBranchRed @Int - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -instance NumElt CInt where - numEltAdd = intWidBranch2 @CInt (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @CInt (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @CInt (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed @CInt - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) - numEltProduct1Inner = intWidBranchRed @CInt - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -class FloatElt a where - floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a - -instance FloatElt Float where - floatEltDiv = divVectorFloat - floatEltPow = powVectorFloat - floatEltLogbase = logbaseVectorFloat - floatEltRecip = recipVectorFloat - floatEltExp = expVectorFloat - floatEltLog = logVectorFloat - floatEltSqrt = sqrtVectorFloat - floatEltSin = sinVectorFloat - floatEltCos = cosVectorFloat - floatEltTan = tanVectorFloat - floatEltAsin = asinVectorFloat - floatEltAcos = acosVectorFloat - floatEltAtan = atanVectorFloat - floatEltSinh = sinhVectorFloat - floatEltCosh = coshVectorFloat - floatEltTanh = tanhVectorFloat - floatEltAsinh = asinhVectorFloat - floatEltAcosh = acoshVectorFloat - floatEltAtanh = atanhVectorFloat - floatEltLog1p = log1pVectorFloat - floatEltExpm1 = expm1VectorFloat - floatEltLog1pexp = log1pexpVectorFloat - floatEltLog1mexp = log1mexpVectorFloat - -instance FloatElt Double where - floatEltDiv = divVectorDouble - floatEltPow = powVectorDouble - floatEltLogbase = logbaseVectorDouble - floatEltRecip = recipVectorDouble - floatEltExp = expVectorDouble - floatEltLog = logVectorDouble - floatEltSqrt = sqrtVectorDouble - floatEltSin = sinVectorDouble - floatEltCos = cosVectorDouble - floatEltTan = tanVectorDouble - floatEltAsin = asinVectorDouble - floatEltAcos = acosVectorDouble - floatEltAtan = atanVectorDouble - floatEltSinh = sinhVectorDouble - floatEltCosh = coshVectorDouble - floatEltTanh = tanhVectorDouble - floatEltAsinh = asinhVectorDouble - floatEltAcosh = acoshVectorDouble - floatEltAtanh = atanhVectorDouble - floatEltLog1p = log1pVectorDouble - floatEltExpm1 = expm1VectorDouble - floatEltLog1pexp = log1pexpVectorDouble - floatEltLog1mexp = log1mexpVectorDouble |