[back port] Optimize dart inplace predict perf. (#6804) (#6829)

This commit is contained in:
Jiaming Yuan 2021-04-07 00:21:12 +08:00 committed by GitHub
parent d231e7c35f
commit 357a78b3de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 32 deletions

View File

@ -575,6 +575,20 @@ void GPUDartPredictInc(common::Span<float> out_predts,
}
#endif
void GPUDartInplacePredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w,
size_t n_rows, float base_score,
bst_group_t n_groups,
bst_group_t group)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
}
#endif
class Dart : public GBTree {
public:
explicit Dart(LearnerModelParam const* booster_config) :
@ -728,13 +742,14 @@ class Dart : public GBTree {
gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA)
};
Predictor const * predictor {nullptr};
MetaInfo info;
StringView msg{"Unsupported data type for inplace predict."};
int32_t device = GenericParameter::kCpuId;
PredictionCacheEntry predts;
// Inplace predict is not used for training, so no need to drop tree.
for (size_t i = tree_begin; i < tree_end; ++i) {
PredictionCacheEntry predts;
if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations
bool success = false;
@ -742,6 +757,7 @@ class Dart : public GBTree {
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
i + 1)) {
success = true;
predictor = p;
#if defined(XGBOOST_USE_CUDA)
device = predts.predictions.DeviceIdx();
#endif // defined(XGBOOST_USE_CUDA)
@ -750,45 +766,52 @@ class Dart : public GBTree {
}
CHECK(success) << msg;
} else {
// No base margin for each tree
bool success = this->GetPredictor()->InplacePredict(
x, nullptr, model_, missing, &predts, i, i + 1);
// No base margin from meta info for each tree
predictor = this->GetPredictor().get();
bool success = predictor->InplacePredict(x, nullptr, model_, missing,
&predts, i, i + 1);
device = predts.predictions.DeviceIdx();
CHECK(success) << msg;
}
auto w = this->weight_drop_.at(i);
auto &h_predts = predts.predictions.HostVector();
auto &h_out_predts = out_preds->predictions.HostVector();
size_t n_groups = model_.learner_model_param->num_output_group;
auto n_rows = predts.predictions.Size() / n_groups;
if (i == tree_begin) {
auto n_rows =
h_predts.size() / model_.learner_model_param->num_output_group;
// base margin is added here.
if (p_m) {
p_m->Info().num_row_ = n_rows;
cpu_predictor_->InitOutPredictions(p_m->Info(),
&out_preds->predictions, model_);
predictor->InitOutPredictions(p_m->Info(), &out_preds->predictions,
model_);
} else {
info.num_row_ = n_rows;
cpu_predictor_->InitOutPredictions(info, &out_preds->predictions,
model_);
predictor->InitOutPredictions(info, &out_preds->predictions, model_);
}
}
// Multiple the tree weight
CHECK_EQ(h_predts.size(), h_out_predts.size());
CHECK_EQ(predts.predictions.Size(), out_preds->predictions.Size());
auto group = model_.tree_info.at(i);
if (device == GenericParameter::kCpuId) {
auto &h_predts = predts.predictions.HostVector();
auto &h_out_predts = out_preds->predictions.HostVector();
#pragma omp parallel for
for (omp_ulong i = 0; i < h_out_predts.size(); ++i) {
for (omp_ulong ridx = 0; ridx < n_rows; ++ridx) {
const size_t offset = ridx * n_groups + group;
// Need to remove the base margin from indiviual tree.
h_out_predts[i] +=
(h_predts[i] - model_.learner_model_param->base_score) * w;
h_out_predts[offset] +=
(h_predts[offset] - model_.learner_model_param->base_score) * w;
}
}
if (device != GenericParameter::kCpuId) {
} else {
out_preds->predictions.SetDevice(device);
out_preds->predictions.DeviceSpan();
predts.predictions.SetDevice(device);
GPUDartInplacePredictInc(out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows,
model_.learner_model_param->base_score,
n_groups, group);
}
}
}

View File

@ -14,5 +14,15 @@ void GPUDartPredictInc(common::Span<float> out_predts,
out_predts[offset] += (predts[offset] * tree_w);
});
}
void GPUDartInplacePredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w,
size_t n_rows, float base_score,
bst_group_t n_groups, bst_group_t group) {
dh::LaunchN(dh::CurrentDevice(), n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] - base_score) * tree_w;
});
}
} // namespace gbm
} // namespace xgboost

View File

@ -332,27 +332,44 @@ class TestGPUPredict:
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
def test_predict_dart(self):
@pytest.mark.parametrize("n_classes", [2, 3])
def test_predict_dart(self, n_classes):
from sklearn.datasets import make_classification
import cupy as cp
rng = cp.random.RandomState(1994)
n_samples = 1000
X = rng.randn(n_samples, 10)
y = rng.randn(n_samples)
X_, y_ = make_classification(
n_samples=n_samples, n_informative=5, n_classes=n_classes
)
X, y = cp.array(X_), cp.array(y_)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
if n_classes == 2:
params = {
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
},
Xy,
num_boost_round=32
)
"objective": "binary:logistic"
}
else:
params = {
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
"objective": "multi:softprob",
"num_class": n_classes
}
booster = xgb.train(params, Xy, num_boost_round=32)
# predictor=auto
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
cpu_inplace = booster.inplace_predict(X_)
booster.set_param({"predictor": "cpu_predictor"})
cpu_copied = booster.predict(Xy)
copied = cp.array(copied)
cp.testing.assert_allclose(cpu_inplace, copied, atol=1e-6)
cp.testing.assert_allclose(cpu_copied, copied, atol=1e-6)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
booster.set_param({"predictor": "gpu_predictor"})