Implement contribution prediction with QuantileDMatrix (#10043)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Louis Desreumaux 2024-02-19 14:03:29 +01:00 committed by GitHub
parent 057f03cacc
commit edf501d227
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 137 additions and 55 deletions

View File

@ -698,6 +698,67 @@ class CPUPredictor : public Predictor {
} }
} }
template <typename DataView>
void PredictContributionKernel(DataView batch, const MetaInfo& info,
const gbm::GBTreeModel& model,
const std::vector<bst_float>* tree_weights,
std::vector<std::vector<float>>* mean_values,
std::vector<RegTree::FVec>* feat_vecs,
std::vector<bst_float>* contribs, uint32_t ntree_limit,
bool approximate, int condition,
unsigned condition_feature) const {
const int num_feature = model.learner_model_param->num_feature;
const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
size_t const ncolumns = num_feature + 1;
CHECK_NE(ncolumns, 0);
auto base_margin = info.base_margin_.View(ctx_->Device());
auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0);
// parallel over local batch
common::ParallelFor(batch.Size(), this->ctx_->Threads(), [&](auto i) {
auto row_idx = batch.base_rowid + i;
RegTree::FVec &feats = (*feat_vecs)[omp_get_thread_num()];
if (feats.Size() == 0) {
feats.Init(num_feature);
}
std::vector<bst_float> this_tree_contribs(ncolumns);
// loop over all classes
for (int gid = 0; gid < ngroup; ++gid) {
bst_float* p_contribs = &(*contribs)[(row_idx * ngroup + gid) * ncolumns];
feats.Fill(batch[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values->at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
CalculateContributions(*model.trees[j], feats, tree_mean_values,
&this_tree_contribs[0], condition, condition_feature);
} else {
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
}
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop();
// add base margin to BIAS
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), ngroup);
p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
} else {
p_contribs[ncolumns - 1] += base_score;
}
}
});
}
public: public:
explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {}
@ -861,7 +922,6 @@ class CPUPredictor : public Predictor {
CHECK(!p_fmat->Info().IsColumnSplit()) CHECK(!p_fmat->Info().IsColumnSplit())
<< "Predict contribution support for column-wise data split is not yet implemented."; << "Predict contribution support for column-wise data split is not yet implemented.";
auto const n_threads = this->ctx_->Threads(); auto const n_threads = this->ctx_->Threads();
const int num_feature = model.learner_model_param->num_feature;
std::vector<RegTree::FVec> feat_vecs; std::vector<RegTree::FVec> feat_vecs;
InitThreadTemp(n_threads, &feat_vecs); InitThreadTemp(n_threads, &feat_vecs);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
@ -869,10 +929,7 @@ class CPUPredictor : public Predictor {
if (ntree_limit == 0 || ntree_limit > model.trees.size()) { if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size()); ntree_limit = static_cast<unsigned>(model.trees.size());
} }
const int ngroup = model.learner_model_param->num_output_group; size_t const ncolumns = model.learner_model_param->num_feature + 1;
CHECK_NE(ngroup, 0);
size_t const ncolumns = num_feature + 1;
CHECK_NE(ncolumns, 0);
// allocate space for (number of features + bias) times the number of rows // allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = out_contribs->HostVector(); std::vector<bst_float>& contribs = out_contribs->HostVector();
contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group);
@ -884,53 +941,22 @@ class CPUPredictor : public Predictor {
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
}); });
auto base_margin = info.base_margin_.View(ctx_->Device());
auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0);
// start collecting the contributions // start collecting the contributions
if (!p_fmat->PageExists<SparsePage>()) {
std::vector<Entry> workspace(info.num_col_ * kUnroll * n_threads);
auto ft = p_fmat->Info().feature_types.ConstHostVector();
for (const auto &batch : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, {})) {
PredictContributionKernel(
GHistIndexMatrixView{batch, info.num_col_, ft, workspace, n_threads},
info, model, tree_weights, &mean_values, &feat_vecs, &contribs, ntree_limit,
approximate, condition, condition_feature);
}
} else {
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) { for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView(); PredictContributionKernel(
// parallel over local batch SparsePageView{&batch}, info, model, tree_weights, &mean_values, &feat_vecs,
common::ParallelFor(batch.Size(), n_threads, [&](auto i) { &contribs, ntree_limit, approximate, condition, condition_feature);
auto row_idx = batch.base_rowid + i;
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
if (feats.Size() == 0) {
feats.Init(num_feature);
} }
std::vector<bst_float> this_tree_contribs(ncolumns);
// loop over all classes
for (int gid = 0; gid < ngroup; ++gid) {
bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns];
feats.Fill(page[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values.at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
CalculateContributions(*model.trees[j], feats, tree_mean_values,
&this_tree_contribs[0], condition, condition_feature);
} else {
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
}
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop();
// add base margin to BIAS
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), ngroup);
p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
} else {
p_contribs[ncolumns - 1] += base_score;
}
}
});
} }
} }

View File

@ -1042,6 +1042,9 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) { if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented; LOG(FATAL) << "Dart booster feature " << not_implemented;
} }
if (!p_fmat->PageExists<SparsePage>()) {
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
}
CHECK(!p_fmat->Info().IsColumnSplit()) CHECK(!p_fmat->Info().IsColumnSplit())
<< "Predict contribution support for column-wise data split is not yet implemented."; << "Predict contribution support for column-wise data split is not yet implemented.";
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
@ -1102,6 +1105,9 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) { if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented; LOG(FATAL) << "Dart booster feature " << not_implemented;
} }
if (!p_fmat->PageExists<SparsePage>()) {
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
}
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
out_contribs->SetDevice(ctx_->Device()); out_contribs->SetDevice(ctx_->Device());
if (tree_end == 0 || tree_end > model.trees.size()) { if (tree_end == 0 || tree_end > model.trees.size()) {

View File

@ -148,7 +148,7 @@ TEST(CPUPredictor, GHistIndexTraining) {
auto adapter = data::ArrayAdapter(columnar.c_str()); auto adapter = data::ArrayAdapter(columnar.c_str());
std::shared_ptr<DMatrix> p_full{ std::shared_ptr<DMatrix> p_full{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)}; DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist); TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist, true);
} }
TEST(CPUPredictor, CategoricalPrediction) { TEST(CPUPredictor, CategoricalPrediction) {

View File

@ -118,7 +118,8 @@ TEST(Predictor, PredictionCache) {
} }
void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins, void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist) { std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs) {
size_t constexpr kCols = 16; size_t constexpr kCols = 16;
size_t constexpr kClasses = 3; size_t constexpr kClasses = 3;
size_t constexpr kIters = 3; size_t constexpr kIters = 3;
@ -161,6 +162,28 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
for (size_t i = 0; i < rows; ++i) { for (size_t i = 0; i < rows; ++i) {
EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps); EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps);
} }
if (check_contribs) {
// Contributions
HostDeviceVector<float> from_full_contribs;
learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true);
HostDeviceVector<float> from_hist_contribs;
learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true);
for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i],
from_full_contribs.ConstHostVector()[i], kRtEps);
}
// Contributions (approximate method)
HostDeviceVector<float> from_full_approx_contribs;
learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true);
HostDeviceVector<float> from_hist_approx_contribs;
learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true);
for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i],
from_full_approx_contribs.ConstHostVector()[i], kRtEps);
}
}
} }
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows, void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,

View File

@ -89,7 +89,8 @@ void TestBasic(DMatrix* dmat, Context const * ctx);
// p_full and p_hist should come from the same data set. // p_full and p_hist should come from the same data set.
void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins, void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist); std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs = false);
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows, void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
bst_feature_t cols); bst_feature_t cols);

View File

@ -2,7 +2,6 @@ import itertools
import re import re
import numpy as np import numpy as np
import scipy
import scipy.special import scipy.special
import xgboost as xgb import xgboost as xgb
@ -256,3 +255,30 @@ class TestSHAP:
brute_force[-1, -1] += base_score brute_force[-1, -1] += base_score
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_interactions=True) fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_interactions=True)
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4 assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
def test_shap_values(self) -> None:
from sklearn.datasets import make_classification, make_regression
def assert_same(X: np.ndarray, y: np.ndarray) -> None:
Xy = xgb.DMatrix(X, y)
booster = xgb.train({}, Xy, num_boost_round=4)
shap_dm = booster.predict(Xy, pred_contribs=True)
Xy = xgb.QuantileDMatrix(X, y)
shap_qdm = booster.predict(Xy, pred_contribs=True)
np.testing.assert_allclose(shap_dm, shap_qdm)
margin = booster.predict(Xy, output_margin=True)
np.testing.assert_allclose(
np.sum(shap_qdm, axis=len(shap_qdm.shape) - 1), margin, 1e-3, 1e-3
)
shap_dm = booster.predict(Xy, pred_interactions=True)
Xy = xgb.QuantileDMatrix(X, y)
shap_qdm = booster.predict(Xy, pred_interactions=True)
np.testing.assert_allclose(shap_dm, shap_qdm)
X, y = make_regression()
assert_same(X, y)
X, y = make_classification()
assert_same(X, y)