Implement unified update prediction cache for (gpu_)hist. (#6860)

* Implement utilites for linalg.
* Unify the update prediction cache functions.
* Implement update prediction cache for multi-class gpu hist.
This commit is contained in:
Jiaming Yuan
2021-04-17 00:29:34 +08:00
committed by GitHub
parent 1b26a2a561
commit 556a83022d
10 changed files with 246 additions and 68 deletions

View File

@@ -273,9 +273,9 @@ struct GPUHistMakerDevice {
if (d_gpair.size() != dh_gpair->Size()) {
d_gpair.resize(dh_gpair->Size());
}
thrust::copy(thrust::device, dh_gpair->ConstDevicePointer(),
dh_gpair->ConstDevicePointer() + dh_gpair->Size(),
d_gpair.begin());
dh::safe_cuda(cudaMemcpyAsync(
d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat);
page = sample.page;
gpair = sample.gpair;
@@ -528,8 +528,9 @@ struct GPUHistMakerDevice {
});
}
void UpdatePredictionCache(common::Span<bst_float> out_preds_d) {
void UpdatePredictionCache(VectorView<float> out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id));
CHECK_EQ(out_preds_d.DeviceIdx(), device_id);
auto d_ridx = row_partitioner->GetRows();
GPUTrainingParam param_d(param);
@@ -543,14 +544,14 @@ struct GPUHistMakerDevice {
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
dh::LaunchN(
device_id, out_preds_d.size(), [=] __device__(int local_idx) {
int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight(pos, param_d,
GradStats{d_node_sum_gradients[pos]});
out_preds_d[d_ridx[local_idx]] +=
weight * param_d.learning_rate;
});
dh::LaunchN(device_id, d_ridx.size(), [=] __device__(int local_idx) {
int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight(
pos, param_d, GradStats{d_node_sum_gradients[pos]});
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
auto v_predt = out_preds_d; // for some reaon out_preds_d is const by both nvcc and clang.
v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate;
});
row_partitioner.reset();
}
@@ -862,13 +863,12 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
}
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
bool UpdatePredictionCache(const DMatrix* data, VectorView<bst_float> p_out_preds) {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
monitor_.Start("UpdatePredictionCache");
p_out_preds->SetDevice(device_);
maker->UpdatePredictionCache(p_out_preds->DeviceSpan());
maker->UpdatePredictionCache(p_out_preds);
monitor_.Stop("UpdatePredictionCache");
return true;
}
@@ -947,8 +947,8 @@ class GPUHistMaker : public TreeUpdater {
}
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
bool UpdatePredictionCache(const DMatrix *data,
VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else {