GPU performance logging/improvements (#3945)
- Improved GPU performance logging - Only use one execute shards function - Revert performance regression on multi-GPU - Use threads to launch NCCL AllReduce
This commit is contained in:
parent
c5f92df475
commit
a9d684db18
@ -19,6 +19,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "timer.h"
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
@ -840,14 +841,17 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
|
||||
*/
|
||||
|
||||
class AllReducer {
|
||||
bool initialised;
|
||||
bool initialised_;
|
||||
bool debug_verbose_;
|
||||
size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated
|
||||
size_t allreduce_calls_; // Keep statistics of the number of reduce calls
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
std::vector<ncclComm_t> comms;
|
||||
std::vector<cudaStream_t> streams;
|
||||
std::vector<int> device_ordinals;
|
||||
#endif
|
||||
public:
|
||||
AllReducer() : initialised(false) {}
|
||||
AllReducer() : initialised_(false),debug_verbose_(false) {}
|
||||
|
||||
/**
|
||||
* \fn void Init(const std::vector<int> &device_ordinals)
|
||||
@ -858,8 +862,10 @@ class AllReducer {
|
||||
* \param device_ordinals The device ordinals.
|
||||
*/
|
||||
|
||||
void Init(const std::vector<int> &device_ordinals) {
|
||||
void Init(const std::vector<int> &device_ordinals, bool debug_verbose) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
/** \brief this >monitor . init. */
|
||||
this->debug_verbose_ = debug_verbose;
|
||||
this->device_ordinals = device_ordinals;
|
||||
comms.resize(device_ordinals.size());
|
||||
dh::safe_nccl(ncclCommInitAll(comms.data(),
|
||||
@ -870,7 +876,7 @@ class AllReducer {
|
||||
safe_cuda(cudaSetDevice(device_ordinals[i]));
|
||||
safe_cuda(cudaStreamCreate(&streams[i]));
|
||||
}
|
||||
initialised = true;
|
||||
initialised_ = true;
|
||||
#else
|
||||
CHECK_EQ(device_ordinals.size(), 1)
|
||||
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
|
||||
@ -878,7 +884,7 @@ class AllReducer {
|
||||
}
|
||||
~AllReducer() {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
if (initialised) {
|
||||
if (initialised_) {
|
||||
for (auto &stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
}
|
||||
@ -886,6 +892,11 @@ class AllReducer {
|
||||
ncclCommDestroy(comm);
|
||||
}
|
||||
}
|
||||
if (debug_verbose_) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MB communicated: " << allreduce_bytes_/1000000;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -920,11 +931,16 @@ class AllReducer {
|
||||
void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
||||
double *recvbuff, int count) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
CHECK(initialised);
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
|
||||
comms.at(communication_group_idx),
|
||||
streams.at(communication_group_idx)));
|
||||
if(communication_group_idx == 0)
|
||||
{
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -942,7 +958,7 @@ class AllReducer {
|
||||
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
|
||||
int64_t *recvbuff, int count) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
CHECK(initialised);
|
||||
CHECK(initialised_);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
|
||||
@ -989,27 +1005,6 @@ class SaveCudaContext {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a
|
||||
* single controlling thread for each element.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \tparam FunctionT Type of the function t.
|
||||
* \param shards The shards.
|
||||
* \param f The func_t to process.
|
||||
*/
|
||||
|
||||
template <typename T, typename FunctionT>
|
||||
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||
SaveCudaContext {
|
||||
[&](){
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shards->at(shard));
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a
|
||||
* single controlling thread for each element. In addition, passes the shard index
|
||||
@ -1023,13 +1018,12 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||
|
||||
template <typename T, typename FunctionT>
|
||||
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
||||
SaveCudaContext {
|
||||
[&](){
|
||||
SaveCudaContext{[&]() {
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shard, shards->at(shard));
|
||||
}
|
||||
}};
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shard, shards->at(shard));
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -358,7 +358,7 @@ struct GPUSketcher {
|
||||
});
|
||||
|
||||
// compute sketches for each shard
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->Init(batch, info);
|
||||
shard->Sketch(batch, info);
|
||||
});
|
||||
|
||||
@ -275,7 +275,7 @@ struct HostDeviceVectorImpl {
|
||||
(end - begin) * sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
} else {
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||
shard.ScatterFrom(begin.get());
|
||||
});
|
||||
}
|
||||
@ -288,7 +288,7 @@ struct HostDeviceVectorImpl {
|
||||
data_h_.size() * sizeof(T),
|
||||
cudaMemcpyHostToDevice));
|
||||
} else {
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.GatherTo(begin); });
|
||||
}
|
||||
}
|
||||
|
||||
@ -296,7 +296,7 @@ struct HostDeviceVectorImpl {
|
||||
if (perm_h_.CanWrite()) {
|
||||
std::fill(data_h_.begin(), data_h_.end(), v);
|
||||
} else {
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.Fill(v); });
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Fill(v); });
|
||||
}
|
||||
}
|
||||
|
||||
@ -323,7 +323,7 @@ struct HostDeviceVectorImpl {
|
||||
if (perm_h_.CanWrite()) {
|
||||
std::copy(other.begin(), other.end(), data_h_.begin());
|
||||
} else {
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||
shard.ScatterFrom(other.data());
|
||||
});
|
||||
}
|
||||
@ -334,7 +334,7 @@ struct HostDeviceVectorImpl {
|
||||
if (perm_h_.CanWrite()) {
|
||||
std::copy(other.begin(), other.end(), data_h_.begin());
|
||||
} else {
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||
shard.ScatterFrom(other.begin());
|
||||
});
|
||||
}
|
||||
@ -387,14 +387,14 @@ struct HostDeviceVectorImpl {
|
||||
if (perm_h_.CanAccess(access)) { return; }
|
||||
if (perm_h_.CanRead()) {
|
||||
// data is present, just need to deny access to the device
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||
shard.perm_d_.DenyComplementary(access);
|
||||
});
|
||||
perm_h_.Grant(access);
|
||||
return;
|
||||
}
|
||||
if (data_h_.size() != size_d_) { data_h_.resize(size_d_); }
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||
shard.LazySyncHost(access);
|
||||
});
|
||||
perm_h_.Grant(access);
|
||||
|
||||
@ -45,9 +45,13 @@ struct Timer {
|
||||
*/
|
||||
|
||||
struct Monitor {
|
||||
struct Statistics {
|
||||
Timer timer;
|
||||
size_t count{0};
|
||||
};
|
||||
bool debug_verbose = false;
|
||||
std::string label = "";
|
||||
std::map<std::string, Timer> timer_map;
|
||||
std::map<std::string, Statistics> statistics_map;
|
||||
Timer self_timer;
|
||||
|
||||
Monitor() { self_timer.Start(); }
|
||||
@ -56,35 +60,46 @@ struct Monitor {
|
||||
if (!debug_verbose) return;
|
||||
|
||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||
for (auto &kv : timer_map) {
|
||||
kv.second.PrintElapsed(kv.first);
|
||||
for (auto &kv : statistics_map) {
|
||||
LOG(CONSOLE) << kv.first << ": " << kv.second.timer.ElapsedSeconds()
|
||||
<< "s, " << kv.second.count << " calls @ "
|
||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
kv.second.timer.elapsed / kv.second.count)
|
||||
.count()
|
||||
<< "us";
|
||||
}
|
||||
self_timer.Stop();
|
||||
self_timer.PrintElapsed(label + " Lifetime");
|
||||
}
|
||||
void Init(std::string label, bool debug_verbose) {
|
||||
this->debug_verbose = debug_verbose;
|
||||
this->label = label;
|
||||
}
|
||||
void Start(const std::string &name) { timer_map[name].Start(); }
|
||||
void Start(const std::string &name) { statistics_map[name].timer.Start(); }
|
||||
void Start(const std::string &name, GPUSet devices) {
|
||||
if (debug_verbose) {
|
||||
#ifdef __CUDACC__
|
||||
#include "device_helpers.cuh"
|
||||
dh::SynchronizeNDevices(devices);
|
||||
for (auto device : devices) {
|
||||
cudaSetDevice(device);
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
timer_map[name].Start();
|
||||
statistics_map[name].timer.Start();
|
||||
}
|
||||
void Stop(const std::string &name) {
|
||||
statistics_map[name].timer.Stop();
|
||||
statistics_map[name].count++;
|
||||
}
|
||||
void Stop(const std::string &name) { timer_map[name].Stop(); }
|
||||
void Stop(const std::string &name, GPUSet devices) {
|
||||
if (debug_verbose) {
|
||||
#ifdef __CUDACC__
|
||||
#include "device_helpers.cuh"
|
||||
dh::SynchronizeNDevices(devices);
|
||||
for (auto device : devices) {
|
||||
cudaSetDevice(device);
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
timer_map[name].Stop();
|
||||
this->Stop(name);
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
|
||||
@ -258,7 +258,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
|
||||
monitor.Start("UpdateGpair");
|
||||
// Update gpair
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdateGpair(in_gpair->ConstHostVector(), model->param);
|
||||
});
|
||||
monitor.Stop("UpdateGpair");
|
||||
@ -300,7 +300,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
model->bias()[group_idx] += dbias;
|
||||
|
||||
// Update residual
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdateBiasResidual(dbias, group_idx,
|
||||
model->param.num_output_group);
|
||||
});
|
||||
@ -324,7 +324,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
param.reg_lambda_denorm));
|
||||
w += dw;
|
||||
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||
});
|
||||
}
|
||||
|
||||
@ -337,10 +337,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
std::vector<size_t> device_offsets;
|
||||
DeviceOffsets(batch.offset, &device_offsets);
|
||||
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
|
||||
dh::ExecuteShards(&shards, [&](DeviceShard& shard){
|
||||
shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments,
|
||||
h_nodes, tree_begin, tree_end);
|
||||
});
|
||||
dh::ExecuteIndexShards(&shards, [&](int idx, DeviceShard& shard) {
|
||||
shard.PredictInternal(batch, dmat->Info(), out_preds, model,
|
||||
h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
});
|
||||
i_batch++;
|
||||
}
|
||||
}
|
||||
|
||||
@ -587,23 +587,6 @@ struct DeviceShard {
|
||||
return best_split;
|
||||
}
|
||||
|
||||
/** \brief Builds both left and right hist with subtraction trick if possible.
|
||||
*/
|
||||
void BuildHistWithSubtractionTrick(int nidx_parent, int nidx_left,
|
||||
int nidx_right) {
|
||||
auto smallest_nidx =
|
||||
ridx_segments[nidx_left].Size() < ridx_segments[nidx_right].Size()
|
||||
? nidx_left
|
||||
: nidx_right;
|
||||
auto largest_nidx = smallest_nidx == nidx_left ? nidx_right : nidx_left;
|
||||
this->BuildHist(smallest_nidx);
|
||||
if (this->CanDoSubtractionTrick(nidx_parent, smallest_nidx, largest_nidx)) {
|
||||
this->SubtractionTrick(nidx_parent, smallest_nidx, largest_nidx);
|
||||
} else {
|
||||
this->BuildHist(largest_nidx);
|
||||
}
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
hist.AllocateHistogram(nidx);
|
||||
hist_builder->Build(this, nidx);
|
||||
@ -954,7 +937,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
device_list_[index] = device_id;
|
||||
}
|
||||
|
||||
reducer_.Init(device_list_);
|
||||
reducer_.Init(device_list_, param_.debug_verbose);
|
||||
|
||||
auto batch_iter = dmat->GetRowBatches().begin();
|
||||
const SparsePage& batch = *batch_iter;
|
||||
@ -976,7 +959,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor_.Stop("Quantiles", dist_.Devices());
|
||||
|
||||
monitor_.Start("BinningCompression", dist_.Devices());
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->InitCompressedData(hmat_, batch);
|
||||
});
|
||||
monitor_.Stop("BinningCompression", dist_.Devices());
|
||||
@ -1000,7 +983,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor_.Start("InitDataReset", dist_.Devices());
|
||||
|
||||
gpair->Reshard(dist_);
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->Reset(gpair);
|
||||
});
|
||||
monitor_.Stop("InitDataReset", dist_.Devices());
|
||||
@ -1009,34 +992,66 @@ class GPUHistMaker : public TreeUpdater {
|
||||
void AllReduceHist(int nidx) {
|
||||
if (shards_.size() == 1) return;
|
||||
|
||||
reducer_.GroupStart();
|
||||
for (auto& shard : shards_) {
|
||||
monitor_.Start("AllReduce");
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
||||
reducer_.AllReduceSum(
|
||||
dist_.Devices().Index(shard->device_id_),
|
||||
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<GradientPairSumT::ValueT*>(d_node_hist),
|
||||
n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT)));
|
||||
}
|
||||
reducer_.GroupEnd();
|
||||
|
||||
reducer_.Synchronize();
|
||||
});
|
||||
monitor_.Stop("AllReduce");
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) {
|
||||
// If one GPU
|
||||
if (shards_.size() == 1) {
|
||||
shards_.back()->BuildHistWithSubtractionTrick(nidx_parent, nidx_left, nidx_right);
|
||||
} else {
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->BuildHist(nidx_left);
|
||||
shard->BuildHist(nidx_right);
|
||||
size_t left_node_max_elements = 0;
|
||||
size_t right_node_max_elements = 0;
|
||||
for (auto& shard : shards_) {
|
||||
left_node_max_elements = (std::max)(
|
||||
left_node_max_elements, shard->ridx_segments[nidx_left].Size());
|
||||
right_node_max_elements = (std::max)(
|
||||
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
|
||||
}
|
||||
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
if (right_node_max_elements < left_node_max_elements) {
|
||||
build_hist_nidx = nidx_right;
|
||||
subtraction_trick_nidx = nidx_left;
|
||||
}
|
||||
|
||||
// Build histogram for node with the smallest number of training examples
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->BuildHist(build_hist_nidx);
|
||||
});
|
||||
|
||||
this->AllReduceHist(build_hist_nidx);
|
||||
|
||||
// Check whether we can use the subtraction trick to calculate the other
|
||||
bool do_subtraction_trick = true;
|
||||
for (auto& shard : shards_) {
|
||||
do_subtraction_trick &= shard->CanDoSubtractionTrick(
|
||||
nidx_parent, build_hist_nidx, subtraction_trick_nidx);
|
||||
}
|
||||
|
||||
if (do_subtraction_trick) {
|
||||
// Calculate other histogram using subtraction trick
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
|
||||
subtraction_trick_nidx);
|
||||
});
|
||||
this->AllReduceHist(nidx_left);
|
||||
this->AllReduceHist(nidx_right);
|
||||
} else {
|
||||
// Calculate other histogram manually
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->BuildHist(subtraction_trick_nidx);
|
||||
});
|
||||
|
||||
this->AllReduceHist(subtraction_trick_nidx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1061,7 +1076,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
|
||||
|
||||
// Generate root histogram
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->BuildHist(root_nidx);
|
||||
});
|
||||
|
||||
@ -1107,7 +1122,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
||||
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx,
|
||||
split_gidx, default_dir_left,
|
||||
is_dense, fidx_begin, fidx_end);
|
||||
@ -1153,7 +1168,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
|
||||
shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
|
||||
}
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
@ -1175,9 +1189,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
qexpand_->pop();
|
||||
if (!candidate.IsValid(param_, num_leaves)) continue;
|
||||
|
||||
monitor_.Start("ApplySplit", dist_.Devices());
|
||||
this->ApplySplit(candidate, p_tree);
|
||||
monitor_.Stop("ApplySplit", dist_.Devices());
|
||||
monitor_.Start("UpdatePosition", dist_.Devices());
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor_.Stop("UpdatePosition", dist_.Devices());
|
||||
num_leaves++;
|
||||
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
@ -1213,7 +1228,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||
return false;
|
||||
p_out_preds->Reshard(dist_.Devices());
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdatePredictionCache(
|
||||
p_out_preds->DevicePointer(shard->device_id_));
|
||||
});
|
||||
|
||||
@ -375,6 +375,7 @@ TEST(GpuHist, ApplySplit) {
|
||||
|
||||
hist_maker.info_ = &info;
|
||||
hist_maker.ApplySplit(candidate_entry, &tree);
|
||||
hist_maker.UpdatePosition(candidate_entry, &tree);
|
||||
|
||||
ASSERT_FALSE(tree[nid].IsLeaf());
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user