GPU performance logging/improvements (#3945)

- Improved GPU performance logging

- Only use one execute shards function

- Revert performance regression on multi-GPU

- Use threads to launch NCCL AllReduce
This commit is contained in:
Rory Mitchell 2018-11-29 14:36:51 +13:00 committed by GitHub
parent c5f92df475
commit a9d684db18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 127 additions and 102 deletions

View File

@ -19,6 +19,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "timer.h"
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
@ -840,14 +841,17 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
*/ */
class AllReducer { class AllReducer {
bool initialised; bool initialised_;
bool debug_verbose_;
size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated
size_t allreduce_calls_; // Keep statistics of the number of reduce calls
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams; std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals; std::vector<int> device_ordinals;
#endif #endif
public: public:
AllReducer() : initialised(false) {} AllReducer() : initialised_(false),debug_verbose_(false) {}
/** /**
* \fn void Init(const std::vector<int> &device_ordinals) * \fn void Init(const std::vector<int> &device_ordinals)
@ -858,8 +862,10 @@ class AllReducer {
* \param device_ordinals The device ordinals. * \param device_ordinals The device ordinals.
*/ */
void Init(const std::vector<int> &device_ordinals) { void Init(const std::vector<int> &device_ordinals, bool debug_verbose) {
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
this->debug_verbose_ = debug_verbose;
this->device_ordinals = device_ordinals; this->device_ordinals = device_ordinals;
comms.resize(device_ordinals.size()); comms.resize(device_ordinals.size());
dh::safe_nccl(ncclCommInitAll(comms.data(), dh::safe_nccl(ncclCommInitAll(comms.data(),
@ -870,7 +876,7 @@ class AllReducer {
safe_cuda(cudaSetDevice(device_ordinals[i])); safe_cuda(cudaSetDevice(device_ordinals[i]));
safe_cuda(cudaStreamCreate(&streams[i])); safe_cuda(cudaStreamCreate(&streams[i]));
} }
initialised = true; initialised_ = true;
#else #else
CHECK_EQ(device_ordinals.size(), 1) CHECK_EQ(device_ordinals.size(), 1)
<< "XGBoost must be compiled with NCCL to use more than one GPU."; << "XGBoost must be compiled with NCCL to use more than one GPU.";
@ -878,7 +884,7 @@ class AllReducer {
} }
~AllReducer() { ~AllReducer() {
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
if (initialised) { if (initialised_) {
for (auto &stream : streams) { for (auto &stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream)); dh::safe_cuda(cudaStreamDestroy(stream));
} }
@ -886,6 +892,11 @@ class AllReducer {
ncclCommDestroy(comm); ncclCommDestroy(comm);
} }
} }
if (debug_verbose_) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MB communicated: " << allreduce_bytes_/1000000;
}
#endif #endif
} }
@ -920,11 +931,16 @@ class AllReducer {
void AllReduceSum(int communication_group_idx, const double *sendbuff, void AllReduceSum(int communication_group_idx, const double *sendbuff,
double *recvbuff, int count) { double *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
CHECK(initialised); CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx))); dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
comms.at(communication_group_idx), comms.at(communication_group_idx),
streams.at(communication_group_idx))); streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
#endif #endif
} }
@ -942,7 +958,7 @@ class AllReducer {
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
int64_t *recvbuff, int count) { int64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
CHECK(initialised); CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx])); dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
@ -989,27 +1005,6 @@ class SaveCudaContext {
} }
}; };
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element.
*
* \tparam T Generic type parameter.
* \tparam FunctionT Type of the function t.
* \param shards The shards.
* \param f The func_t to process.
*/
template <typename T, typename FunctionT>
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
f(shards->at(shard));
}
}};
}
/** /**
* \brief Executes some operation on each element of the input vector, using a * \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element. In addition, passes the shard index * single controlling thread for each element. In addition, passes the shard index
@ -1023,13 +1018,12 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
template <typename T, typename FunctionT> template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) { void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext { SaveCudaContext{[&]() {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1) #pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) { for (int shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard)); f(shard, shards->at(shard));
} }
}}; }};
} }
/** /**

View File

@ -358,7 +358,7 @@ struct GPUSketcher {
}); });
// compute sketches for each shard // compute sketches for each shard
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info); shard->Init(batch, info);
shard->Sketch(batch, info); shard->Sketch(batch, info);
}); });

View File

@ -275,7 +275,7 @@ struct HostDeviceVectorImpl {
(end - begin) * sizeof(T), (end - begin) * sizeof(T),
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
} else { } else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(begin.get()); shard.ScatterFrom(begin.get());
}); });
} }
@ -288,7 +288,7 @@ struct HostDeviceVectorImpl {
data_h_.size() * sizeof(T), data_h_.size() * sizeof(T),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
} else { } else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); }); dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.GatherTo(begin); });
} }
} }
@ -296,7 +296,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) { if (perm_h_.CanWrite()) {
std::fill(data_h_.begin(), data_h_.end(), v); std::fill(data_h_.begin(), data_h_.end(), v);
} else { } else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.Fill(v); }); dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Fill(v); });
} }
} }
@ -323,7 +323,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) { if (perm_h_.CanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin()); std::copy(other.begin(), other.end(), data_h_.begin());
} else { } else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(other.data()); shard.ScatterFrom(other.data());
}); });
} }
@ -334,7 +334,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) { if (perm_h_.CanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin()); std::copy(other.begin(), other.end(), data_h_.begin());
} else { } else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(other.begin()); shard.ScatterFrom(other.begin());
}); });
} }
@ -387,14 +387,14 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanAccess(access)) { return; } if (perm_h_.CanAccess(access)) { return; }
if (perm_h_.CanRead()) { if (perm_h_.CanRead()) {
// data is present, just need to deny access to the device // data is present, just need to deny access to the device
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.perm_d_.DenyComplementary(access); shard.perm_d_.DenyComplementary(access);
}); });
perm_h_.Grant(access); perm_h_.Grant(access);
return; return;
} }
if (data_h_.size() != size_d_) { data_h_.resize(size_d_); } if (data_h_.size() != size_d_) { data_h_.resize(size_d_); }
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.LazySyncHost(access); shard.LazySyncHost(access);
}); });
perm_h_.Grant(access); perm_h_.Grant(access);

View File

@ -45,9 +45,13 @@ struct Timer {
*/ */
struct Monitor { struct Monitor {
struct Statistics {
Timer timer;
size_t count{0};
};
bool debug_verbose = false; bool debug_verbose = false;
std::string label = ""; std::string label = "";
std::map<std::string, Timer> timer_map; std::map<std::string, Statistics> statistics_map;
Timer self_timer; Timer self_timer;
Monitor() { self_timer.Start(); } Monitor() { self_timer.Start(); }
@ -56,35 +60,46 @@ struct Monitor {
if (!debug_verbose) return; if (!debug_verbose) return;
LOG(CONSOLE) << "======== Monitor: " << label << " ========"; LOG(CONSOLE) << "======== Monitor: " << label << " ========";
for (auto &kv : timer_map) { for (auto &kv : statistics_map) {
kv.second.PrintElapsed(kv.first); LOG(CONSOLE) << kv.first << ": " << kv.second.timer.ElapsedSeconds()
<< "s, " << kv.second.count << " calls @ "
<< std::chrono::duration_cast<std::chrono::microseconds>(
kv.second.timer.elapsed / kv.second.count)
.count()
<< "us";
} }
self_timer.Stop(); self_timer.Stop();
self_timer.PrintElapsed(label + " Lifetime");
} }
void Init(std::string label, bool debug_verbose) { void Init(std::string label, bool debug_verbose) {
this->debug_verbose = debug_verbose; this->debug_verbose = debug_verbose;
this->label = label; this->label = label;
} }
void Start(const std::string &name) { timer_map[name].Start(); } void Start(const std::string &name) { statistics_map[name].timer.Start(); }
void Start(const std::string &name, GPUSet devices) { void Start(const std::string &name, GPUSet devices) {
if (debug_verbose) { if (debug_verbose) {
#ifdef __CUDACC__ #ifdef __CUDACC__
#include "device_helpers.cuh" for (auto device : devices) {
dh::SynchronizeNDevices(devices); cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif #endif
} }
timer_map[name].Start(); statistics_map[name].timer.Start();
}
void Stop(const std::string &name) {
statistics_map[name].timer.Stop();
statistics_map[name].count++;
} }
void Stop(const std::string &name) { timer_map[name].Stop(); }
void Stop(const std::string &name, GPUSet devices) { void Stop(const std::string &name, GPUSet devices) {
if (debug_verbose) { if (debug_verbose) {
#ifdef __CUDACC__ #ifdef __CUDACC__
#include "device_helpers.cuh" for (auto device : devices) {
dh::SynchronizeNDevices(devices); cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif #endif
} }
timer_map[name].Stop(); this->Stop(name);
} }
}; };
} // namespace common } // namespace common

View File

@ -258,7 +258,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
monitor.Start("UpdateGpair"); monitor.Start("UpdateGpair");
// Update gpair // Update gpair
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) { dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->UpdateGpair(in_gpair->ConstHostVector(), model->param); shard->UpdateGpair(in_gpair->ConstHostVector(), model->param);
}); });
monitor.Stop("UpdateGpair"); monitor.Stop("UpdateGpair");
@ -300,7 +300,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
model->bias()[group_idx] += dbias; model->bias()[group_idx] += dbias;
// Update residual // Update residual
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) { dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->UpdateBiasResidual(dbias, group_idx, shard->UpdateBiasResidual(dbias, group_idx,
model->param.num_output_group); model->param.num_output_group);
}); });
@ -324,7 +324,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
param.reg_lambda_denorm)); param.reg_lambda_denorm));
w += dw; w += dw;
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) { dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
}); });
} }

View File

@ -337,10 +337,10 @@ class GPUPredictor : public xgboost::Predictor {
std::vector<size_t> device_offsets; std::vector<size_t> device_offsets;
DeviceOffsets(batch.offset, &device_offsets); DeviceOffsets(batch.offset, &device_offsets);
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
dh::ExecuteShards(&shards, [&](DeviceShard& shard){ dh::ExecuteIndexShards(&shards, [&](int idx, DeviceShard& shard) {
shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments, shard.PredictInternal(batch, dmat->Info(), out_preds, model,
h_nodes, tree_begin, tree_end); h_tree_segments, h_nodes, tree_begin, tree_end);
}); });
i_batch++; i_batch++;
} }
} }

View File

@ -587,23 +587,6 @@ struct DeviceShard {
return best_split; return best_split;
} }
/** \brief Builds both left and right hist with subtraction trick if possible.
*/
void BuildHistWithSubtractionTrick(int nidx_parent, int nidx_left,
int nidx_right) {
auto smallest_nidx =
ridx_segments[nidx_left].Size() < ridx_segments[nidx_right].Size()
? nidx_left
: nidx_right;
auto largest_nidx = smallest_nidx == nidx_left ? nidx_right : nidx_left;
this->BuildHist(smallest_nidx);
if (this->CanDoSubtractionTrick(nidx_parent, smallest_nidx, largest_nidx)) {
this->SubtractionTrick(nidx_parent, smallest_nidx, largest_nidx);
} else {
this->BuildHist(largest_nidx);
}
}
void BuildHist(int nidx) { void BuildHist(int nidx) {
hist.AllocateHistogram(nidx); hist.AllocateHistogram(nidx);
hist_builder->Build(this, nidx); hist_builder->Build(this, nidx);
@ -954,7 +937,7 @@ class GPUHistMaker : public TreeUpdater {
device_list_[index] = device_id; device_list_[index] = device_id;
} }
reducer_.Init(device_list_); reducer_.Init(device_list_, param_.debug_verbose);
auto batch_iter = dmat->GetRowBatches().begin(); auto batch_iter = dmat->GetRowBatches().begin();
const SparsePage& batch = *batch_iter; const SparsePage& batch = *batch_iter;
@ -976,7 +959,7 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Stop("Quantiles", dist_.Devices()); monitor_.Stop("Quantiles", dist_.Devices());
monitor_.Start("BinningCompression", dist_.Devices()); monitor_.Start("BinningCompression", dist_.Devices());
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->InitCompressedData(hmat_, batch); shard->InitCompressedData(hmat_, batch);
}); });
monitor_.Stop("BinningCompression", dist_.Devices()); monitor_.Stop("BinningCompression", dist_.Devices());
@ -1000,7 +983,7 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Start("InitDataReset", dist_.Devices()); monitor_.Start("InitDataReset", dist_.Devices());
gpair->Reshard(dist_); gpair->Reshard(dist_);
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Reset(gpair); shard->Reset(gpair);
}); });
monitor_.Stop("InitDataReset", dist_.Devices()); monitor_.Stop("InitDataReset", dist_.Devices());
@ -1009,34 +992,66 @@ class GPUHistMaker : public TreeUpdater {
void AllReduceHist(int nidx) { void AllReduceHist(int nidx) {
if (shards_.size() == 1) return; if (shards_.size() == 1) return;
reducer_.GroupStart(); monitor_.Start("AllReduce");
for (auto& shard : shards_) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
reducer_.AllReduceSum( reducer_.AllReduceSum(
dist_.Devices().Index(shard->device_id_), dist_.Devices().Index(shard->device_id_),
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist), reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist), reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT))); n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT)));
} });
reducer_.GroupEnd(); monitor_.Stop("AllReduce");
reducer_.Synchronize();
} }
/** /**
* \brief Build GPU local histograms for the left and right child of some parent node * \brief Build GPU local histograms for the left and right child of some parent node
*/ */
void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) { void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) {
// If one GPU size_t left_node_max_elements = 0;
if (shards_.size() == 1) { size_t right_node_max_elements = 0;
shards_.back()->BuildHistWithSubtractionTrick(nidx_parent, nidx_left, nidx_right); for (auto& shard : shards_) {
} else { left_node_max_elements = (std::max)(
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { left_node_max_elements, shard->ridx_segments[nidx_left].Size());
shard->BuildHist(nidx_left); right_node_max_elements = (std::max)(
shard->BuildHist(nidx_right); right_node_max_elements, shard->ridx_segments[nidx_right].Size());
}
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
if (right_node_max_elements < left_node_max_elements) {
build_hist_nidx = nidx_right;
subtraction_trick_nidx = nidx_left;
}
// Build histogram for node with the smallest number of training examples
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->BuildHist(build_hist_nidx);
});
this->AllReduceHist(build_hist_nidx);
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = true;
for (auto& shard : shards_) {
do_subtraction_trick &= shard->CanDoSubtractionTrick(
nidx_parent, build_hist_nidx, subtraction_trick_nidx);
}
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
subtraction_trick_nidx);
}); });
this->AllReduceHist(nidx_left); } else {
this->AllReduceHist(nidx_right); // Calculate other histogram manually
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->BuildHist(subtraction_trick_nidx);
});
this->AllReduceHist(subtraction_trick_nidx);
} }
} }
@ -1061,7 +1076,7 @@ class GPUHistMaker : public TreeUpdater {
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair()); std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
// Generate root histogram // Generate root histogram
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->BuildHist(root_nidx); shard->BuildHist(root_nidx);
}); });
@ -1107,7 +1122,7 @@ class GPUHistMaker : public TreeUpdater {
} }
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx,
split_gidx, default_dir_left, split_gidx, default_dir_left,
is_dense, fidx_begin, fidx_end); is_dense, fidx_begin, fidx_end);
@ -1153,7 +1168,6 @@ class GPUHistMaker : public TreeUpdater {
shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum; shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum; shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
} }
this->UpdatePosition(candidate, p_tree);
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
@ -1175,9 +1189,10 @@ class GPUHistMaker : public TreeUpdater {
qexpand_->pop(); qexpand_->pop();
if (!candidate.IsValid(param_, num_leaves)) continue; if (!candidate.IsValid(param_, num_leaves)) continue;
monitor_.Start("ApplySplit", dist_.Devices());
this->ApplySplit(candidate, p_tree); this->ApplySplit(candidate, p_tree);
monitor_.Stop("ApplySplit", dist_.Devices()); monitor_.Start("UpdatePosition", dist_.Devices());
this->UpdatePosition(candidate, p_tree);
monitor_.Stop("UpdatePosition", dist_.Devices());
num_leaves++; num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild(); int left_child_nidx = tree[candidate.nid].LeftChild();
@ -1213,7 +1228,7 @@ class GPUHistMaker : public TreeUpdater {
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false; return false;
p_out_preds->Reshard(dist_.Devices()); p_out_preds->Reshard(dist_.Devices());
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->UpdatePredictionCache( shard->UpdatePredictionCache(
p_out_preds->DevicePointer(shard->device_id_)); p_out_preds->DevicePointer(shard->device_id_));
}); });

View File

@ -375,6 +375,7 @@ TEST(GpuHist, ApplySplit) {
hist_maker.info_ = &info; hist_maker.info_ = &info;
hist_maker.ApplySplit(candidate_entry, &tree); hist_maker.ApplySplit(candidate_entry, &tree);
hist_maker.UpdatePosition(candidate_entry, &tree);
ASSERT_FALSE(tree[nid].IsLeaf()); ASSERT_FALSE(tree[nid].IsLeaf());