[EM] Support SHAP contribution with QDM. (#10724)

- Add GPU support.
- Add external memory support.
- Update the GPU tree shap.
This commit is contained in:
Jiaming Yuan 2024-08-22 05:25:10 +08:00 committed by GitHub
parent cb54374550
commit 142bdc73ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 274 additions and 159 deletions

@ -1 +1 @@
Subproject commit 787259b412c18ab8d5f24bf2b8bd6a59ff8208f3
Subproject commit 40eae8c4c45974705f8053e4d3d05b88e3cfaefd

View File

@ -143,10 +143,9 @@ struct SparsePageLoader {
};
struct EllpackLoader {
EllpackDeviceAccessor const& matrix;
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t,
float)
: matrix{m} {}
EllpackDeviceAccessor matrix;
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor m, bool, bst_feature_t, bst_idx_t, float)
: matrix{std::move(m)} {}
[[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
auto gidx = matrix.GetBinIndex<false>(ridx, fidx);
if (gidx == -1) {
@ -162,6 +161,8 @@ struct EllpackLoader {
}
return matrix.gidx_fvalue_map[gidx - 1];
}
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); }
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; }
};
template <typename Batch>
@ -1031,9 +1032,6 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
if (!p_fmat->PageExists<SparsePage>()) {
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
}
CHECK(!p_fmat->Info().IsColumnSplit())
<< "Predict contribution support for column-wise data split is not yet implemented.";
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
@ -1047,8 +1045,8 @@ class GPUPredictor : public xgboost::Predictor {
// allocate space for (number of features + bias) times the number of rows
size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
model.learner_model_param->num_output_group);
auto dim_size = contributions_columns * model.learner_model_param->num_output_group;
out_contribs->Resize(p_fmat->Info().num_row_ * dim_size);
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
@ -1058,16 +1056,27 @@ class GPUPredictor : public xgboost::Predictor {
d_model.Init(model, 0, tree_end, ctx_->Device());
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device());
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(ctx_->Device());
batch.offset.SetDevice(ctx_->Device());
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
if (p_fmat->PageExists<SparsePage>()) {
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(ctx_->Device());
batch.offset.SetDevice(ctx_->Device());
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * dim_size;
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
}
} else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
std::numeric_limits<float>::quiet_NaN()};
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
}
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(ctx_->Device());
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
@ -1094,9 +1103,6 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
if (!p_fmat->PageExists<SparsePage>()) {
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
}
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
out_contribs->SetDevice(ctx_->Device());
if (tree_end == 0 || tree_end > model.trees.size()) {
@ -1108,9 +1114,9 @@ class GPUPredictor : public xgboost::Predictor {
// allocate space for (number of features + bias) times the number of rows
size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
contributions_columns *
model.learner_model_param->num_output_group);
auto dim_size =
contributions_columns * contributions_columns * model.learner_model_param->num_output_group;
out_contribs->Resize(p_fmat->Info().num_row_ * dim_size);
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
@ -1120,16 +1126,29 @@ class GPUPredictor : public xgboost::Predictor {
d_model.Init(model, 0, tree_end, ctx_->Device());
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device());
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(ctx_->Device());
batch.offset.SetDevice(ctx_->Device());
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(ctx_->Device());
batch.offset.SetDevice(ctx_->Device());
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * dim_size;
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
}
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
auto impl = batch.Impl();
auto acc =
impl->GetDeviceAccessor(ctx_->Device(), p_fmat->Info().feature_types.ConstDeviceSpan());
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
std::numeric_limits<float>::quiet_NaN()};
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
}
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(ctx_->Device());
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
@ -1180,51 +1199,35 @@ class GPUPredictor : public xgboost::Predictor {
bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_;
auto launch = [&](auto fn, std::uint32_t grid, auto data, bst_idx_t batch_offset) {
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes}(
fn, data, d_model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(), d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, use_shared,
std::numeric_limits<float>::quiet_NaN());
};
if (p_fmat->PageExists<SparsePage>()) {
bst_idx_t batch_offset = 0;
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(ctx_->Device());
batch.offset.SetDevice(ctx_->Device());
bst_idx_t batch_offset = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
d_model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
d_model.tree_segments.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(),
d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
use_shared, std::numeric_limits<float>::quiet_NaN());
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
launch(PredictLeafKernel<SparsePageLoader, SparsePageView>, grid, data, batch_offset);
batch_offset += batch.Size();
}
} else {
bst_idx_t batch_offset = 0;
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
bst_idx_t batch_offset = 0;
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
d_model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
d_model.tree_segments.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(),
d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
use_shared, std::numeric_limits<float>::quiet_NaN());
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset);
batch_offset += batch.Size();
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2016-2023 by XGBoost Contributors
* Copyright 2016-2024, XGBoost Contributors
*/
#include <xgboost/data.h>
@ -434,12 +434,11 @@ namespace {
void VerifyColumnSplit() {
size_t constexpr kRows {16};
size_t constexpr kCols {8};
auto dmat =
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol);
auto p_fmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, DataSplitMode::kCol);
ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize());
ASSERT_EQ(dmat->Info().num_row_, kRows);
ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol);
ASSERT_EQ(p_fmat->Info().num_col_, kCols * collective::GetWorldSize());
ASSERT_EQ(p_fmat->Info().num_row_, kRows);
ASSERT_EQ(p_fmat->Info().data_split_mode, DataSplitMode::kCol);
}
} // anonymous namespace

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023, XGBoost contributors
* Copyright 2019-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h>
@ -463,7 +463,7 @@ INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU"));
std::pair<Json, Json> TestModelSlice(std::string booster) {
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses);
auto m = RandomDataGenerator{kRows, kCols, 0}.Classes(kClasses).GenerateDMatrix(true);
int32_t kIters = 10;
std::unique_ptr<Learner> learner {
@ -592,7 +592,7 @@ TEST(Dart, Slice) {
TEST(GBTree, FeatureScore) {
size_t n_samples = 1000, n_features = 10, n_classes = 4;
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true);
std::unique_ptr<Learner> learner{ Learner::Create({m}) };
learner->SetParam("num_class", std::to_string(n_classes));
@ -629,7 +629,7 @@ TEST(GBTree, FeatureScore) {
TEST(GBTree, PredictRange) {
size_t n_samples = 1000, n_features = 10, n_classes = 4;
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true);
std::unique_ptr<Learner> learner{Learner::Create({m})};
learner->SetParam("num_class", std::to_string(n_classes));
@ -642,7 +642,7 @@ TEST(GBTree, PredictRange) {
ASSERT_THROW(learner->Predict(m, false, &out_predt, 0, 3), dmlc::Error);
auto m_1 =
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true);
HostDeviceVector<float> out_predt_full;
learner->Predict(m_1, false, &out_predt_full, 0, 0);
ASSERT_TRUE(std::equal(out_predt.HostVector().begin(), out_predt.HostVector().end(),

View File

@ -376,8 +376,33 @@ void RandomDataGenerator::GenerateCSR(
CHECK_EQ(columns->Size(), value->Size());
}
namespace {
void MakeLabels(DeviceOrd device, bst_idx_t n_samples, bst_target_t n_classes,
bst_target_t n_targets, std::shared_ptr<DMatrix> out) {
RandomDataGenerator gen{n_samples, n_targets, 0.0f};
if (n_classes != 0) {
gen.Lower(0).Upper(n_classes).GenerateDense(out->Info().labels.Data());
out->Info().labels.Reshape(n_samples, n_targets);
auto& h_labels = out->Info().labels.Data()->HostVector();
for (auto& v : h_labels) {
v = static_cast<float>(static_cast<uint32_t>(v));
}
} else {
gen.GenerateDense(out->Info().labels.Data());
CHECK_EQ(out->Info().labels.Size(), n_samples * n_targets);
out->Info().labels.Reshape(n_samples, n_targets);
}
if (device.IsCUDA()) {
out->Info().labels.Data()->SetDevice(device);
out->Info().labels.Data()->ConstDevicePointer();
out->Info().feature_types.SetDevice(device);
out->Info().feature_types.ConstDevicePointer();
}
}
} // namespace
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(
bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const {
bool with_label, DataSplitMode data_split_mode) const {
HostDeviceVector<float> data;
HostDeviceVector<std::size_t> rptrs;
HostDeviceVector<bst_feature_t> columns;
@ -388,19 +413,7 @@ void RandomDataGenerator::GenerateCSR(
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, "", data_split_mode)};
if (with_label) {
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
if (!float_label) {
gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data());
out->Info().labels.Reshape(this->rows_, this->n_targets_);
auto& h_labels = out->Info().labels.Data()->HostVector();
for (auto& v : h_labels) {
v = static_cast<float>(static_cast<uint32_t>(v));
}
} else {
gen.GenerateDense(out->Info().labels.Data());
CHECK_EQ(out->Info().labels.Size(), this->rows_ * this->n_targets_);
out->Info().labels.Reshape(this->rows_, this->n_targets_);
}
MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, out);
}
if (device_.IsCUDA()) {
out->Info().labels.SetDevice(device_);
@ -435,34 +448,31 @@ void RandomDataGenerator::GenerateCSR(
#endif // defined(XGBOOST_USE_CUDA)
}
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
std::shared_ptr<DMatrix> p_fmat{DMatrix::Create(
static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix, on_host_)};
auto row_page_path =
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(p_fmat.get())) + ".row.page";
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
// Loop over the batches and count the number of pages
std::size_t batch_count = 0;
bst_idx_t row_count = 0;
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
for (const auto& batch : p_fmat->GetBatches<xgboost::SparsePage>()) {
batch_count++;
row_count += batch.Size();
CHECK_NE(batch.data.Size(), 0);
}
EXPECT_EQ(batch_count, n_batches_);
EXPECT_EQ(dmat->NumBatches(), n_batches_);
EXPECT_EQ(row_count, dmat->Info().num_row_);
EXPECT_EQ(p_fmat->NumBatches(), n_batches_);
EXPECT_EQ(row_count, p_fmat->Info().num_row_);
if (with_label) {
RandomDataGenerator{static_cast<bst_idx_t>(dmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense(
dmat->Info().labels.Data());
CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_);
dmat->Info().labels.Reshape(this->rows_, this->n_targets_);
MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, p_fmat);
}
return dmat;
return p_fmat;
}
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateExtMemQuantileDMatrix(
@ -492,10 +502,7 @@ void RandomDataGenerator::GenerateCSR(
}
if (with_label) {
RandomDataGenerator{static_cast<bst_idx_t>(p_fmat->Info().num_row_), this->n_targets_, 0.0f}
.GenerateDense(p_fmat->Info().labels.Data());
CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_);
p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_);
MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, p_fmat);
}
return p_fmat;
}

View File

@ -229,6 +229,7 @@ class RandomDataGenerator {
float upper_{1.0f};
bst_target_t n_targets_{1};
bst_target_t n_classes_{0};
DeviceOrd device_{DeviceOrd::CPU()};
std::size_t n_batches_{0};
@ -291,6 +292,10 @@ class RandomDataGenerator {
n_targets_ = n_targets;
return *this;
}
RandomDataGenerator& Classes(bst_target_t n_classes) {
n_classes_ = n_classes;
return *this;
}
void GenerateDense(HostDeviceVector<float>* out) const;
@ -315,8 +320,7 @@ class RandomDataGenerator {
HostDeviceVector<bst_feature_t>* columns) const;
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(
bool with_label = false, bool float_label = true, size_t classes = 1,
DataSplitMode data_split_mode = DataSplitMode::kRow) const;
bool with_label = false, DataSplitMode data_split_mode = DataSplitMode::kRow) const;
[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
bool with_label) const;

View File

@ -119,7 +119,8 @@ void TestUnbiasedNDCG(Context const* ctx) {
obj->Configure(Args{{"lambdarank_pair_method", "topk"},
{"lambdarank_unbiased", "true"},
{"lambdarank_bias_norm", "0"}});
std::shared_ptr<DMatrix> p_fmat{RandomDataGenerator{10, 1, 0.0f}.GenerateDMatrix(true, false, 2)};
std::shared_ptr<DMatrix> p_fmat{
RandomDataGenerator{10, 1, 0.0f}.Classes(2).GenerateDMatrix(true)};
auto h_label = p_fmat->Info().labels.HostView().Values();
// Move clicked samples to the beginning.
std::sort(h_label.begin(), h_label.end(), std::greater<>{});

View File

@ -61,6 +61,12 @@ TEST(CpuPredictor, ExternalMemory) {
TestBasic(dmat.get(), &ctx);
}
TEST_P(ShapExternalMemoryTest, CPUPredictor) {
Context ctx;
auto [is_qdm, is_interaction] = this->GetParam();
this->Run(&ctx, is_qdm, is_interaction);
}
TEST(CpuPredictor, InplacePredict) {
bst_idx_t constexpr kRows{128};
bst_feature_t constexpr kCols{64};
@ -110,7 +116,7 @@ void TestUpdatePredictionCache(bool use_subsampling) {
}
gbm->Configure(args);
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
auto dmat = RandomDataGenerator(kRows, kCols, 0).Classes(kClasses).GenerateDMatrix(true);
linalg::Matrix<GradientPair> gpair({kRows, kClasses}, ctx.Device());
auto h_gpair = gpair.HostView();
@ -145,7 +151,7 @@ TEST(CPUPredictor, GHistIndexTraining) {
auto adapter = data::ArrayAdapter(columnar.c_str());
std::shared_ptr<DMatrix> p_full{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist, true);
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist);
}
TEST(CPUPredictor, CategoricalPrediction) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2017-2023, XGBoost contributors
* Copyright 2017-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
@ -17,7 +17,6 @@
#include "test_predictor.h"
namespace xgboost::predictor {
TEST(GPUPredictor, Basic) {
auto cpu_lparam = MakeCUDACtx(-1);
auto gpu_lparam = MakeCUDACtx(0);
@ -269,10 +268,9 @@ TEST(GPUPredictor, Shap) {
trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0);
model.CommitModelGroup(std::move(trees), 0);
auto gpu_lparam = MakeCUDACtx(0);
auto cpu_lparam = MakeCUDACtx(-1);
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor>(
Predictor::Create("gpu_predictor", &gpu_lparam));
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &ctx));
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor>(
Predictor::Create("cpu_predictor", &cpu_lparam));
gpu_predictor->Configure({});
@ -289,6 +287,12 @@ TEST(GPUPredictor, Shap) {
}
}
TEST_P(ShapExternalMemoryTest, GPUPredictor) {
auto ctx = MakeCUDACtx(0);
auto [is_qdm, is_interaction] = this->GetParam();
this->Run(&ctx, is_qdm, is_interaction);
}
TEST(GPUPredictor, IterationRange) {
auto ctx = MakeCUDACtx(0);
TestIterationRange(&ctx);

View File

@ -4,15 +4,16 @@
#include "test_predictor.h"
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for DMatrix, BatchIterator, BatchSet, MetaInfo
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/predictor.h> // for PredictionCacheEntry, Predictor, Predic...
#include <xgboost/string_view.h> // for StringView
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for DMatrix, BatchIterator, BatchSet, MetaInfo
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/json.h> // for Json
#include <xgboost/predictor.h> // for PredictionCacheEntry, Predictor, Predic...
#include <xgboost/string_view.h> // for StringView
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <unordered_map> // for unordered_map
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <unordered_map> // for unordered_map
#include "../../../src/common/bitfield.h" // for LBitField32
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
@ -26,7 +27,6 @@
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost {
void TestBasic(DMatrix* dmat, Context const *ctx) {
auto predictor = std::unique_ptr<Predictor>(CreatePredictorForTest(ctx));
@ -118,8 +118,7 @@ TEST(Predictor, PredictionCache) {
}
void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs) {
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist) {
size_t constexpr kCols = 16;
size_t constexpr kClasses = 3;
size_t constexpr kIters = 3;
@ -163,34 +162,32 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps);
}
if (check_contribs) {
// Contributions
HostDeviceVector<float> from_full_contribs;
learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true);
HostDeviceVector<float> from_hist_contribs;
learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true);
for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i],
from_full_contribs.ConstHostVector()[i], kRtEps);
}
// Contributions
HostDeviceVector<float> from_full_contribs;
learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true);
HostDeviceVector<float> from_hist_contribs;
learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true);
for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i], from_full_contribs.ConstHostVector()[i],
kRtEps);
}
// Contributions (approximate method)
HostDeviceVector<float> from_full_approx_contribs;
learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true);
HostDeviceVector<float> from_hist_approx_contribs;
learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true);
for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i],
from_full_approx_contribs.ConstHostVector()[i], kRtEps);
}
// Contributions (approximate method)
HostDeviceVector<float> from_full_approx_contribs;
learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true);
HostDeviceVector<float> from_hist_approx_contribs;
learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true);
for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i],
from_full_approx_contribs.ConstHostVector()[i], kRtEps);
}
}
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_idx_t rows,
bst_feature_t cols) {
std::size_t constexpr kClasses { 4 };
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->Device());
std::shared_ptr<DMatrix> m = gen.GenerateDMatrix(true, false, kClasses);
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->Device()).Classes(kClasses);
std::shared_ptr<DMatrix> m = gen.GenerateDMatrix(true);
std::unique_ptr<Learner> learner {
Learner::Create({m})
@ -444,7 +441,8 @@ void TestIterationRange(Context const* ctx) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0)
.Device(ctx->Device())
.GenerateDMatrix(true, true, kClasses);
.Classes(kClasses)
.GenerateDMatrix(true);
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
bool bound = false;
@ -515,7 +513,7 @@ void VerifyIterationRangeColumnSplit(bool use_gpu, Json const &ranged_model,
ctx.UpdateAllowUnknown(
Args{{"nthread", std::to_string(n_threads)}, {"device", ctx.DeviceName()}});
auto dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix(true, true, classes);
auto dmat = RandomDataGenerator(rows, cols, 0).Classes(classes).GenerateDMatrix(true);
std::shared_ptr<DMatrix> Xy{dmat->SliceCol(world_size, rank)};
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
@ -566,7 +564,7 @@ void VerifyIterationRangeColumnSplit(bool use_gpu, Json const &ranged_model,
void TestIterationRangeColumnSplit(int world_size, bool use_gpu) {
std::size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
auto dmat = RandomDataGenerator(kRows, kCols, 0).Classes(kClasses).GenerateDMatrix(true);
Context ctx;
if (use_gpu) {
ctx = MakeCUDACtx(0);
@ -835,4 +833,69 @@ void TestVectorLeafPrediction(Context const *ctx) {
data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0);
run_test(1.5, &data);
}
void ShapExternalMemoryTest::Run(Context const *ctx, bool is_qdm, bool is_interaction) {
bst_idx_t n_samples{2048};
bst_feature_t n_features{16};
bst_target_t n_classes{3};
bst_bin_t max_bin{64};
auto create_pfmat = [&](RandomDataGenerator &rng) {
if (is_qdm) {
return rng.Bins(max_bin).GenerateExtMemQuantileDMatrix("temp", true);
}
return rng.GenerateSparsePageDMatrix("temp", true);
};
auto p_fmat = create_pfmat(RandomDataGenerator(n_samples, n_features, 0)
.Batches(1)
.Device(ctx->Device())
.Classes(n_classes));
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("base_score", "0.5");
learner->SetParam("num_parallel_tree", "3");
learner->SetParam("max_bin", std::to_string(max_bin));
for (std::int32_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_fmat);
}
Json model{Object{}};
learner->SaveModel(&model);
auto j_booster = model["learner"]["gradient_booster"]["model"];
auto model_param = MakeMP(n_features, 0.0, n_classes, ctx->Device());
gbm::GBTreeModel gbtree{&model_param, ctx};
gbtree.LoadModel(j_booster);
std::unique_ptr<Predictor> predictor{
Predictor::Create(ctx->IsCPU() ? "cpu_predictor" : "gpu_predictor", ctx)};
predictor->Configure({});
HostDeviceVector<float> contrib;
if (is_interaction) {
predictor->PredictInteractionContributions(p_fmat.get(), &contrib, gbtree);
} else {
predictor->PredictContribution(p_fmat.get(), &contrib, gbtree);
}
auto p_fmat_ext = create_pfmat(RandomDataGenerator(n_samples, n_features, 0)
.Batches(4)
.Device(ctx->Device())
.Classes(n_classes));
HostDeviceVector<float> contrib_ext;
if (is_interaction) {
predictor->PredictInteractionContributions(p_fmat_ext.get(), &contrib_ext, gbtree);
} else {
predictor->PredictContribution(p_fmat_ext.get(), &contrib_ext, gbtree);
}
ASSERT_EQ(contrib_ext.Size(), contrib.Size());
auto h_contrib = contrib.ConstHostSpan();
auto h_contrib_ext = contrib_ext.ConstHostSpan();
for (std::size_t i = 0; i < h_contrib.size(); ++i) {
ASSERT_EQ(h_contrib[i], h_contrib_ext[i]);
}
}
INSTANTIATE_TEST_SUITE_P(Predictor, ShapExternalMemoryTest,
::testing::Combine(::testing::Bool(), ::testing::Bool()));
} // namespace xgboost

View File

@ -89,8 +89,7 @@ void TestBasic(DMatrix* dmat, Context const * ctx);
// p_full and p_hist should come from the same data set.
void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs = false);
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist);
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_idx_t rows,
bst_feature_t cols);
@ -114,6 +113,11 @@ void TestSparsePrediction(Context const* ctx, float sparsity);
void TestSparsePredictionColumnSplit(int world_size, bool use_gpu, float sparsity);
void TestVectorLeafPrediction(Context const* ctx);
class ShapExternalMemoryTest : public ::testing::TestWithParam<std::tuple<bool, bool>> {
public:
void Run(Context const* ctx, bool is_qdm, bool is_interaction);
};
} // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_

View File

@ -209,7 +209,7 @@ TEST(Learner, ConfigIO) {
bst_idx_t n_samples = 128;
bst_feature_t n_features = 12;
std::shared_ptr<DMatrix> p_fmat{
RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true, false, 2)};
RandomDataGenerator{n_samples, n_features, 0}.Classes(2).GenerateDMatrix(true)};
auto serialised_model_tmp = std::string{};
std::string eval_res_0;

View File

@ -343,32 +343,45 @@ class TestGPUPredict:
strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy
)
@settings(deadline=None, max_examples=20, print_blob=True)
def test_shap(self, num_rounds, dataset, param):
def test_shap(self, num_rounds: int, dataset: tm.TestDataset, param: dict) -> None:
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"tree_method": "hist", "device": "gpu:0"})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
test_dmat = xgb.DMatrix(
dataset.X, dataset.y, weight=dataset.w, base_margin=dataset.margin
)
bst.set_param({"device": "gpu:0"})
shap = bst.predict(test_dmat, pred_contribs=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)
dmat = dataset.get_external_dmat()
shap = bst.predict(dmat, pred_contribs=True)
margin = bst.predict(dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)
@given(
strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy
)
@settings(deadline=None, max_examples=10, print_blob=True)
def test_shap_interactions(self, num_rounds, dataset, param):
def test_shap_interactions(
self, num_rounds: int, dataset: tm.TestDataset, param: dict
) -> None:
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"tree_method": "hist", "device": "cuda:0"})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
test_dmat = xgb.DMatrix(
dataset.X, dataset.y, weight=dataset.w, base_margin=dataset.margin
)
bst.set_param({"device": "cuda:0"})
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
@ -380,6 +393,17 @@ class TestGPUPredict:
1e-3,
)
test_dmat = dataset.get_external_dmat()
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(
np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
margin,
1e-3,
1e-3,
)
def test_shap_categorical(self):
X, y = tm.make_categorical(100, 20, 7, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)