diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 435 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 55 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 78 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 82 |
4 files changed, 0 insertions, 650 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs deleted file mode 100644 index 95fcfcf..0000000 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ /dev/null @@ -1,435 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Arith where - -import Control.Monad (forM, guard) -import qualified Data.Array.Internal as OI -import qualified Data.Array.Internal.RankedG as RG -import qualified Data.Array.Internal.RankedS as RS -import Data.Bits -import Data.Int -import Data.List (sort) -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.C.Types -import Foreign.Ptr -import Foreign.Storable (Storable) -import GHC.TypeLits -import Language.Haskell.TH -import System.IO.Unsafe - -import Data.Array.Nested.Internal.Arith.Foreign -import Data.Array.Nested.Internal.Arith.Lists - - -liftVEltwise1 :: Storable a - => SNat n - -> (VS.Vector a -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just prefixSz <- stridesDense sh strides = - let vec' = f (VS.slice offset prefixSz vec) - in RS.A (RG.A sh (OI.T strides 0 vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) - -liftVEltwise2 :: Storable a - => SNat n - -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -> RS.Array n a -liftVEltwise2 SNat f - arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) - arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 - | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs - | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of - (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars - let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) - in RS.A (RG.A sh1 (OI.T strides1 0 vec')) - (Just 1, Just n) -> -- scalar * dense - RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) - (Just n, Just 1) -> -- dense * scalar - RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) - (_, _) -> -- fallback case - RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) - --- | Given the shape vector and the stride vector, return whether this vector --- of strides uses a dense prefix of its backing array. If so, the number of --- elements in this prefix is returned. --- This excludes any offset. -stridesDense :: [Int] -> [Int] -> Maybe Int -stridesDense sh _ | any (<= 0) sh = Just 0 -stridesDense sh str = - -- sort dimensions on their stride, ascending, dropping any zero strides - case dropWhile ((== 0) . fst) (sort (zip str sh)) of - [] -> Just 1 - (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' - _ -> Nothing -- if the smallest stride is not 1, it will never be dense - where - -- Given size of currently densely covered region at beginning of the - -- array, the remaining shape vector and the corresponding remaining stride - -- vector, return whether this all together covers a dense prefix of the - -- array. If it does, return the number of elements in this prefix. - checkCover :: Int -> [Int] -> [Int] -> Maybe Int - checkCover block [] [] = Just block - checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' - checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" - -{-# NOINLINE vectorOp1 #-} -vectorOp1 :: forall a b. Storable a - => (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr b -> IO ()) - -> VS.Vector a -> VS.Vector a -vectorOp1 ptrconv f v = unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length v) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith v $ \pv -> - f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) - VS.unsafeFreeze outv - --- | If two vectors are given, assumes that they have the same length. -{-# NOINLINE vectorOp2 #-} -vectorOp2 :: forall a b. Storable a - => (a -> b) - -> (Ptr a -> Ptr b) - -> (a -> a -> a) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv - -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs - -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv - -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a -vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases - (Left x) (Left y) -> VS.singleton (fss x y) - - (Left x) (Right vy) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vy) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vy $ \pvy -> - fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) - VS.unsafeFreeze outv - - (Right vx) (Left y) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) - VS.unsafeFreeze outv - - (Right vx) (Right vy) - | VS.length vx == VS.length vy -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - VS.unsafeWith vy $ \pvy -> - fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) - VS.unsafeFreeze outv - | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) - --- TODO: test all the weird cases of this function --- | Reduce along the inner dimension -{-# NOINLINE vectorRedInnerOp #-} -vectorRedInnerOp :: forall a b n. (Num a, Storable a) - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel - -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) - | null sh = error "unreachable" - | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) - | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) - -- now the input array is nonempty - | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) - | last strides == 0 = - liftVEltwise1 sn - (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) - (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) - -- now there is useful work along the inner dimension - | otherwise = - let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those - (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) - ndimsF = length shF - in unsafePerformIO $ do - outv <- VSM.unsafeNew (product (init shF)) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> - fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) - RS.fromVector (init sh) <$> VS.unsafeFreeze outv - -flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) - -> Int64 -> Ptr a -> Ptr a -> a -> IO () -flipOp f n out v s = f n out s v - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_binary_" ++ atCName arithtype - c_ss = varE (aboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_fbinary_" ++ atCName arithtype - c_ss = varE (afboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) - c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] - return $ FunD name [Clause [] (NormalB body) []]]) - --- This branch is ostensibly a runtime branch, but will (hopefully) be --- constant-folded away by GHC. -intWidBranch1 :: forall i n. (FiniteBits i, Storable i) - => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) - -> (SNat n -> RS.Array n i -> RS.Array n i) -intWidBranch1 f32 f64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) - | otherwise = error "Unsupported Int width" - -intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) - => (i -> i -> i) -- ss - -- int32 - -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv - -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) -intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) - | otherwise = error "Unsupported Int width" - -intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) - => -- int32 - (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed fsc32 fred32 fsc64 fred64 sn - | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 - | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 - | otherwise = error "Unsupported Int width" - -class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - -instance NumElt Int32 where - numEltAdd = addVectorInt32 - numEltSub = subVectorInt32 - numEltMul = mulVectorInt32 - numEltNeg = negVectorInt32 - numEltAbs = absVectorInt32 - numEltSignum = signumVectorInt32 - numEltSum1Inner = sum1VectorInt32 - numEltProduct1Inner = product1VectorInt32 - -instance NumElt Int64 where - numEltAdd = addVectorInt64 - numEltSub = subVectorInt64 - numEltMul = mulVectorInt64 - numEltNeg = negVectorInt64 - numEltAbs = absVectorInt64 - numEltSignum = signumVectorInt64 - numEltSum1Inner = sum1VectorInt64 - numEltProduct1Inner = product1VectorInt64 - -instance NumElt Float where - numEltAdd = addVectorFloat - numEltSub = subVectorFloat - numEltMul = mulVectorFloat - numEltNeg = negVectorFloat - numEltAbs = absVectorFloat - numEltSignum = signumVectorFloat - numEltSum1Inner = sum1VectorFloat - numEltProduct1Inner = product1VectorFloat - -instance NumElt Double where - numEltAdd = addVectorDouble - numEltSub = subVectorDouble - numEltMul = mulVectorDouble - numEltNeg = negVectorDouble - numEltAbs = absVectorDouble - numEltSignum = signumVectorDouble - numEltSum1Inner = sum1VectorDouble - numEltProduct1Inner = product1VectorDouble - -instance NumElt Int where - numEltAdd = intWidBranch2 @Int (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @Int (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @Int (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed @Int - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) - numEltProduct1Inner = intWidBranchRed @Int - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -instance NumElt CInt where - numEltAdd = intWidBranch2 @CInt (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @CInt (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @CInt (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed @CInt - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) - numEltProduct1Inner = intWidBranchRed @CInt - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -class FloatElt a where - floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a - -instance FloatElt Float where - floatEltDiv = divVectorFloat - floatEltPow = powVectorFloat - floatEltLogbase = logbaseVectorFloat - floatEltRecip = recipVectorFloat - floatEltExp = expVectorFloat - floatEltLog = logVectorFloat - floatEltSqrt = sqrtVectorFloat - floatEltSin = sinVectorFloat - floatEltCos = cosVectorFloat - floatEltTan = tanVectorFloat - floatEltAsin = asinVectorFloat - floatEltAcos = acosVectorFloat - floatEltAtan = atanVectorFloat - floatEltSinh = sinhVectorFloat - floatEltCosh = coshVectorFloat - floatEltTanh = tanhVectorFloat - floatEltAsinh = asinhVectorFloat - floatEltAcosh = acoshVectorFloat - floatEltAtanh = atanhVectorFloat - floatEltLog1p = log1pVectorFloat - floatEltExpm1 = expm1VectorFloat - floatEltLog1pexp = log1pexpVectorFloat - floatEltLog1mexp = log1mexpVectorFloat - -instance FloatElt Double where - floatEltDiv = divVectorDouble - floatEltPow = powVectorDouble - floatEltLogbase = logbaseVectorDouble - floatEltRecip = recipVectorDouble - floatEltExp = expVectorDouble - floatEltLog = logVectorDouble - floatEltSqrt = sqrtVectorDouble - floatEltSin = sinVectorDouble - floatEltCos = cosVectorDouble - floatEltTan = tanVectorDouble - floatEltAsin = asinVectorDouble - floatEltAcos = acosVectorDouble - floatEltAtan = atanVectorDouble - floatEltSinh = sinhVectorDouble - floatEltCosh = coshVectorDouble - floatEltTanh = tanhVectorDouble - floatEltAsinh = asinhVectorDouble - floatEltAcosh = acoshVectorDouble - floatEltAtanh = atanhVectorDouble - floatEltLog1p = log1pVectorDouble - floatEltExpm1 = expm1VectorDouble - floatEltLog1pexp = log1pexpVectorDouble - floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs deleted file mode 100644 index ac83188..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ /dev/null @@ -1,55 +0,0 @@ -{-# LANGUAGE ForeignFunctionInterface #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Foreign where - -import Control.Monad -import Data.Int -import Data.Maybe -import Foreign.C.Types -import Foreign.Ptr -import Language.Haskell.TH - -import Data.Array.Nested.Internal.Arith.Lists - - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "binary_" ++ atCName arithtype - sequence $ catMaybes - [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "fbinary_" ++ atCName arithtype - sequence $ catMaybes - [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "unary_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "funary_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "reduce_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs deleted file mode 100644 index ce2836d..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ /dev/null @@ -1,78 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Lists where - -import Data.Char -import Data.Int -import Language.Haskell.TH - -import Data.Array.Nested.Internal.Arith.Lists.TH - - -data ArithType = ArithType - { atType :: Name -- ''Int32 - , atCName :: String -- "i32" - } - -floatTypesList :: [ArithType] -floatTypesList = - [ArithType ''Float "float" - ,ArithType ''Double "double" - ] - -typesList :: [ArithType] -typesList = - [ArithType ''Int32 "i32" - ,ArithType ''Int64 "i64" - ] - ++ floatTypesList - --- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) -$(genArithDataType Binop "ArithBOp") - -$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) -$(genArithEnumFun Binop ''ArithBOp "aboEnum") - -$(do clauses <- readArithLists Binop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] - ,return $ FunD (mkName "aboNumOp") clauses]) - - --- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) -$(genArithDataType FBinop "ArithFBOp") - -$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) -$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") - -$(do clauses <- readArithLists FBinop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] - ,return $ FunD (mkName "afboNumOp") clauses]) - - --- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) -$(genArithDataType Unop "ArithUOp") - -$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) -$(genArithEnumFun Unop ''ArithUOp "auoEnum") - - --- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) -$(genArithDataType FUnop "ArithFUOp") - -$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) -$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") - - --- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) -$(genArithDataType Redop "ArithRedOp") - -$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) -$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs deleted file mode 100644 index 7142dfa..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs +++ /dev/null @@ -1,82 +0,0 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} -module Data.Array.Nested.Internal.Arith.Lists.TH where - -import Control.Monad -import Control.Monad.IO.Class -import Data.Maybe -import Foreign.C.Types -import Language.Haskell.TH -import Language.Haskell.TH.Syntax -import Text.Read - - -data OpKind = Binop | FBinop | Unop | FUnop | Redop - deriving (Show, Eq) - -readArithLists :: OpKind - -> (String -> Int -> String -> Q a) - -> ([a] -> Q r) - -> Q r -readArithLists targetkind fop fcombine = do - addDependentFile "cbits/arith_lists.h" - lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" - - mvals <- forM lns $ \line -> do - if null (dropWhile (== ' ') line) - then return Nothing - else do let (kind, name, num, aux) = parseLine line - if kind == targetkind - then Just <$> fop name num aux - else return Nothing - - fcombine (catMaybes mvals) - where - parseLine s0 - | ("LIST_", s1) <- splitAt 5 s0 - , (kindstr, '(' : s2) <- break (== '(') s1 - , (f1, ',' : s3) <- parseField s2 - , (f2, ',' : s4) <- parseField s3 - , (f3, ')' : _) <- parseField s4 - , Just kind <- parseKind kindstr - , let name = f1 - , Just num <- readMaybe f2 - , let aux = f3 - = (kind, name, num, aux) - | otherwise - = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 - - parseField s = break (`elem` ",)") (dropWhile (== ' ') s) - - parseKind "BINOP" = Just Binop - parseKind "FBINOP" = Just FBinop - parseKind "UNOP" = Just Unop - parseKind "FUNOP" = Just FUnop - parseKind "REDOP" = Just Redop - parseKind _ = Nothing - -genArithDataType :: OpKind -> String -> Q [Dec] -genArithDataType kind dtname = do - cons <- readArithLists kind - (\name _num _ -> return $ NormalC (mkName name) []) - return - return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] - -genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] -genArithNameFun kind dtname funname nametrans = do - clauses <- readArithLists kind - (\name _num _ -> return (Clause [ConP (mkName name) [] []] - (NormalB (LitE (StringL (nametrans name)))) - [])) - return - return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) - ,FunD (mkName funname) clauses] - -genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] -genArithEnumFun kind dtname funname = do - clauses <- readArithLists kind - (\name num _ -> return (Clause [ConP (mkName name) [] []] - (NormalB (LitE (IntegerL (fromIntegral num)))) - [])) - return - return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) - ,FunD (mkName funname) clauses] |