Fix dart inplace prediction with GPU input. (#6777)
* Fix dart inplace predict with data on GPU, which might trigger a fatal check for device access right. * Avoid copying data whenever possible.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file gbtree.cc
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
@@ -558,6 +558,23 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
return cpu_predictor_;
|
||||
}
|
||||
|
||||
/** Increment the prediction on GPU.
|
||||
*
|
||||
* \param out_predts Prediction for the whole model.
|
||||
* \param predts Prediction for current tree.
|
||||
* \param tree_w Tree weight.
|
||||
*/
|
||||
void GPUDartPredictInc(common::Span<float> out_predts,
|
||||
common::Span<float> predts, float tree_w, size_t n_rows,
|
||||
bst_group_t n_groups, bst_group_t group)
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
; // NOLINT
|
||||
#else
|
||||
{
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif
|
||||
|
||||
class Dart : public GBTree {
|
||||
public:
|
||||
explicit Dart(LearnerModelParam const* booster_config) :
|
||||
@@ -647,31 +664,46 @@ class Dart : public GBTree {
|
||||
model_);
|
||||
p_out_preds->version = 0;
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
auto n_groups = model_.learner_model_param->num_output_group;
|
||||
|
||||
PredictionCacheEntry predts; // temporary storage for prediction
|
||||
if (generic_param_->gpu_id != GenericParameter::kCpuId) {
|
||||
predts.predictions.SetDevice(generic_param_->gpu_id);
|
||||
}
|
||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||
|
||||
for (size_t i = tree_begin; i < tree_end; i += 1) {
|
||||
if (training &&
|
||||
std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
|
||||
if (training && std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CHECK_GE(i, p_out_preds->version);
|
||||
auto version = i / this->LayerTrees();
|
||||
p_out_preds->version = version;
|
||||
|
||||
auto n_groups = model_.learner_model_param->num_output_group;
|
||||
PredictionCacheEntry predts;
|
||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||
predts.predictions.Fill(0);
|
||||
predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1);
|
||||
|
||||
// Multiple the weight to output prediction.
|
||||
auto w = this->weight_drop_.at(i);
|
||||
auto &h_predts = predts.predictions.HostVector();
|
||||
auto group = model_.tree_info.at(i);
|
||||
auto &h_out_predts = p_out_preds->predictions.HostVector();
|
||||
CHECK_EQ(h_out_predts.size(), h_predts.size());
|
||||
for (size_t ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
|
||||
const size_t offset = ridx * n_groups + group;
|
||||
h_out_predts[offset] += (h_predts[offset] * w);
|
||||
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());
|
||||
|
||||
size_t n_rows = p_fmat->Info().num_row_;
|
||||
if (predts.predictions.DeviceIdx() != GenericParameter::kCpuId) {
|
||||
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
|
||||
GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(),
|
||||
predts.predictions.DeviceSpan(), w, n_rows, n_groups,
|
||||
group);
|
||||
} else {
|
||||
auto &h_out_predts = p_out_preds->predictions.HostVector();
|
||||
auto &h_predts = predts.predictions.HostVector();
|
||||
#pragma omp parallel for
|
||||
for (omp_ulong ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
|
||||
const size_t offset = ridx * n_groups + group;
|
||||
h_out_predts[offset] += (h_predts[offset] * w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -699,6 +731,7 @@ class Dart : public GBTree {
|
||||
|
||||
MetaInfo info;
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
int32_t device = GenericParameter::kCpuId;
|
||||
// Inplace predict is not used for training, so no need to drop tree.
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
PredictionCacheEntry predts;
|
||||
@@ -709,6 +742,9 @@ class Dart : public GBTree {
|
||||
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
|
||||
i + 1)) {
|
||||
success = true;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
device = predts.predictions.DeviceIdx();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -716,14 +752,16 @@ class Dart : public GBTree {
|
||||
} else {
|
||||
// No base margin for each tree
|
||||
bool success = this->GetPredictor()->InplacePredict(
|
||||
x, nullptr, model_, missing, &predts, tree_begin, tree_end);
|
||||
x, nullptr, model_, missing, &predts, i, i + 1);
|
||||
device = predts.predictions.DeviceIdx();
|
||||
CHECK(success) << msg;
|
||||
}
|
||||
|
||||
auto w = this->weight_drop_.at(i);
|
||||
auto &h_predts = predts.predictions.HostVector();
|
||||
auto &h_out_predts = out_preds->predictions.HostVector();
|
||||
if (h_out_predts.empty()) {
|
||||
|
||||
if (i == tree_begin) {
|
||||
auto n_rows =
|
||||
h_predts.size() / model_.learner_model_param->num_output_group;
|
||||
if (p_m) {
|
||||
@@ -739,12 +777,19 @@ class Dart : public GBTree {
|
||||
|
||||
// Multiple the tree weight
|
||||
CHECK_EQ(h_predts.size(), h_out_predts.size());
|
||||
for (size_t i = 0; i < h_out_predts.size(); ++i) {
|
||||
|
||||
#pragma omp parallel for
|
||||
for (omp_ulong i = 0; i < h_out_predts.size(); ++i) {
|
||||
// Need to remove the base margin from indiviual tree.
|
||||
h_out_predts[i] +=
|
||||
(h_predts[i] - model_.learner_model_param->base_score) * w;
|
||||
}
|
||||
}
|
||||
|
||||
if (device != GenericParameter::kCpuId) {
|
||||
out_preds->predictions.SetDevice(device);
|
||||
out_preds->predictions.DeviceSpan();
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
|
||||
18
src/gbm/gbtree.cu
Normal file
18
src/gbm/gbtree.cu
Normal file
@@ -0,0 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2021 by Contributors
|
||||
*/
|
||||
#include "xgboost/span.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
void GPUDartPredictInc(common::Span<float> out_predts,
|
||||
common::Span<float> predts, float tree_w, size_t n_rows,
|
||||
bst_group_t n_groups, bst_group_t group) {
|
||||
dh::LaunchN(dh::CurrentDevice(), n_rows, [=]XGBOOST_DEVICE(size_t ridx) {
|
||||
const size_t offset = ridx * n_groups + group;
|
||||
out_predts[offset] += (predts[offset] * tree_w);
|
||||
});
|
||||
}
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user