diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 928 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 47 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 95 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 83 | ||||
-rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 5 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 75 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 2 |
9 files changed, 55 insertions, 1184 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 27ebb64..f7a76bc 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -1,929 +1,23 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Internal.Arith where -import Control.Monad (forM, guard) import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as RG import Data.Array.Internal.RankedS qualified as RS -import Data.Bifunctor (second) -import Data.Bits -import Data.Int -import Data.List (sort) -import Data.Vector.Storable qualified as VS -import Data.Vector.Storable.Mutable qualified as VSM -import Foreign.C.Types -import Foreign.Marshal.Alloc (alloca) -import Foreign.Ptr -import Foreign.Storable (Storable(sizeOf), peek, poke) -import GHC.TypeLits -import GHC.TypeNats qualified as TypeNats -import Language.Haskell.TH -import System.IO (hFlush, stdout) -import System.IO.Unsafe -import Data.Array.Mixed.Internal.Arith.Foreign -import Data.Array.Mixed.Internal.Arith.Lists -import Data.Array.Mixed.Types (fromSNat') +import Data.Array.Strided qualified as AS --- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) --- TODO: test all the cases of this thing with various input strides -liftVEltwise1 :: (Storable a, Storable b) - => SNat n - -> (VS.Vector a -> VS.Vector b) - -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just (blockOff, blockSz) <- stridesDense sh offset strides = - let vec' = f (VS.slice blockOff blockSz vec) - in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) +liftO1 :: (AS.Array n a -> AS.Array n' b) + -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO --- TODO: test all the cases of this thing with various input strides -{-# NOINLINE liftOpEltwise1 #-} -liftOpEltwise1 :: (Storable a, Storable b) - => SNat n - -> (Ptr a -> Ptr a') - -> (Ptr b -> Ptr b') - -> (Int64 -> Ptr b' -> Ptr Int64 -> Ptr Int64 -> Ptr a' -> IO ()) - -> RS.Array n a -> RS.Array n b -liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) - -- TODO: less code duplication between these two branches - | Just (blockOff, blockSz) <- stridesDense sh offset strides = - if blockSz == 0 - then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty)) - else unsafePerformIO $ do - outv <- VSM.unsafeNew blockSz - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> - VS.unsafeWith (VS.singleton 1) $ \pstrides -> - VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> - cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) - RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv - | otherwise = unsafePerformIO $ do - outv <- VSM.unsafeNew (product sh) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> - cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv) - RS.fromVector sh <$> VS.unsafeFreeze outv - --- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: Storable a - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (a -> a -> a) - -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv - -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs - -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv - -> RS.Array n a -> RS.Array n a -> RS.Array n a -liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv - arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) - arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 - | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty)) - | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of - (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars - let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2)) - in RS.A (RG.A sh1 (OI.T strides1 0 vec')) - - (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense - let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2) - RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2' - in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec)) - - (Just (_, 1), Nothing) -> -- scalar * array - wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2 - - (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar - let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1) - RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2) - in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec)) - - (Nothing, Just (_, 1)) -> -- array * scalar - wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2) - - (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) - | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the strides check - , strides1 == strides2 - -> -- dense * dense but the strides match - let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) - arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) - RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2' - in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec)) - - (_, _) -> -- fallback case - wrapBinaryVV sn ptrconv f_vv arr1 arr2 - --- | Given shape vector, offset and stride vector, check whether this virtual --- vector uses a dense subarray of its backing array. If so, the first index --- and the number of elements in this subarray is returned. --- This excludes any offset. -stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int) -stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0) -stridesDense sh offsetNeg stridesNeg = - -- First reverse all dimensions with negative stride, so that the first used - -- value is at 'offset' and the rest is >= offset. - let (offset, strides) = flipReverseds sh offsetNeg stridesNeg - in -- sort dimensions on their stride, ascending, dropping any zero strides - case filter ((/= 0) . fst) (sort (zip strides sh)) of - [] -> Just (offset, 1) - (1, n) : pairs -> (offset,) <$> checkCover n pairs - _ -> Nothing -- if the smallest stride is not 1, it will never be dense - where - -- Given size of currently densely covered region at beginning of the - -- array and the remaining (stride, size) pairs with all strides >=1, - -- return whether this all together covers a dense prefix of the array. If - -- it does, return the number of elements in this prefix. - checkCover :: Int -> [(Int, Int)] -> Maybe Int - checkCover block [] = Just block - checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover ((n-1) * s + block) pairs - - -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0 - flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int]) - flipReverseds [] off [] = (off, []) - flipReverseds (n : sh') off (s : str') - | s >= 0 = second (s :) (flipReverseds sh' off str') - | otherwise = - let off' = off + (n - 1) * s - in second ((-s) :) (flipReverseds sh' off' str') - flipReverseds _ _ _ = error "flipReverseds: invalid arguments" - -{-# NOINLINE wrapBinarySV #-} -wrapBinarySV :: Storable a - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) - -> a -> RS.Array n a - -> RS.Array n a -wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) = - unsafePerformIO $ do - outv <- VSM.unsafeNew (product sh) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> - cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv) - RS.fromVector sh <$> VS.unsafeFreeze outv - -wrapBinaryVS :: Storable a - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) - -> RS.Array n a -> a - -> RS.Array n a -wrapBinaryVS sn valconv ptrconv cf_strided arr y = - wrapBinarySV sn valconv ptrconv - (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr - --- | This function assumes that the two shapes are equal. -{-# NOINLINE wrapBinaryVV #-} -wrapBinaryVV :: Storable a - => SNat n - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) - -> RS.Array n a -> RS.Array n a - -> RS.Array n a -wrapBinaryVV sn@SNat ptrconv cf_strided - (RS.A (RG.A sh (OI.T strides1 offset1 vec1))) - (RS.A (RG.A _ (OI.T strides2 offset2 vec2))) = - unsafePerformIO $ do - outv <- VSM.unsafeNew (product sh) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 -> - VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 -> - VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 -> - VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 -> - cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2) - RS.fromVector sh <$> VS.unsafeFreeze outv - -{-# NOINLINE vectorOp1 #-} -vectorOp1 :: forall a b. Storable a - => (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr b -> IO ()) - -> VS.Vector a -> VS.Vector a -vectorOp1 ptrconv f v = unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length v) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith v $ \pv -> - f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) - VS.unsafeFreeze outv - --- | If two vectors are given, assumes that they have the same length. -{-# NOINLINE vectorOp2 #-} -vectorOp2 :: forall a b. Storable a - => (a -> b) - -> (Ptr a -> Ptr b) - -> (a -> a -> a) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv - -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs - -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv - -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a -vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases - (Left x) (Left y) -> VS.singleton (fss x y) - - (Left x) (Right vy) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vy) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vy $ \pvy -> - fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) - VS.unsafeFreeze outv - - (Right vx) (Left y) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) - VS.unsafeFreeze outv - - (Right vx) (Right vy) - | VS.length vx == VS.length vy -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - VS.unsafeWith vy $ \pvy -> - fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) - VS.unsafeFreeze outv - | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) - --- TODO: test handling of negative strides --- | Reduce along the inner dimension -{-# NOINLINE vectorRedInnerOp #-} -vectorRedInnerOp :: forall a b n. (Num a, Storable a) - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel - -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) - | null sh = error "unreachable" - | last sh <= 0 = RS.stretch (init sh) (RS.fromList (1 <$ init sh) [0]) - | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (0 <$ init strides) 0 VS.empty)) - -- now the input array is nonempty - | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) - | last strides == 0 = - liftVEltwise1 sn - (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) - (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) - -- now there is useful work along the inner dimension - | otherwise = - let -- replicated dimensions: dimensions with zero stride. The reduction - -- kernel need not concern itself with those (and in fact has a - -- precondition that there are no such dimensions in its input). - replDims = map (== 0) strides - -- filter out replicated dimensions - (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] - -- replace replicated dimensions with ones - shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims - ndimsF = length shF -- > 0, otherwise `last strides == 0` - - -- reversed dimensions: dimensions with negative stride. Reversal is - -- irrelevant for a reduction, and indeed the kernel has a - -- precondition that there are no such dimensions. - revDims = map (< 0) stridesF - stridesR = map abs stridesF - offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) - -- The *R values give an array with strides all > 0, hence the - -- left-most element is at offsetR. - in unsafePerformIO $ do - outvR <- VSM.unsafeNew (product (init shF)) - VSM.unsafeWith outvR $ \poutvR -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> - VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> - fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR) - TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> - RS.stretch (init sh) -- replicate to original shape - . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated - . RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions - . RS.fromVector @_ @lenFm1 (init shF) -- the partially-reversed result array - <$> VS.unsafeFreeze outvR - --- TODO: test handling of negative strides --- | Reduce full array -{-# NOINLINE vectorRedFullOp #-} -vectorRedFullOp :: forall a b n. (Num a, Storable a) - => SNat n - -> (a -> Int -> a) - -> (b -> a) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel - -> RS.Array n a -> a -vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec))) - | null sh = vec VS.! offset -- 0D array has one element - | any (<= 0) sh = 0 - -- now the input array is nonempty - | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset - -- now there is at least one non-replicated dimension - | otherwise = - let -- replicated dimensions: dimensions with zero stride. The reduction - -- kernel need not concern itself with those (and in fact has a - -- precondition that there are no such dimensions in its input). - replDims = map (== 0) strides - -- filter out replicated dimensions - (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] - ndimsF = length shF -- > 0, otherwise `all (== 0) strides` - -- we should scale up the output this many times to account for the replicated dimensions - multiplier = product [n | (n, True) <- zip sh replDims] - - -- reversed dimensions: dimensions with negative stride. Reversal is - -- irrelevant for a reduction, and indeed the kernel has a - -- precondition that there are no such dimensions. - revDims = map (< 0) stridesF - stridesR = map abs stridesF - offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) - -- The *R values give an array with strides all > 0, hence the - -- left-most element is at offsetR. - in unsafePerformIO $ do - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> - VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> - (`scaleval` multiplier) . valbackconv - <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) - --- TODO: test this function --- | Find extremum (minindex ("argmin") or maxindex) in full array -{-# NOINLINE vectorExtremumOp #-} -vectorExtremumOp :: forall a b n. Storable a - => (Ptr a -> Ptr b) - -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel - -> RS.Array n a -> [Int] -- result length: n -vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) - | null sh = [] - | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array" - -- now the input array is nonempty - | all (== 0) strides = 0 <$ sh - -- now there is at least one non-replicated dimension - | otherwise = - let -- replicated dimensions: dimensions with zero stride. The extremum - -- kernel need not concern itself with those (and in fact has a - -- precondition that there are no such dimensions in its input). - replDims = map (== 0) strides - -- filter out replicated dimensions - (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims] - ndimsF = length shF -- > 0, because not all strides were <=0 - - -- un-reverse reversed dimensions - revDims = map (< 0) stridesF - stridesR = map abs stridesF - offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) - - -- function to insert zeros in replicated-out dimensions - insertZeros :: [Bool] -> [Int] -> [Int] - insertZeros [] idx = idx - insertZeros (True : repls) idx = 0 : insertZeros repls idx - insertZeros (False : repls) (i : idx) = i : insertZeros repls idx - insertZeros (_:_) [] = error "unreachable" - in unsafePerformIO $ do - outvR <- VSM.unsafeNew (length shF) - VSM.unsafeWith outvR $ \poutvR -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> - VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> - fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) - insertZeros replDims - . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF -- re-reverse the reversed dimensions - . map (fromIntegral @Int64 @Int) - . VS.toList - <$> VS.unsafeFreeze outvR - -vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel - -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel - -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a -vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner - arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) - arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | null sh1 || null sh2 = error "unreachable" - | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 - | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0]) - | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty)) - -- now the input arrays are nonempty - | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2) - | last strides1 == 0 = - fmul sn - (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1))) - (vectorRedInnerOp sn valconv ptrconv fscale fred arr2) - | last strides2 == 0 = - fmul sn - (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) - (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2))) - -- now there is useful dotprod work along the inner dimension - | otherwise = unsafePerformIO $ do - let inrank = fromSNat' sn + 1 - outv <- VSM.unsafeNew (product (init sh1)) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> - VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> - VS.unsafeWith vec1 $ \pvec1 -> - VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> - VS.unsafeWith vec2 $ \pvec2 -> - fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) - pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) - pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) - RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv - -{-# NOINLINE dotScalarVector #-} -dotScalarVector :: forall a b. (Num a, Storable a) - => Int -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel - -> a -> VS.Vector a -> a -dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do - alloca @a $ \pout -> do - alloca @Int64 $ \pshape -> do - poke pshape (fromIntegral @Int @Int64 len) - alloca @Int64 $ \pstride -> do - poke pstride 1 - VS.unsafeWith vec $ \pvec -> - fred 1 (ptrconv pout) pshape pstride (ptrconv pvec) - res <- peek pout - return (scalar * res) - -{-# NOINLINE dotVectorVector #-} -dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel - -> VS.Vector a -> VS.Vector a -> a -dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do - VS.unsafeWith vec1 $ \pvec1 -> - VS.unsafeWith vec2 $ \pvec2 -> - valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2) - -{-# NOINLINE dotVectorVectorStrided #-} -dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) - -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel - -> Int -> Int -> VS.Vector a - -> Int -> Int -> VS.Vector a - -> a -dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do - VS.unsafeWith vec1 $ \pvec1 -> - VS.unsafeWith vec2 $ \pvec2 -> - valbackconv <$> fdot (fromIntegral @Int @Int64 len) - (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1) - (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2) - -flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) - -> Int64 -> Ptr a -> Ptr a -> a -> IO () -flipOp f n out v s = f n out s v - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_binary_" ++ atCName arithtype - c_ss_str = varE (aboNumOp arithop) - c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM intTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aiboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_ibinary_" ++ atCName arithtype - c_ss_str = varE (aiboNumOp arithop) - c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) - c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) - c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aiboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_fbinary_" ++ atCName arithtype - c_ss_str = varE (afboNumOp arithop) - c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn id id $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |] - return $ FunD name [Clause [] (NormalB body) []]]) - -mulWithInt :: Num a => a -> Int -> a -mulWithInt a i = a * fromIntegral i - -scaleFromSVStrided :: (Int64 -> Ptr Int64 -> Ptr a -> a -> Ptr Int64 -> Ptr a -> IO ()) - -> Int64 -> Ptr a -> a -> Ptr a -> IO () -scaleFromSVStrided fsv n out x ys = - VS.unsafeWith (VS.singleton n) $ \psh -> - VS.unsafeWith (VS.singleton 1) $ \pstrides -> - fsv 1 psh out x pstrides ys - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let scaleVar = case arithop of - RO_SUM -> varE 'mulWithInt - RO_PRODUCT -> varE '(^) - let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype)) - namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype)) - c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) - c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) - c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) - sequence [SigD name1 <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> vectorRedInnerOp sn id id (scaleFromSVStrided $c_scale_op) $c_op1 |] - return $ FunD name1 [Clause [] (NormalB body) []] - ,SigD namefull <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |] - ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |] - return $ FunD namefull [Clause [] (NormalB body) []] - ]) - -$(fmap concat . forM typesList $ \arithtype -> - fmap concat . forM ["min", "max"] $ \fname -> do - let ttyp = conT (atType arithtype) - name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) - sequence [SigD name <$> - [t| forall n. RS.Array n $ttyp -> [Int] |] - ,do body <- [| vectorExtremumOp id $c_op |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype)) - mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) - c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) - c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |] - return $ FunD name [Clause [] (NormalB body) []]]) - -foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO () -foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO () - -statisticsEnable :: Bool -> IO () -statisticsEnable b = c_stats_enable (if b then 1 else 0) - --- | Consumes the log: one particular event will only ever be printed once, --- even if statisticsPrintAll is called multiple times. -statisticsPrintAll :: IO () -statisticsPrintAll = do - hFlush stdout -- lower the chance of overlapping output - c_stats_print_all - --- This branch is ostensibly a runtime branch, but will (hopefully) be --- constant-folded away by GHC. -intWidBranch1 :: forall i n. (FiniteBits i, Storable i) - => (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) - -> (SNat n -> RS.Array n i -> RS.Array n i) -intWidBranch1 f32 f64 sn - | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr castPtr f32 - | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr castPtr f64 - | otherwise = error "Unsupported Int width" - -intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) - => (i -> i -> i) -- ss - -- int32 - -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- sv - -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ()) -- vs - -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- vv - -- int64 - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- sv - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv - -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) -intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32 - | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64 - | otherwise = error "Unsupported Int width" - -intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) - => -- int32 - (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn - | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 - | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 - | otherwise = error "Unsupported Int width" - -intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i) - => (i -> Int -> i) -- ^ scale op - -- int32 - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32) -- ^ reduction kernel - -- int64 - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ reduction kernel - -> (SNat n -> RS.Array n i -> i) -intWidBranchRedFull fsc fred32 fred64 sn - | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32 - | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64 - | otherwise = error "Unsupported Int width" - -intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i) - => -- int32 - (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ extremum kernel - -- int64 - -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ extremum kernel - -> (RS.Array n i -> [Int]) -intWidBranchExtr fextr32 fextr64 - | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32 - | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 - | otherwise = error "Unsupported Int width" - -intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i) - => -- int32 - (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel - -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ dotprod kernel - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn - | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 - | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64 - | otherwise = error "Unsupported Int width" - -class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltSumFull :: SNat n -> RS.Array n a -> a - numEltProductFull :: SNat n -> RS.Array n a -> a - numEltMinIndex :: SNat n -> RS.Array n a -> [Int] - numEltMaxIndex :: SNat n -> RS.Array n a -> [Int] - numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a - -instance NumElt Int32 where - numEltAdd = addVectorInt32 - numEltSub = subVectorInt32 - numEltMul = mulVectorInt32 - numEltNeg = negVectorInt32 - numEltAbs = absVectorInt32 - numEltSignum = signumVectorInt32 - numEltSum1Inner = sum1VectorInt32 - numEltProduct1Inner = product1VectorInt32 - numEltSumFull = sumFullVectorInt32 - numEltProductFull = productFullVectorInt32 - numEltMinIndex _ = minindexVectorInt32 - numEltMaxIndex _ = maxindexVectorInt32 - numEltDotprodInner = dotprodinnerVectorInt32 - -instance NumElt Int64 where - numEltAdd = addVectorInt64 - numEltSub = subVectorInt64 - numEltMul = mulVectorInt64 - numEltNeg = negVectorInt64 - numEltAbs = absVectorInt64 - numEltSignum = signumVectorInt64 - numEltSum1Inner = sum1VectorInt64 - numEltProduct1Inner = product1VectorInt64 - numEltSumFull = sumFullVectorInt64 - numEltProductFull = productFullVectorInt64 - numEltMinIndex _ = minindexVectorInt64 - numEltMaxIndex _ = maxindexVectorInt64 - numEltDotprodInner = dotprodinnerVectorInt64 - -instance NumElt Float where - numEltAdd = addVectorFloat - numEltSub = subVectorFloat - numEltMul = mulVectorFloat - numEltNeg = negVectorFloat - numEltAbs = absVectorFloat - numEltSignum = signumVectorFloat - numEltSum1Inner = sum1VectorFloat - numEltProduct1Inner = product1VectorFloat - numEltSumFull = sumFullVectorFloat - numEltProductFull = productFullVectorFloat - numEltMinIndex _ = minindexVectorFloat - numEltMaxIndex _ = maxindexVectorFloat - numEltDotprodInner = dotprodinnerVectorFloat - -instance NumElt Double where - numEltAdd = addVectorDouble - numEltSub = subVectorDouble - numEltMul = mulVectorDouble - numEltNeg = negVectorDouble - numEltAbs = absVectorDouble - numEltSignum = signumVectorDouble - numEltSum1Inner = sum1VectorDouble - numEltProduct1Inner = product1VectorDouble - numEltSumFull = sumFullVectorDouble - numEltProductFull = productFullVectorDouble - numEltMinIndex _ = minindexVectorDouble - numEltMaxIndex _ = maxindexVectorDouble - numEltDotprodInner = dotprodinnerVectorDouble - -instance NumElt Int where - numEltAdd = intWidBranch2 @Int (+) - (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) - (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @Int (-) - (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) - (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @Int (*) - (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) - (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed1 @Int - (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) - (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) - numEltProduct1Inner = intWidBranchRed1 @Int - (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT)) - (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT)) - numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) - numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) - numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 - numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 - numEltDotprodInner = intWidBranchDotprod @Int (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 - (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 - -instance NumElt CInt where - numEltAdd = intWidBranch2 @CInt (+) - (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) - (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @CInt (-) - (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) - (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @CInt (*) - (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) - (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed1 @CInt - (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) - (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) - numEltProduct1Inner = intWidBranchRed1 @CInt - (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_PRODUCT)) - (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_PRODUCT)) - numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) - numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) - numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 - numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 - numEltDotprodInner = intWidBranchDotprod @CInt (scaleFromSVStrided (c_binary_i32_sv_strided (aboEnum BO_MUL))) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 - (scaleFromSVStrided (c_binary_i64_sv_strided (aboEnum BO_MUL))) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 - -class NumElt a => IntElt a where - intEltQuot :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - intEltRem :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - -instance IntElt Int32 where - intEltQuot = quotVectorInt32 - intEltRem = remVectorInt32 - -instance IntElt Int64 where - intEltQuot = quotVectorInt64 - intEltRem = remVectorInt64 - -instance IntElt Int where - intEltQuot = intWidBranch2 @Int quot - (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) - (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) - intEltRem = intWidBranch2 @Int rem - (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) - (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) - -instance IntElt CInt where - intEltQuot = intWidBranch2 @CInt quot - (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) - (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) - intEltRem = intWidBranch2 @CInt rem - (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) - (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) - -class NumElt a => FloatElt a where - floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - -instance FloatElt Float where - floatEltDiv = divVectorFloat - floatEltPow = powVectorFloat - floatEltLogbase = logbaseVectorFloat - floatEltRecip = recipVectorFloat - floatEltExp = expVectorFloat - floatEltLog = logVectorFloat - floatEltSqrt = sqrtVectorFloat - floatEltSin = sinVectorFloat - floatEltCos = cosVectorFloat - floatEltTan = tanVectorFloat - floatEltAsin = asinVectorFloat - floatEltAcos = acosVectorFloat - floatEltAtan = atanVectorFloat - floatEltSinh = sinhVectorFloat - floatEltCosh = coshVectorFloat - floatEltTanh = tanhVectorFloat - floatEltAsinh = asinhVectorFloat - floatEltAcosh = acoshVectorFloat - floatEltAtanh = atanhVectorFloat - floatEltLog1p = log1pVectorFloat - floatEltExpm1 = expm1VectorFloat - floatEltLog1pexp = log1pexpVectorFloat - floatEltLog1mexp = log1mexpVectorFloat - floatEltAtan2 = atan2VectorFloat - -instance FloatElt Double where - floatEltDiv = divVectorDouble - floatEltPow = powVectorDouble - floatEltLogbase = logbaseVectorDouble - floatEltRecip = recipVectorDouble - floatEltExp = expVectorDouble - floatEltLog = logVectorDouble - floatEltSqrt = sqrtVectorDouble - floatEltSin = sinVectorDouble - floatEltCos = cosVectorDouble - floatEltTan = tanVectorDouble - floatEltAsin = asinVectorDouble - floatEltAcos = acosVectorDouble - floatEltAtan = atanVectorDouble - floatEltSinh = sinhVectorDouble - floatEltCosh = coshVectorDouble - floatEltTanh = tanhVectorDouble - floatEltAsinh = asinhVectorDouble - floatEltAcosh = acoshVectorDouble - floatEltAtanh = atanhVectorDouble - floatEltLog1p = log1pVectorDouble - floatEltExpm1 = expm1VectorDouble - floatEltLog1pexp = log1pexpVectorDouble - floatEltLog1mexp = log1mexpVectorDouble - floatEltAtan2 = atan2VectorDouble +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) + -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs deleted file mode 100644 index 78d5365..0000000 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ /dev/null @@ -1,47 +0,0 @@ -{-# LANGUAGE ForeignFunctionInterface #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Mixed.Internal.Arith.Foreign where - -import Data.Int -import Foreign.C.Types -import Foreign.Ptr -import Language.Haskell.TH - -import Data.Array.Mixed.Internal.Arith.Lists - - -$(do - let importsScal ttyp tyn = - [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) - ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) - ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ] - - let importsInt ttyp tyn = - [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) - ] - - let importsFloat ttyp tyn = - [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) - ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ] - - let generate types imports = - sequence - [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ - | arithtype <- types - , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] - decs1 <- generate typesList importsScal - decs2 <- generate intTypesList importsInt - decs3 <- generate floatTypesList importsFloat - return (decs1 ++ decs2 ++ decs3)) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs deleted file mode 100644 index 370b708..0000000 --- a/src/Data/Array/Mixed/Internal/Arith/Lists.hs +++ /dev/null @@ -1,95 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Mixed.Internal.Arith.Lists where - -import Data.Char -import Data.Int -import Language.Haskell.TH - -import Data.Array.Mixed.Internal.Arith.Lists.TH - - -data ArithType = ArithType - { atType :: Name -- ''Int32 - , atCName :: String -- "i32" - } - -intTypesList :: [ArithType] -intTypesList = - [ArithType ''Int32 "i32" - ,ArithType ''Int64 "i64" - ] - -floatTypesList :: [ArithType] -floatTypesList = - [ArithType ''Float "float" - ,ArithType ''Double "double" - ] - -typesList :: [ArithType] -typesList = intTypesList ++ floatTypesList - --- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) -$(genArithDataType Binop "ArithBOp") - -$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) -$(genArithEnumFun Binop ''ArithBOp "aboEnum") - -$(do clauses <- readArithLists Binop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] - ,return $ FunD (mkName "aboNumOp") clauses]) - - --- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded) -$(genArithDataType IBinop "ArithIBOp") - -$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3)) -$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum") - -$(do clauses <- readArithLists IBinop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |] - ,return $ FunD (mkName "aiboNumOp") clauses]) - - --- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) -$(genArithDataType FBinop "ArithFBOp") - -$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) -$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") - -$(do clauses <- readArithLists FBinop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] - ,return $ FunD (mkName "afboNumOp") clauses]) - - --- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) -$(genArithDataType Unop "ArithUOp") - -$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) -$(genArithEnumFun Unop ''ArithUOp "auoEnum") - - --- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) -$(genArithDataType FUnop "ArithFUOp") - -$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) -$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") - - --- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) -$(genArithDataType Redop "ArithRedOp") - -$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) -$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs deleted file mode 100644 index a156e29..0000000 --- a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs +++ /dev/null @@ -1,83 +0,0 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} -module Data.Array.Mixed.Internal.Arith.Lists.TH where - -import Control.Monad -import Control.Monad.IO.Class -import Data.Maybe -import Foreign.C.Types -import Language.Haskell.TH -import Language.Haskell.TH.Syntax -import Text.Read - - -data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop - deriving (Show, Eq) - -readArithLists :: OpKind - -> (String -> Int -> String -> Q a) - -> ([a] -> Q r) - -> Q r -readArithLists targetkind fop fcombine = do - addDependentFile "cbits/arith_lists.h" - lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" - - mvals <- forM lns $ \line -> do - if null (dropWhile (== ' ') line) - then return Nothing - else do let (kind, name, num, aux) = parseLine line - if kind == targetkind - then Just <$> fop name num aux - else return Nothing - - fcombine (catMaybes mvals) - where - parseLine s0 - | ("LIST_", s1) <- splitAt 5 s0 - , (kindstr, '(' : s2) <- break (== '(') s1 - , (f1, ',' : s3) <- parseField s2 - , (f2, ',' : s4) <- parseField s3 - , (f3, ')' : _) <- parseField s4 - , Just kind <- parseKind kindstr - , let name = f1 - , Just num <- readMaybe f2 - , let aux = f3 - = (kind, name, num, aux) - | otherwise - = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 - - parseField s = break (`elem` ",)") (dropWhile (== ' ') s) - - parseKind "BINOP" = Just Binop - parseKind "IBINOP" = Just IBinop - parseKind "FBINOP" = Just FBinop - parseKind "UNOP" = Just Unop - parseKind "FUNOP" = Just FUnop - parseKind "REDOP" = Just Redop - parseKind _ = Nothing - -genArithDataType :: OpKind -> String -> Q [Dec] -genArithDataType kind dtname = do - cons <- readArithLists kind - (\name _num _ -> return $ NormalC (mkName name) []) - return - return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] - -genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] -genArithNameFun kind dtname funname nametrans = do - clauses <- readArithLists kind - (\name _num _ -> return (Clause [ConP (mkName name) [] []] - (NormalB (LitE (StringL (nametrans name)))) - [])) - return - return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) - ,FunD (mkName funname) clauses] - -genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] -genArithEnumFun kind dtname funname = do - clauses <- readArithLists kind - (\name num _ -> return (Clause [ConP (mkName name) [] []] - (NormalB (LitE (IntegerL (fromIntegral num)))) - [])) - return - return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) - ,FunD (mkName funname) clauses] diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index 71bdc1f..204c1d8 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -34,6 +34,7 @@ import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types +import Data.Array.Strided.Arith type XArray :: [Maybe Nat] -> Type -> Type @@ -240,7 +241,7 @@ transpose2 ssh1 ssh2 (XArray arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = S.unScalar $ - numEltSum1Inner (SNat @0) $ + liftO1 (numEltSum1Inner (SNat @0)) $ S.fromVector [product (S.shapeL arr)] $ S.toVector arr @@ -256,7 +257,7 @@ sumInner ssh ssh' arr go (XArray arr') | Refl <- lemRankApp ssh ssh'F , let sn = listxRank (let StaticShX l = ssh in l) - = XArray (numEltSum1Inner sn arr') + = XArray (liftO1 (numEltSum1Inner sn) arr') in go $ transpose2 ssh'F ssh $ diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index db13da4..9869cba 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -103,7 +103,6 @@ module Data.Array.Nested ( import Prelude hiding (mappend, mconcat) -import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types @@ -112,6 +111,7 @@ import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Ranked import Data.Array.Nested.Internal.Shape import Data.Array.Nested.Internal.Shaped +import Data.Array.Strided.Arith import Foreign.Storable import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 80d581e..eb452dd 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -49,6 +49,7 @@ import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.Permutation import Data.Array.Mixed.Lemmas +import Data.Array.Strided.Arith -- TODO: -- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -225,52 +226,52 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 numEltAdd - (-) = mliftNumElt2 numEltSub - (*) = mliftNumElt2 numEltMul - negate = mliftNumElt1 numEltNeg - abs = mliftNumElt1 numEltAbs - signum = mliftNumElt1 numEltSignum + (+) = mliftNumElt2 (liftO2 . numEltAdd) + (-) = mliftNumElt2 (liftO2 . numEltSub) + (*) = mliftNumElt2 (liftO2 . numEltMul) + negate = mliftNumElt1 (liftO1 . numEltNeg) + abs = mliftNumElt1 (liftO1 . numEltAbs) + signum = mliftNumElt1 (liftO1 . numEltSignum) -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 floatEltRecip - (/) = mliftNumElt2 floatEltDiv + recip = mliftNumElt1 (liftO1 . floatEltRecip) + (/) = mliftNumElt2 (liftO2 . floatEltDiv) instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 floatEltExp - log = mliftNumElt1 floatEltLog - sqrt = mliftNumElt1 floatEltSqrt - - (**) = mliftNumElt2 floatEltPow - logBase = mliftNumElt2 floatEltLogbase - - sin = mliftNumElt1 floatEltSin - cos = mliftNumElt1 floatEltCos - tan = mliftNumElt1 floatEltTan - asin = mliftNumElt1 floatEltAsin - acos = mliftNumElt1 floatEltAcos - atan = mliftNumElt1 floatEltAtan - sinh = mliftNumElt1 floatEltSinh - cosh = mliftNumElt1 floatEltCosh - tanh = mliftNumElt1 floatEltTanh - asinh = mliftNumElt1 floatEltAsinh - acosh = mliftNumElt1 floatEltAcosh - atanh = mliftNumElt1 floatEltAtanh - log1p = mliftNumElt1 floatEltLog1p - expm1 = mliftNumElt1 floatEltExpm1 - log1pexp = mliftNumElt1 floatEltLog1pexp - log1mexp = mliftNumElt1 floatEltLog1mexp + exp = mliftNumElt1 (liftO1 . floatEltExp) + log = mliftNumElt1 (liftO1 . floatEltLog) + sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) + + (**) = mliftNumElt2 (liftO2 . floatEltPow) + logBase = mliftNumElt2 (liftO2 . floatEltLogbase) + + sin = mliftNumElt1 (liftO1 . floatEltSin) + cos = mliftNumElt1 (liftO1 . floatEltCos) + tan = mliftNumElt1 (liftO1 . floatEltTan) + asin = mliftNumElt1 (liftO1 . floatEltAsin) + acos = mliftNumElt1 (liftO1 . floatEltAcos) + atan = mliftNumElt1 (liftO1 . floatEltAtan) + sinh = mliftNumElt1 (liftO1 . floatEltSinh) + cosh = mliftNumElt1 (liftO1 . floatEltCosh) + tanh = mliftNumElt1 (liftO1 . floatEltTanh) + asinh = mliftNumElt1 (liftO1 . floatEltAsinh) + acosh = mliftNumElt1 (liftO1 . floatEltAcosh) + atanh = mliftNumElt1 (liftO1 . floatEltAtanh) + log1p = mliftNumElt1 (liftO1 . floatEltLog1p) + expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) + log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) + log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 intEltQuot -mremArray = mliftNumElt2 intEltRem +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 floatEltAtan2 +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or @@ -867,12 +868,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) -- | Throws if the array is empty. mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr) + ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) -- | Throws if the array is empty. mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr) + ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -883,7 +884,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi _ :$% _ | sh1 == sh2 , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> - fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) | otherwise -> error "mdot1Inner: Unequal shapes" ZSX -> error "unreachable" diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 1c6b789..0a165bc 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -41,13 +41,13 @@ import GHC.TypeNats qualified as TN import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape +import Data.Array.Strided.Arith -- | A rank-typed array: the number of dimensions of the array (its /rank/) is diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 35628db..d7a8ece 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -41,7 +41,6 @@ import GHC.TypeLits import Data.Array.Mixed.XArray (XArray) import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Internal.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape @@ -49,6 +48,7 @@ import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape +import Data.Array.Strided.Arith -- | A shape-typed array: the full shape of the array (the sizes of its |