Fix dart inplace prediction with GPU input. (#6777)

* Fix dart inplace predict with data on GPU, which might trigger a fatal check
for device access right.
* Avoid copying data whenever possible.
This commit is contained in:
Jiaming Yuan 2021-03-25 12:00:32 +08:00 committed by GitHub
parent 1d90577800
commit a7083d3c13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 135 additions and 25 deletions

View File

@ -185,9 +185,15 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
size_t begin, size_t end, SketchContainer *sketch_container,
int num_cuts_per_feature, size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
dh::device_vector<Entry> sorted_entries;
if (page.data.DeviceCanRead()) {
const auto& device_data = page.data.ConstDevicePointer();
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end);
} else {
const auto& host_data = page.data.ConstHostVector();
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin,
host_data.begin() + end);
}
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

View File

@ -92,7 +92,10 @@ class HostDeviceVectorImpl {
} else {
gpu_access_ = GPUAccess::kWrite;
SetDevice();
thrust::fill(data_d_->begin(), data_d_->end(), v);
auto s_data = dh::ToSpan(*data_d_);
dh::LaunchN(device_, data_d_->size(), [=]XGBOOST_DEVICE(size_t i) {
s_data[i] = v;
});
}
}

View File

@ -407,7 +407,6 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t gpu_batch_nrows =
std::min(dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(row_batch.Size()));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);
@ -429,9 +428,18 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
if (row_batch.data.DeviceCanRead()) {
auto const& d_data = row_batch.data.ConstDeviceSpan();
dh::safe_cuda(cudaMemcpyAsync(
entries_d.data().get(), d_data.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
} else {
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
dh::safe_cuda(cudaMemcpyAsync(
entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
}
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2014-2020 by Contributors
* Copyright 2014-2021 by Contributors
* \file gbtree.cc
* \brief gradient boosted tree implementation.
* \author Tianqi Chen
@ -558,6 +558,23 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
return cpu_predictor_;
}
/** Increment the prediction on GPU.
*
* \param out_predts Prediction for the whole model.
* \param predts Prediction for current tree.
* \param tree_w Tree weight.
*/
void GPUDartPredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w, size_t n_rows,
bst_group_t n_groups, bst_group_t group)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
}
#endif
class Dart : public GBTree {
public:
explicit Dart(LearnerModelParam const* booster_config) :
@ -647,31 +664,46 @@ class Dart : public GBTree {
model_);
p_out_preds->version = 0;
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
std::tie(tree_begin, tree_end) =
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
auto n_groups = model_.learner_model_param->num_output_group;
PredictionCacheEntry predts; // temporary storage for prediction
if (generic_param_->gpu_id != GenericParameter::kCpuId) {
predts.predictions.SetDevice(generic_param_->gpu_id);
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
for (size_t i = tree_begin; i < tree_end; i += 1) {
if (training &&
std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
if (training && std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
continue;
}
CHECK_GE(i, p_out_preds->version);
auto version = i / this->LayerTrees();
p_out_preds->version = version;
auto n_groups = model_.learner_model_param->num_output_group;
PredictionCacheEntry predts;
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
predts.predictions.Fill(0);
predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1);
// Multiple the weight to output prediction.
auto w = this->weight_drop_.at(i);
auto &h_predts = predts.predictions.HostVector();
auto group = model_.tree_info.at(i);
auto &h_out_predts = p_out_preds->predictions.HostVector();
CHECK_EQ(h_out_predts.size(), h_predts.size());
for (size_t ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
const size_t offset = ridx * n_groups + group;
h_out_predts[offset] += (h_predts[offset] * w);
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());
size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != GenericParameter::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows, n_groups,
group);
} else {
auto &h_out_predts = p_out_preds->predictions.HostVector();
auto &h_predts = predts.predictions.HostVector();
#pragma omp parallel for
for (omp_ulong ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
const size_t offset = ridx * n_groups + group;
h_out_predts[offset] += (h_predts[offset] * w);
}
}
}
}
@ -699,6 +731,7 @@ class Dart : public GBTree {
MetaInfo info;
StringView msg{"Unsupported data type for inplace predict."};
int32_t device = GenericParameter::kCpuId;
// Inplace predict is not used for training, so no need to drop tree.
for (size_t i = tree_begin; i < tree_end; ++i) {
PredictionCacheEntry predts;
@ -709,6 +742,9 @@ class Dart : public GBTree {
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
i + 1)) {
success = true;
#if defined(XGBOOST_USE_CUDA)
device = predts.predictions.DeviceIdx();
#endif // defined(XGBOOST_USE_CUDA)
break;
}
}
@ -716,14 +752,16 @@ class Dart : public GBTree {
} else {
// No base margin for each tree
bool success = this->GetPredictor()->InplacePredict(
x, nullptr, model_, missing, &predts, tree_begin, tree_end);
x, nullptr, model_, missing, &predts, i, i + 1);
device = predts.predictions.DeviceIdx();
CHECK(success) << msg;
}
auto w = this->weight_drop_.at(i);
auto &h_predts = predts.predictions.HostVector();
auto &h_out_predts = out_preds->predictions.HostVector();
if (h_out_predts.empty()) {
if (i == tree_begin) {
auto n_rows =
h_predts.size() / model_.learner_model_param->num_output_group;
if (p_m) {
@ -739,12 +777,19 @@ class Dart : public GBTree {
// Multiple the tree weight
CHECK_EQ(h_predts.size(), h_out_predts.size());
for (size_t i = 0; i < h_out_predts.size(); ++i) {
#pragma omp parallel for
for (omp_ulong i = 0; i < h_out_predts.size(); ++i) {
// Need to remove the base margin from indiviual tree.
h_out_predts[i] +=
(h_predts[i] - model_.learner_model_param->base_score) * w;
}
}
if (device != GenericParameter::kCpuId) {
out_preds->predictions.SetDevice(device);
out_preds->predictions.DeviceSpan();
}
}
void PredictInstance(const SparsePage::Inst &inst,

18
src/gbm/gbtree.cu Normal file
View File

@ -0,0 +1,18 @@
/*!
* Copyright 2021 by Contributors
*/
#include "xgboost/span.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace gbm {
void GPUDartPredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w, size_t n_rows,
bst_group_t n_groups, bst_group_t group) {
dh::LaunchN(dh::CurrentDevice(), n_rows, [=]XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] * tree_w);
});
}
} // namespace gbm
} // namespace xgboost

View File

@ -312,3 +312,33 @@ class TestGPUPredict:
pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
def test_predict_dart(self):
import cupy as cp
rng = cp.random.RandomState(1994)
n_samples = 1000
X = rng.randn(n_samples, 10)
y = rng.randn(n_samples)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
},
Xy,
num_boost_round=32
)
# predictor=auto
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
copied = cp.array(copied)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
booster.set_param({"predictor": "gpu_predictor"})
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
copied = cp.array(copied)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)