diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs new file mode 100644 index 0000000..4312cd5 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -0,0 +1,240 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Nested.Internal.Arith.Foreign +import Data.Array.Nested.Internal.Arith.Lists + + +mliftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just prefixSz <- stridesDense sh strides = + let vec' = f (VS.slice offset prefixSz vec) + in RS.A (RG.A sh (OI.T strides 0 vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +mliftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +mliftVEltwise2 SNat f + arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) + arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) + | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs + | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of + (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) + in RS.A (RG.A sh1 (OI.T strides1 0 vec')) + (Just 1, Just n) -> -- scalar * dense + RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) + (Just n, Just 1) -> -- dense * scalar + RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) + (_, _) -> -- fallback case + RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = + -- sort dimensions on their stride, ascending + case sort (zip str sh) of + [] -> Just 0 + (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' + _ -> error "Orthotope array's shape vector and stride vector have different lengths" + where + -- Given size of currently densely covered region at beginning of the + -- array, the remaining shape vector and the corresponding remaining stride + -- vector, return whether this all together covers a dense prefix of the + -- array. If it does, return the number of elements in this prefix. + checkCover :: Int -> [Int] -> [Int] -> Maybe Int + checkCover block [] [] = Just block + checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' + checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +vectorOp1 :: forall a b. Storable a + => (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO ()) + -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length v) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith v $ \pv -> + f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) + VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +vectorOp2 :: forall a b. Storable a + => (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv + -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs + -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv + -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases + (Left x) (Left y) -> VS.singleton (fss x y) + + (Left x) (Right vy) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vy) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vy $ \pvy -> + fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) + VS.unsafeFreeze outv + + (Right vx) (Left y) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) + VS.unsafeFreeze outv + + (Right vx) (Right vy) + | VS.length vx == VS.length vy -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + VS.unsafeWith vy $ \pvy -> + fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) + VS.unsafeFreeze outv + | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) + -> Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM binopsList $ \arithop -> do + let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype + c_ss = varE (aboScalFun arithop arithtype) + c_sv = varE $ mkName (cnamebase ++ "_sv") + c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs") + | otherwise = [| flipOp $c_sv |] + c_vv = varE $ mkName (cnamebase ++ "_vv") + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM unopsList $ \arithop -> do + let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) + => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) + -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn + | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> i -> i) -- ss + -- int32 + -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv + -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn + | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +instance NumElt Int32 where + numEltAdd = addVectorInt32 + numEltSub = subVectorInt32 + numEltMul = mulVectorInt32 + numEltNeg = negVectorInt32 + numEltAbs = absVectorInt32 + numEltSignum = signumVectorInt32 + +instance NumElt Int64 where + numEltAdd = addVectorInt64 + numEltSub = subVectorInt64 + numEltMul = mulVectorInt64 + numEltNeg = negVectorInt64 + numEltAbs = absVectorInt64 + numEltSignum = signumVectorInt64 + +instance NumElt Float where + numEltAdd = addVectorFloat + numEltSub = subVectorFloat + numEltMul = mulVectorFloat + numEltNeg = negVectorFloat + numEltAbs = absVectorFloat + numEltSignum = signumVectorFloat + +instance NumElt Double where + numEltAdd = addVectorDouble + numEltSub = subVectorDouble + numEltMul = mulVectorDouble + numEltNeg = negVectorDouble + numEltAbs = absVectorDouble + numEltSignum = signumVectorDouble + +instance NumElt Int where + numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv + numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv + numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv + numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 + numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 + numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 + +instance NumElt CInt where + numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv + numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv + numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv + numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 + numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 + numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 |