From a65306ba5d80891b20ac86fa3a3242f9497751e6 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 May 2024 11:58:40 +0200 Subject: Refactor Mixed (modules, regular function names) --- src/Data/Array/Mixed/Internal/Arith.hs | 435 +++++++++++++++++++++++++++++++++ 1 file changed, 435 insertions(+) create mode 100644 src/Data/Array/Mixed/Internal/Arith.hs (limited to 'src/Data/Array/Mixed/Internal/Arith.hs') diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs new file mode 100644 index 0000000..cf6820b --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -0,0 +1,435 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Mixed.Internal.Arith.Foreign +import Data.Array.Mixed.Internal.Arith.Lists + + +liftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just prefixSz <- stridesDense sh strides = + let vec' = f (VS.slice offset prefixSz vec) + in RS.A (RG.A sh (OI.T strides 0 vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +liftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f + arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) + arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) + | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs + | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of + (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) + in RS.A (RG.A sh1 (OI.T strides1 0 vec')) + (Just 1, Just n) -> -- scalar * dense + RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) + (Just n, Just 1) -> -- dense * scalar + RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) + (_, _) -> -- fallback case + RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = + -- sort dimensions on their stride, ascending, dropping any zero strides + case dropWhile ((== 0) . fst) (sort (zip str sh)) of + [] -> Just 1 + (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' + _ -> Nothing -- if the smallest stride is not 1, it will never be dense + where + -- Given size of currently densely covered region at beginning of the + -- array, the remaining shape vector and the corresponding remaining stride + -- vector, return whether this all together covers a dense prefix of the + -- array. If it does, return the number of elements in this prefix. + checkCover :: Int -> [Int] -> [Int] -> Maybe Int + checkCover block [] [] = Just block + checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' + checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +{-# NOINLINE vectorOp1 #-} +vectorOp1 :: forall a b. Storable a + => (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO ()) + -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length v) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith v $ \pv -> + f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) + VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} +vectorOp2 :: forall a b. Storable a + => (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv + -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs + -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv + -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases + (Left x) (Left y) -> VS.singleton (fss x y) + + (Left x) (Right vy) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vy) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vy $ \pvy -> + fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) + VS.unsafeFreeze outv + + (Right vx) (Left y) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) + VS.unsafeFreeze outv + + (Right vx) (Right vy) + | VS.length vx == VS.length vy -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + VS.unsafeWith vy $ \pvy -> + fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) + VS.unsafeFreeze outv + | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) + | null sh = error "unreachable" + | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) + | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) + -- now the input array is nonempty + | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) + | last strides == 0 = + liftVEltwise1 sn + (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) + (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) + -- now there is useful work along the inner dimension + | otherwise = + let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those + (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) + ndimsF = length shF + in unsafePerformIO $ do + outv <- VSM.unsafeNew (product (init shF)) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> + fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) + RS.fromVector (init sh) <$> VS.unsafeFreeze outv + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) + -> Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_binary_" ++ atCName arithtype + c_ss = varE (aboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_fbinary_" ++ atCName arithtype + c_ss = varE (afboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] + return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) + => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) + -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn + | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> i -> i) -- ss + -- int32 + -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv + -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn + | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 + | otherwise = error "Unsupported Int width" + +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + +instance NumElt Int32 where + numEltAdd = addVectorInt32 + numEltSub = subVectorInt32 + numEltMul = mulVectorInt32 + numEltNeg = negVectorInt32 + numEltAbs = absVectorInt32 + numEltSignum = signumVectorInt32 + numEltSum1Inner = sum1VectorInt32 + numEltProduct1Inner = product1VectorInt32 + +instance NumElt Int64 where + numEltAdd = addVectorInt64 + numEltSub = subVectorInt64 + numEltMul = mulVectorInt64 + numEltNeg = negVectorInt64 + numEltAbs = absVectorInt64 + numEltSignum = signumVectorInt64 + numEltSum1Inner = sum1VectorInt64 + numEltProduct1Inner = product1VectorInt64 + +instance NumElt Float where + numEltAdd = addVectorFloat + numEltSub = subVectorFloat + numEltMul = mulVectorFloat + numEltNeg = negVectorFloat + numEltAbs = absVectorFloat + numEltSignum = signumVectorFloat + numEltSum1Inner = sum1VectorFloat + numEltProduct1Inner = product1VectorFloat + +instance NumElt Double where + numEltAdd = addVectorDouble + numEltSub = subVectorDouble + numEltMul = mulVectorDouble + numEltNeg = negVectorDouble + numEltAbs = absVectorDouble + numEltSignum = signumVectorDouble + numEltSum1Inner = sum1VectorDouble + numEltProduct1Inner = product1VectorDouble + +instance NumElt Int where + numEltAdd = intWidBranch2 @Int (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @Int (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @Int (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +instance NumElt CInt where + numEltAdd = intWidBranch2 @CInt (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @CInt (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @CInt (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where + floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a + floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a + floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a + floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a + floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where + floatEltDiv = divVectorFloat + floatEltPow = powVectorFloat + floatEltLogbase = logbaseVectorFloat + floatEltRecip = recipVectorFloat + floatEltExp = expVectorFloat + floatEltLog = logVectorFloat + floatEltSqrt = sqrtVectorFloat + floatEltSin = sinVectorFloat + floatEltCos = cosVectorFloat + floatEltTan = tanVectorFloat + floatEltAsin = asinVectorFloat + floatEltAcos = acosVectorFloat + floatEltAtan = atanVectorFloat + floatEltSinh = sinhVectorFloat + floatEltCosh = coshVectorFloat + floatEltTanh = tanhVectorFloat + floatEltAsinh = asinhVectorFloat + floatEltAcosh = acoshVectorFloat + floatEltAtanh = atanhVectorFloat + floatEltLog1p = log1pVectorFloat + floatEltExpm1 = expm1VectorFloat + floatEltLog1pexp = log1pexpVectorFloat + floatEltLog1mexp = log1mexpVectorFloat + +instance FloatElt Double where + floatEltDiv = divVectorDouble + floatEltPow = powVectorDouble + floatEltLogbase = logbaseVectorDouble + floatEltRecip = recipVectorDouble + floatEltExp = expVectorDouble + floatEltLog = logVectorDouble + floatEltSqrt = sqrtVectorDouble + floatEltSin = sinVectorDouble + floatEltCos = cosVectorDouble + floatEltTan = tanVectorDouble + floatEltAsin = asinVectorDouble + floatEltAcos = acosVectorDouble + floatEltAtan = atanVectorDouble + floatEltSinh = sinhVectorDouble + floatEltCosh = coshVectorDouble + floatEltTanh = tanhVectorDouble + floatEltAsinh = asinhVectorDouble + floatEltAcosh = acoshVectorDouble + floatEltAtanh = atanhVectorDouble + floatEltLog1p = log1pVectorDouble + floatEltExpm1 = expm1VectorDouble + floatEltLog1pexp = log1pexpVectorDouble + floatEltLog1mexp = log1mexpVectorDouble -- cgit v1.2.3-70-g09d2