[back port] Optimize dart inplace predict perf. (#6804) (#6829)

This commit is contained in:
Jiaming Yuan 2021-04-07 00:21:12 +08:00 committed by GitHub
parent d231e7c35f
commit 357a78b3de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 32 deletions

View File

@ -575,6 +575,20 @@ void GPUDartPredictInc(common::Span<float> out_predts,
} }
#endif #endif
void GPUDartInplacePredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w,
size_t n_rows, float base_score,
bst_group_t n_groups,
bst_group_t group)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
}
#endif
class Dart : public GBTree { class Dart : public GBTree {
public: public:
explicit Dart(LearnerModelParam const* booster_config) : explicit Dart(LearnerModelParam const* booster_config) :
@ -728,13 +742,14 @@ class Dart : public GBTree {
gpu_predictor_.get() gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
}; };
Predictor const * predictor {nullptr};
MetaInfo info; MetaInfo info;
StringView msg{"Unsupported data type for inplace predict."}; StringView msg{"Unsupported data type for inplace predict."};
int32_t device = GenericParameter::kCpuId; int32_t device = GenericParameter::kCpuId;
PredictionCacheEntry predts;
// Inplace predict is not used for training, so no need to drop tree. // Inplace predict is not used for training, so no need to drop tree.
for (size_t i = tree_begin; i < tree_end; ++i) { for (size_t i = tree_begin; i < tree_end; ++i) {
PredictionCacheEntry predts;
if (tparam_.predictor == PredictorType::kAuto) { if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations // Try both predictor implementations
bool success = false; bool success = false;
@ -742,6 +757,7 @@ class Dart : public GBTree {
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i, if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
i + 1)) { i + 1)) {
success = true; success = true;
predictor = p;
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
device = predts.predictions.DeviceIdx(); device = predts.predictions.DeviceIdx();
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
@ -750,46 +766,53 @@ class Dart : public GBTree {
} }
CHECK(success) << msg; CHECK(success) << msg;
} else { } else {
// No base margin for each tree // No base margin from meta info for each tree
bool success = this->GetPredictor()->InplacePredict( predictor = this->GetPredictor().get();
x, nullptr, model_, missing, &predts, i, i + 1); bool success = predictor->InplacePredict(x, nullptr, model_, missing,
&predts, i, i + 1);
device = predts.predictions.DeviceIdx(); device = predts.predictions.DeviceIdx();
CHECK(success) << msg; CHECK(success) << msg;
} }
auto w = this->weight_drop_.at(i); auto w = this->weight_drop_.at(i);
auto &h_predts = predts.predictions.HostVector(); size_t n_groups = model_.learner_model_param->num_output_group;
auto &h_out_predts = out_preds->predictions.HostVector(); auto n_rows = predts.predictions.Size() / n_groups;
if (i == tree_begin) { if (i == tree_begin) {
auto n_rows = // base margin is added here.
h_predts.size() / model_.learner_model_param->num_output_group;
if (p_m) { if (p_m) {
p_m->Info().num_row_ = n_rows; p_m->Info().num_row_ = n_rows;
cpu_predictor_->InitOutPredictions(p_m->Info(), predictor->InitOutPredictions(p_m->Info(), &out_preds->predictions,
&out_preds->predictions, model_); model_);
} else { } else {
info.num_row_ = n_rows; info.num_row_ = n_rows;
cpu_predictor_->InitOutPredictions(info, &out_preds->predictions, predictor->InitOutPredictions(info, &out_preds->predictions, model_);
model_);
} }
} }
// Multiple the tree weight // Multiple the tree weight
CHECK_EQ(h_predts.size(), h_out_predts.size()); CHECK_EQ(predts.predictions.Size(), out_preds->predictions.Size());
auto group = model_.tree_info.at(i);
if (device == GenericParameter::kCpuId) {
auto &h_predts = predts.predictions.HostVector();
auto &h_out_predts = out_preds->predictions.HostVector();
#pragma omp parallel for #pragma omp parallel for
for (omp_ulong i = 0; i < h_out_predts.size(); ++i) { for (omp_ulong ridx = 0; ridx < n_rows; ++ridx) {
// Need to remove the base margin from indiviual tree. const size_t offset = ridx * n_groups + group;
h_out_predts[i] += // Need to remove the base margin from indiviual tree.
(h_predts[i] - model_.learner_model_param->base_score) * w; h_out_predts[offset] +=
(h_predts[offset] - model_.learner_model_param->base_score) * w;
}
} else {
out_preds->predictions.SetDevice(device);
predts.predictions.SetDevice(device);
GPUDartInplacePredictInc(out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows,
model_.learner_model_param->base_score,
n_groups, group);
} }
} }
if (device != GenericParameter::kCpuId) {
out_preds->predictions.SetDevice(device);
out_preds->predictions.DeviceSpan();
}
} }
void PredictInstance(const SparsePage::Inst &inst, void PredictInstance(const SparsePage::Inst &inst,

View File

@ -14,5 +14,15 @@ void GPUDartPredictInc(common::Span<float> out_predts,
out_predts[offset] += (predts[offset] * tree_w); out_predts[offset] += (predts[offset] * tree_w);
}); });
} }
void GPUDartInplacePredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w,
size_t n_rows, float base_score,
bst_group_t n_groups, bst_group_t group) {
dh::LaunchN(dh::CurrentDevice(), n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] - base_score) * tree_w;
});
}
} // namespace gbm } // namespace gbm
} // namespace xgboost } // namespace xgboost

View File

@ -332,27 +332,44 @@ class TestGPUPredict:
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False) rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5) np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
def test_predict_dart(self): @pytest.mark.parametrize("n_classes", [2, 3])
def test_predict_dart(self, n_classes):
from sklearn.datasets import make_classification
import cupy as cp import cupy as cp
rng = cp.random.RandomState(1994)
n_samples = 1000 n_samples = 1000
X = rng.randn(n_samples, 10) X_, y_ = make_classification(
y = rng.randn(n_samples) n_samples=n_samples, n_informative=5, n_classes=n_classes
)
X, y = cp.array(X_), cp.array(y_)
Xy = xgb.DMatrix(X, y) Xy = xgb.DMatrix(X, y)
booster = xgb.train( if n_classes == 2:
{ params = {
"tree_method": "gpu_hist", "tree_method": "gpu_hist",
"booster": "dart", "booster": "dart",
"rate_drop": 0.5, "rate_drop": 0.5,
}, "objective": "binary:logistic"
Xy, }
num_boost_round=32 else:
) params = {
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
"objective": "multi:softprob",
"num_class": n_classes
}
booster = xgb.train(params, Xy, num_boost_round=32)
# predictor=auto # predictor=auto
inplace = booster.inplace_predict(X) inplace = booster.inplace_predict(X)
copied = booster.predict(Xy) copied = booster.predict(Xy)
cpu_inplace = booster.inplace_predict(X_)
booster.set_param({"predictor": "cpu_predictor"})
cpu_copied = booster.predict(Xy)
copied = cp.array(copied) copied = cp.array(copied)
cp.testing.assert_allclose(cpu_inplace, copied, atol=1e-6)
cp.testing.assert_allclose(cpu_copied, copied, atol=1e-6)
cp.testing.assert_allclose(inplace, copied, atol=1e-6) cp.testing.assert_allclose(inplace, copied, atol=1e-6)
booster.set_param({"predictor": "gpu_predictor"}) booster.set_param({"predictor": "gpu_predictor"})