Remove various synchronisations from cuda API calls, instrument monitor (#4205)
* Remove various synchronisations from cuda API calls, instrument monitor with nvtx profiler ranges.
This commit is contained in:
parent
f83e62dca5
commit
4eeeded7d1
@ -10,9 +10,12 @@ msvc_use_static_runtime()
|
|||||||
# Options
|
# Options
|
||||||
## GPUs
|
## GPUs
|
||||||
option(USE_CUDA "Build with GPU acceleration" OFF)
|
option(USE_CUDA "Build with GPU acceleration" OFF)
|
||||||
|
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
|
||||||
option(USE_NCCL "Build with multiple GPUs support" OFF)
|
option(USE_NCCL "Build with multiple GPUs support" OFF)
|
||||||
set(GPU_COMPUTE_VER "" CACHE STRING
|
set(GPU_COMPUTE_VER "" CACHE STRING
|
||||||
"Space separated list of compute versions to be built against, e.g. '35 61'")
|
"Space separated list of compute versions to be built against, e.g. '35 61'")
|
||||||
|
set(NVTX_HEADER_DIR "" CACHE PATH
|
||||||
|
"Path to the stand-alone nvtx header")
|
||||||
|
|
||||||
## Bindings
|
## Bindings
|
||||||
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
||||||
@ -175,6 +178,11 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE))
|
|||||||
add_definitions(-DXGBOOST_USE_NCCL)
|
add_definitions(-DXGBOOST_USE_NCCL)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(USE_NVTX)
|
||||||
|
cuda_include_directories("${NVTX_HEADER_DIR}")
|
||||||
|
add_definitions(-DXGBOOST_USE_NVTX)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(GENCODE_FLAGS "")
|
set(GENCODE_FLAGS "")
|
||||||
format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS)
|
format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS)
|
||||||
message("cuda architecture flags: ${GENCODE_FLAGS}")
|
message("cuda architecture flags: ${GENCODE_FLAGS}")
|
||||||
@ -190,6 +198,7 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE))
|
|||||||
link_directories(${NCCL_LIBRARY})
|
link_directories(${NCCL_LIBRARY})
|
||||||
target_link_libraries(gpuxgboost ${NCCL_LIB_NAME})
|
target_link_libraries(gpuxgboost ${NCCL_LIB_NAME})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(APPEND LINK_LIBRARIES gpuxgboost)
|
list(APPEND LINK_LIBRARIES gpuxgboost)
|
||||||
|
|
||||||
elseif (USE_CUDA AND GENERATE_COMPILATION_DATABASE)
|
elseif (USE_CUDA AND GENERATE_COMPILATION_DATABASE)
|
||||||
|
|||||||
@ -195,6 +195,10 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a
|
|||||||
|
|
||||||
See `GPU Accelerated XGBoost <https://xgboost.ai/2016/12/14/GPU-accelerated-xgboost.html>`_ and `Updates to the XGBoost GPU algorithms <https://xgboost.ai/2018/07/04/gpu-xgboost-update.html>`_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods.
|
See `GPU Accelerated XGBoost <https://xgboost.ai/2016/12/14/GPU-accelerated-xgboost.html>`_ and `Updates to the XGBoost GPU algorithms <https://xgboost.ai/2018/07/04/gpu-xgboost-update.html>`_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods.
|
||||||
|
|
||||||
|
Developer notes
|
||||||
|
==========
|
||||||
|
The application may be profiled with annotations by specifying USE_NTVX to cmake and providing the path to the stand-alone nvtx header via NVTX_HEADER_DIR. Regions covered by the 'Monitor' class in cuda code will automatically appear in the nsight profiler.
|
||||||
|
|
||||||
**********
|
**********
|
||||||
References
|
References
|
||||||
**********
|
**********
|
||||||
|
|||||||
@ -308,7 +308,7 @@ class DVec {
|
|||||||
}
|
}
|
||||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||||
if (other.DeviceIdx() == this->DeviceIdx()) {
|
if (other.DeviceIdx() == this->DeviceIdx()) {
|
||||||
dh::safe_cuda(cudaMemcpy(this->Data(), other.Data(),
|
dh::safe_cuda(cudaMemcpyAsync(this->Data(), other.Data(),
|
||||||
other.Size() * sizeof(T),
|
other.Size() * sizeof(T),
|
||||||
cudaMemcpyDeviceToDevice));
|
cudaMemcpyDeviceToDevice));
|
||||||
} else {
|
} else {
|
||||||
@ -338,7 +338,7 @@ class DVec {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Cannot copy assign vector to dvec, sizes are different");
|
"Cannot copy assign vector to dvec, sizes are different");
|
||||||
}
|
}
|
||||||
safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T),
|
safe_cuda(cudaMemcpyAsync(this->Data(), begin.get(), Size() * sizeof(T),
|
||||||
cudaMemcpyDefault));
|
cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -290,14 +290,14 @@ struct GPUSketcher {
|
|||||||
offset_vec[row_begin_ + batch_row_begin];
|
offset_vec[row_begin_ + batch_row_begin];
|
||||||
// copy the batch to the GPU
|
// copy the batch to the GPU
|
||||||
dh::safe_cuda
|
dh::safe_cuda
|
||||||
(cudaMemcpy(entries_.data().get(),
|
(cudaMemcpyAsync(entries_.data().get(),
|
||||||
data_vec.data() + offset_vec[row_begin_ + batch_row_begin],
|
data_vec.data() + offset_vec[row_begin_ + batch_row_begin],
|
||||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||||
// copy the weights if necessary
|
// copy the weights if necessary
|
||||||
if (has_weights_) {
|
if (has_weights_) {
|
||||||
const auto& weights_vec = info.weights_.HostVector();
|
const auto& weights_vec = info.weights_.HostVector();
|
||||||
dh::safe_cuda
|
dh::safe_cuda
|
||||||
(cudaMemcpy(weights_.data().get(),
|
(cudaMemcpyAsync(weights_.data().get(),
|
||||||
weights_vec.data() + row_begin_ + batch_row_begin,
|
weights_vec.data() + row_begin_ + batch_row_begin,
|
||||||
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
|
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
@ -315,15 +315,11 @@ struct GPUSketcher {
|
|||||||
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
|
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
|
||||||
gpu_batch_nrows_, num_cols_,
|
gpu_batch_nrows_, num_cols_,
|
||||||
offset_vec[row_begin_ + batch_row_begin], batch_nrows);
|
offset_vec[row_begin_ + batch_row_begin], batch_nrows);
|
||||||
dh::safe_cuda(cudaGetLastError()); // NOLINT
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
|
|
||||||
|
|
||||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||||
FindColumnCuts(batch_nrows, icol);
|
FindColumnCuts(batch_nrows, icol);
|
||||||
}
|
}
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
|
|
||||||
|
|
||||||
// add cuts into sketches
|
// add cuts into sketches
|
||||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||||
|
|||||||
@ -74,14 +74,14 @@ struct HostDeviceVectorImpl {
|
|||||||
// TODO(canonizer): avoid full copy of host data
|
// TODO(canonizer): avoid full copy of host data
|
||||||
LazySyncDevice(GPUAccess::kWrite);
|
LazySyncDevice(GPUAccess::kWrite);
|
||||||
SetDevice();
|
SetDevice();
|
||||||
dh::safe_cuda(cudaMemcpy(data_.data().get(), begin + start_,
|
dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), begin + start_,
|
||||||
data_.size() * sizeof(T), cudaMemcpyDefault));
|
data_.size() * sizeof(T), cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherTo(thrust::device_ptr<T> begin) {
|
void GatherTo(thrust::device_ptr<T> begin) {
|
||||||
LazySyncDevice(GPUAccess::kRead);
|
LazySyncDevice(GPUAccess::kRead);
|
||||||
SetDevice();
|
SetDevice();
|
||||||
dh::safe_cuda(cudaMemcpy(begin.get() + start_, data_.data().get(),
|
dh::safe_cuda(cudaMemcpyAsync(begin.get() + start_, data_.data().get(),
|
||||||
proper_size_ * sizeof(T), cudaMemcpyDefault));
|
proper_size_ * sizeof(T), cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ struct HostDeviceVectorImpl {
|
|||||||
LazySyncDevice(GPUAccess::kWrite);
|
LazySyncDevice(GPUAccess::kWrite);
|
||||||
other->LazySyncDevice(GPUAccess::kRead);
|
other->LazySyncDevice(GPUAccess::kRead);
|
||||||
SetDevice();
|
SetDevice();
|
||||||
dh::safe_cuda(cudaMemcpy(data_.data().get(), other->data_.data().get(),
|
dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), other->data_.data().get(),
|
||||||
data_.size() * sizeof(T), cudaMemcpyDefault));
|
data_.size() * sizeof(T), cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,9 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "common.h"
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
|
#include <nvToolsExt.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -45,9 +47,11 @@ struct Timer {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
struct Monitor {
|
struct Monitor {
|
||||||
|
private:
|
||||||
struct Statistics {
|
struct Statistics {
|
||||||
Timer timer;
|
Timer timer;
|
||||||
size_t count{0};
|
size_t count{0};
|
||||||
|
uint64_t nvtx_id;
|
||||||
};
|
};
|
||||||
std::string label = "";
|
std::string label = "";
|
||||||
std::map<std::string, Statistics> statistics_map;
|
std::map<std::string, Statistics> statistics_map;
|
||||||
@ -75,35 +79,37 @@ struct Monitor {
|
|||||||
}
|
}
|
||||||
self_timer.Stop();
|
self_timer.Stop();
|
||||||
}
|
}
|
||||||
void Init(std::string label) {
|
void Init(std::string label) { this->label = label; }
|
||||||
this->label = label;
|
void Start(const std::string &name) {
|
||||||
}
|
|
||||||
void Start(const std::string &name) { statistics_map[name].timer.Start(); }
|
|
||||||
void Start(const std::string &name, GPUSet devices) {
|
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
#ifdef __CUDACC__
|
|
||||||
for (auto device : devices) {
|
|
||||||
cudaSetDevice(device);
|
|
||||||
cudaDeviceSynchronize();
|
|
||||||
}
|
|
||||||
#endif // __CUDACC__
|
|
||||||
}
|
|
||||||
statistics_map[name].timer.Start();
|
statistics_map[name].timer.Start();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
void Stop(const std::string &name) {
|
void Stop(const std::string &name) {
|
||||||
statistics_map[name].timer.Stop();
|
|
||||||
statistics_map[name].count++;
|
|
||||||
}
|
|
||||||
void Stop(const std::string &name, GPUSet devices) {
|
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
#ifdef __CUDACC__
|
auto &stats = statistics_map[name];
|
||||||
for (auto device : devices) {
|
stats.timer.Stop();
|
||||||
cudaSetDevice(device);
|
stats.count++;
|
||||||
cudaDeviceSynchronize();
|
|
||||||
}
|
}
|
||||||
#endif // __CUDACC__
|
|
||||||
}
|
}
|
||||||
this->Stop(name);
|
void StartCuda(const std::string &name) {
|
||||||
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
auto &stats = statistics_map[name];
|
||||||
|
stats.timer.Start();
|
||||||
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
|
stats.nvtx_id = nvtxRangeStartA(name.c_str());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void StopCuda(const std::string &name) {
|
||||||
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
auto &stats = statistics_map[name];
|
||||||
|
stats.timer.Stop();
|
||||||
|
stats.count++;
|
||||||
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
|
nvtxRangeEnd(stats.nvtx_id);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -145,8 +145,6 @@ class Transform {
|
|||||||
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
|
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
|
||||||
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
|
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
|
||||||
_func, shard_range, UnpackHDV(_vectors, device)...);
|
_func, shard_range, UnpackHDV(_vectors, device)...);
|
||||||
dh::safe_cuda(cudaGetLastError());
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|||||||
@ -252,17 +252,17 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
size_t tree_begin, size_t tree_end) {
|
size_t tree_begin, size_t tree_end) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
nodes.resize(h_nodes.size());
|
nodes.resize(h_nodes.size());
|
||||||
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
|
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes), h_nodes.data(),
|
||||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
tree_segments.resize(h_tree_segments.size());
|
tree_segments.resize(h_tree_segments.size());
|
||||||
|
|
||||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
|
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments), h_tree_segments.data(),
|
||||||
sizeof(size_t) * h_tree_segments.size(),
|
sizeof(size_t) * h_tree_segments.size(),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
tree_group.resize(model.tree_info.size());
|
tree_group.resize(model.tree_info.size());
|
||||||
|
|
||||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
|
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group), model.tree_info.data(),
|
||||||
sizeof(int) * model.tree_info.size(),
|
sizeof(int) * model.tree_info.size(),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
|
|
||||||
@ -288,9 +288,6 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
|
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
|
||||||
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
||||||
num_rows, entry_start, use_shared, model.param.num_output_group);
|
num_rows, entry_start, use_shared, model.param.num_output_group);
|
||||||
|
|
||||||
dh::safe_cuda(cudaGetLastError());
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int device_;
|
int device_;
|
||||||
|
|||||||
@ -289,7 +289,9 @@ struct DeviceHistogram {
|
|||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
data.resize(0);
|
dh::safe_cuda(cudaMemsetAsync(
|
||||||
|
data.data().get(), 0,
|
||||||
|
data.size() * sizeof(typename decltype(data)::value_type)));
|
||||||
nidx_map.clear();
|
nidx_map.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,20 +301,25 @@ struct DeviceHistogram {
|
|||||||
|
|
||||||
void AllocateHistogram(int nidx) {
|
void AllocateHistogram(int nidx) {
|
||||||
if (HistogramExists(nidx)) return;
|
if (HistogramExists(nidx)) return;
|
||||||
|
size_t current_size =
|
||||||
|
nidx_map.size() * n_bins * 2; // Number of items currently used in data
|
||||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
if (data.size() > kStopGrowingSize) {
|
if (data.size() >= kStopGrowingSize) {
|
||||||
// Recycle histogram memory
|
// Recycle histogram memory
|
||||||
std::pair<int, size_t> old_entry = *nidx_map.begin();
|
std::pair<int, size_t> old_entry = *nidx_map.begin();
|
||||||
nidx_map.erase(old_entry.first);
|
nidx_map.erase(old_entry.first);
|
||||||
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0,
|
dh::safe_cuda(cudaMemsetAsync(data.data().get() + old_entry.second, 0,
|
||||||
n_bins * sizeof(GradientSumT)));
|
n_bins * sizeof(GradientSumT)));
|
||||||
nidx_map[nidx] = old_entry.second;
|
nidx_map[nidx] = old_entry.second;
|
||||||
} else {
|
} else {
|
||||||
// Append new node histogram
|
// Append new node histogram
|
||||||
nidx_map[nidx] = data.size();
|
nidx_map[nidx] = current_size;
|
||||||
// x 2: Hess and Grad.
|
if (data.size() < current_size + n_bins * 2) {
|
||||||
data.resize(data.size() + (n_bins * 2));
|
size_t new_size = current_size * 2; // Double in size
|
||||||
|
new_size = std::max(static_cast<size_t>(n_bins * 2),
|
||||||
|
new_size); // Have at least one histogram
|
||||||
|
data.resize(new_size);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -610,20 +617,20 @@ struct DeviceShard {
|
|||||||
feature_set_d.resize(feature_set.size());
|
feature_set_d.resize(feature_set.size());
|
||||||
auto d_features = common::Span<int>(feature_set_d.data().get(),
|
auto d_features = common::Span<int>(feature_set_d.data().get(),
|
||||||
feature_set_d.size());
|
feature_set_d.size());
|
||||||
dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(),
|
dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(),
|
||||||
d_features.size_bytes(), cudaMemcpyDefault));
|
d_features.size_bytes(), cudaMemcpyDefault));
|
||||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||||
|
|
||||||
// One block for each feature
|
// One block for each feature
|
||||||
int constexpr BLOCK_THREADS = 256;
|
int constexpr BLOCK_THREADS = 256;
|
||||||
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
|
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
|
||||||
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>
|
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>(
|
||||||
(hist.GetNodeHistogram(nidx), d_features, node,
|
hist.GetNodeHistogram(nidx), d_features, node,
|
||||||
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
||||||
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||||
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
|
d_split_candidates, value_constraint,
|
||||||
|
monotone_constraints.GetSpan());
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
|
||||||
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
|
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
|
||||||
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||||
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
||||||
@ -725,20 +732,21 @@ struct DeviceShard {
|
|||||||
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
|
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
|
||||||
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
|
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
|
||||||
left_nidx, right_nidx, left_count);
|
left_nidx, right_nidx, left_count);
|
||||||
// Copy back key
|
// Copy back key/value
|
||||||
dh::safe_cuda(cudaMemcpy(
|
const auto d_position_current = position.Current() + segment.begin;
|
||||||
position.Current() + segment.begin, position.other() + segment.begin,
|
const auto d_position_other = position.other() + segment.begin;
|
||||||
segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice));
|
const auto d_ridx_current = ridx.Current() + segment.begin;
|
||||||
// Copy back value
|
const auto d_ridx_other = ridx.other() + segment.begin;
|
||||||
dh::safe_cuda(cudaMemcpy(
|
dh::LaunchN(device_id_, segment.Size(), [=] __device__(size_t idx) {
|
||||||
ridx.Current() + segment.begin, ridx.other() + segment.begin,
|
d_position_current[idx] = d_position_other[idx];
|
||||||
segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice));
|
d_ridx_current[idx] = d_ridx_other[idx];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
if (!prediction_cache_initialised) {
|
if (!prediction_cache_initialised) {
|
||||||
dh::safe_cuda(cudaMemcpy(
|
dh::safe_cuda(cudaMemcpyAsync(
|
||||||
prediction_cache.Data(), out_preds_d,
|
prediction_cache.Data(), out_preds_d,
|
||||||
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
|
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
@ -746,7 +754,7 @@ struct DeviceShard {
|
|||||||
|
|
||||||
CalcWeightTrainParam param_d(param);
|
CalcWeightTrainParam param_d(param);
|
||||||
|
|
||||||
dh::safe_cuda(cudaMemcpy(node_sum_gradients_d.Data(),
|
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(),
|
||||||
node_sum_gradients.data(),
|
node_sum_gradients.data(),
|
||||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
@ -925,9 +933,6 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
|
|||||||
batch_row_begin, batch_nrows,
|
batch_row_begin, batch_nrows,
|
||||||
row_ptrs[batch_row_begin],
|
row_ptrs[batch_row_begin],
|
||||||
row_stride, null_gidx_value);
|
row_stride, null_gidx_value);
|
||||||
|
|
||||||
dh::safe_cuda(cudaGetLastError());
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// free the memory that is no longer needed
|
// free the memory that is no longer needed
|
||||||
@ -965,7 +970,7 @@ class GPUHistMakerSpecialised{
|
|||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
const std::vector<RegTree*>& trees) {
|
const std::vector<RegTree*>& trees) {
|
||||||
monitor_.Start("Update", dist_.Devices());
|
monitor_.StartCuda("Update");
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
param_.learning_rate = lr / trees.size();
|
param_.learning_rate = lr / trees.size();
|
||||||
@ -980,7 +985,7 @@ class GPUHistMakerSpecialised{
|
|||||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
monitor_.Stop("Update", dist_.Devices());
|
monitor_.StopCuda("Update");
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitDataOnce(DMatrix* dmat) {
|
void InitDataOnce(DMatrix* dmat) {
|
||||||
@ -1010,17 +1015,17 @@ class GPUHistMakerSpecialised{
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Find the cuts.
|
// Find the cuts.
|
||||||
monitor_.Start("Quantiles", dist_.Devices());
|
monitor_.StartCuda("Quantiles");
|
||||||
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
|
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
|
||||||
n_bins_ = hmat_.row_ptr.back();
|
n_bins_ = hmat_.row_ptr.back();
|
||||||
monitor_.Stop("Quantiles", dist_.Devices());
|
monitor_.StopCuda("Quantiles");
|
||||||
|
|
||||||
monitor_.Start("BinningCompression", dist_.Devices());
|
monitor_.StartCuda("BinningCompression");
|
||||||
dh::ExecuteIndexShards(&shards_, [&](int idx,
|
dh::ExecuteIndexShards(&shards_, [&](int idx,
|
||||||
std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||||
shard->InitCompressedData(hmat_, batch);
|
shard->InitCompressedData(hmat_, batch);
|
||||||
});
|
});
|
||||||
monitor_.Stop("BinningCompression", dist_.Devices());
|
monitor_.StopCuda("BinningCompression");
|
||||||
++batch_iter;
|
++batch_iter;
|
||||||
CHECK(batch_iter.AtEnd()) << "External memory not supported";
|
CHECK(batch_iter.AtEnd()) << "External memory not supported";
|
||||||
|
|
||||||
@ -1030,16 +1035,16 @@ class GPUHistMakerSpecialised{
|
|||||||
|
|
||||||
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
|
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
|
||||||
if (!initialised_) {
|
if (!initialised_) {
|
||||||
monitor_.Start("InitDataOnce", dist_.Devices());
|
monitor_.StartCuda("InitDataOnce");
|
||||||
this->InitDataOnce(dmat);
|
this->InitDataOnce(dmat);
|
||||||
monitor_.Stop("InitDataOnce", dist_.Devices());
|
monitor_.StopCuda("InitDataOnce");
|
||||||
}
|
}
|
||||||
|
|
||||||
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
|
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
|
||||||
param_.colsample_bylevel, param_.colsample_bytree);
|
param_.colsample_bylevel, param_.colsample_bytree);
|
||||||
|
|
||||||
// Copy gpair & reset memory
|
// Copy gpair & reset memory
|
||||||
monitor_.Start("InitDataReset", dist_.Devices());
|
monitor_.StartCuda("InitDataReset");
|
||||||
|
|
||||||
gpair->Reshard(dist_);
|
gpair->Reshard(dist_);
|
||||||
dh::ExecuteIndexShards(
|
dh::ExecuteIndexShards(
|
||||||
@ -1047,13 +1052,12 @@ class GPUHistMakerSpecialised{
|
|||||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||||
shard->Reset(gpair);
|
shard->Reset(gpair);
|
||||||
});
|
});
|
||||||
monitor_.Stop("InitDataReset", dist_.Devices());
|
monitor_.StopCuda("InitDataReset");
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduceHist(int nidx) {
|
void AllReduceHist(int nidx) {
|
||||||
if (shards_.size() == 1 && !rabit::IsDistributed())
|
if (shards_.size() == 1 && !rabit::IsDistributed()) return;
|
||||||
return;
|
monitor_.StartCuda("AllReduce");
|
||||||
monitor_.Start("AllReduce");
|
|
||||||
|
|
||||||
reducer_.GroupStart();
|
reducer_.GroupStart();
|
||||||
for (auto& shard : shards_) {
|
for (auto& shard : shards_) {
|
||||||
@ -1067,7 +1071,7 @@ class GPUHistMakerSpecialised{
|
|||||||
reducer_.GroupEnd();
|
reducer_.GroupEnd();
|
||||||
reducer_.Synchronize();
|
reducer_.Synchronize();
|
||||||
|
|
||||||
monitor_.Stop("AllReduce");
|
monitor_.StopCuda("AllReduce");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1250,12 +1254,12 @@ class GPUHistMakerSpecialised{
|
|||||||
RegTree* p_tree) {
|
RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
|
|
||||||
monitor_.Start("InitData", dist_.Devices());
|
monitor_.StartCuda("InitData");
|
||||||
this->InitData(gpair, p_fmat);
|
this->InitData(gpair, p_fmat);
|
||||||
monitor_.Stop("InitData", dist_.Devices());
|
monitor_.StopCuda("InitData");
|
||||||
monitor_.Start("InitRoot", dist_.Devices());
|
monitor_.StartCuda("InitRoot");
|
||||||
this->InitRoot(p_tree);
|
this->InitRoot(p_tree);
|
||||||
monitor_.Stop("InitRoot", dist_.Devices());
|
monitor_.StopCuda("InitRoot");
|
||||||
|
|
||||||
auto timestamp = qexpand_->size();
|
auto timestamp = qexpand_->size();
|
||||||
auto num_leaves = 1;
|
auto num_leaves = 1;
|
||||||
@ -1266,9 +1270,9 @@ class GPUHistMakerSpecialised{
|
|||||||
if (!candidate.IsValid(param_, num_leaves)) continue;
|
if (!candidate.IsValid(param_, num_leaves)) continue;
|
||||||
|
|
||||||
this->ApplySplit(candidate, p_tree);
|
this->ApplySplit(candidate, p_tree);
|
||||||
monitor_.Start("UpdatePosition", dist_.Devices());
|
monitor_.StartCuda("UpdatePosition");
|
||||||
this->UpdatePosition(candidate, p_tree);
|
this->UpdatePosition(candidate, p_tree);
|
||||||
monitor_.Stop("UpdatePosition", dist_.Devices());
|
monitor_.StopCuda("UpdatePosition");
|
||||||
num_leaves++;
|
num_leaves++;
|
||||||
|
|
||||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||||
@ -1277,32 +1281,30 @@ class GPUHistMakerSpecialised{
|
|||||||
// Only create child entries if needed
|
// Only create child entries if needed
|
||||||
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
||||||
num_leaves)) {
|
num_leaves)) {
|
||||||
monitor_.Start("BuildHist", dist_.Devices());
|
monitor_.StartCuda("BuildHist");
|
||||||
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
||||||
right_child_nidx);
|
right_child_nidx);
|
||||||
monitor_.Stop("BuildHist", dist_.Devices());
|
monitor_.StopCuda("BuildHist");
|
||||||
|
|
||||||
monitor_.Start("EvaluateSplits", dist_.Devices());
|
monitor_.StartCuda("EvaluateSplits");
|
||||||
auto left_child_split =
|
auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree);
|
||||||
this->EvaluateSplit(left_child_nidx, p_tree);
|
auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree);
|
||||||
auto right_child_split =
|
|
||||||
this->EvaluateSplit(right_child_nidx, p_tree);
|
|
||||||
qexpand_->push(ExpandEntry(left_child_nidx,
|
qexpand_->push(ExpandEntry(left_child_nidx,
|
||||||
tree.GetDepth(left_child_nidx), left_child_split,
|
tree.GetDepth(left_child_nidx),
|
||||||
timestamp++));
|
left_child_split, timestamp++));
|
||||||
qexpand_->push(ExpandEntry(right_child_nidx,
|
qexpand_->push(ExpandEntry(right_child_nidx,
|
||||||
tree.GetDepth(right_child_nidx), right_child_split,
|
tree.GetDepth(right_child_nidx),
|
||||||
timestamp++));
|
right_child_split, timestamp++));
|
||||||
monitor_.Stop("EvaluateSplits", dist_.Devices());
|
monitor_.StopCuda("EvaluateSplits");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(
|
bool UpdatePredictionCache(
|
||||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||||
monitor_.Start("UpdatePredictionCache", dist_.Devices());
|
|
||||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||||
return false;
|
return false;
|
||||||
|
monitor_.StartCuda("UpdatePredictionCache");
|
||||||
p_out_preds->Reshard(dist_.Devices());
|
p_out_preds->Reshard(dist_.Devices());
|
||||||
dh::ExecuteIndexShards(
|
dh::ExecuteIndexShards(
|
||||||
&shards_,
|
&shards_,
|
||||||
@ -1310,7 +1312,7 @@ class GPUHistMakerSpecialised{
|
|||||||
shard->UpdatePredictionCache(
|
shard->UpdatePredictionCache(
|
||||||
p_out_preds->DevicePointer(shard->device_id_));
|
p_out_preds->DevicePointer(shard->device_id_));
|
||||||
});
|
});
|
||||||
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
|
monitor_.StopCuda("UpdatePredictionCache");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user