Define CUDA Context. (#8604)
We will transition to non-default and non-blocking CUDA stream.
This commit is contained in:
@@ -267,12 +267,12 @@ __global__ void __launch_bounds__(kBlockThreads)
|
||||
}
|
||||
}
|
||||
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantiser rounding, bool force_global_memory) {
|
||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
||||
bool force_global_memory) {
|
||||
// decide whether to use shared memory
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
@@ -318,9 +318,9 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
min(grid_size,
|
||||
unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
|
||||
|
||||
dh::LaunchKernel {dim3(grid_size, num_groups),
|
||||
static_cast<uint32_t>(kBlockThreads), smem_size}(
|
||||
kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding);
|
||||
dh::LaunchKernel{dim3(grid_size, num_groups), static_cast<uint32_t>(kBlockThreads), smem_size,
|
||||
ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(),
|
||||
gpair.data(), rounding);
|
||||
};
|
||||
|
||||
if (shared) {
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
#define HISTOGRAM_CUH_
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include "feature_groups.cuh"
|
||||
|
||||
#include "../../common/cuda_context.cuh"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "feature_groups.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -56,12 +56,11 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantiser rounding,
|
||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
||||
bool force_global_memory = false);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user