From 142bdc73ece531d8681779763e473e8df27510a9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 22 Aug 2024 05:25:10 +0800 Subject: [PATCH] [EM] Support SHAP contribution with QDM. (#10724) - Add GPU support. - Add external memory support. - Update the GPU tree shap. --- gputreeshap | 2 +- src/predictor/gpu_predictor.cu | 137 +++++++++++---------- tests/cpp/data/test_simple_dmatrix.cc | 11 +- tests/cpp/gbm/test_gbtree.cc | 10 +- tests/cpp/helpers.cc | 63 +++++----- tests/cpp/helpers.h | 8 +- tests/cpp/objective/test_lambdarank_obj.cc | 3 +- tests/cpp/predictor/test_cpu_predictor.cc | 10 +- tests/cpp/predictor/test_gpu_predictor.cu | 14 ++- tests/cpp/predictor/test_predictor.cc | 133 ++++++++++++++------ tests/cpp/predictor/test_predictor.h | 8 +- tests/cpp/test_learner.cc | 2 +- tests/python-gpu/test_gpu_prediction.py | 32 ++++- 13 files changed, 274 insertions(+), 159 deletions(-) diff --git a/gputreeshap b/gputreeshap index 787259b41..40eae8c4c 160000 --- a/gputreeshap +++ b/gputreeshap @@ -1 +1 @@ -Subproject commit 787259b412c18ab8d5f24bf2b8bd6a59ff8208f3 +Subproject commit 40eae8c4c45974705f8053e4d3d05b88e3cfaefd diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 570872aa5..38d6eca4d 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -143,10 +143,9 @@ struct SparsePageLoader { }; struct EllpackLoader { - EllpackDeviceAccessor const& matrix; - XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t, - float) - : matrix{m} {} + EllpackDeviceAccessor matrix; + XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor m, bool, bst_feature_t, bst_idx_t, float) + : matrix{std::move(m)} {} [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { auto gidx = matrix.GetBinIndex(ridx, fidx); if (gidx == -1) { @@ -162,6 +161,8 @@ struct EllpackLoader { } return matrix.gidx_fvalue_map[gidx - 1]; } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; } }; template @@ -1031,9 +1032,6 @@ class GPUPredictor : public xgboost::Predictor { if (tree_weights != nullptr) { LOG(FATAL) << "Dart booster feature " << not_implemented; } - if (!p_fmat->PageExists()) { - LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU."; - } CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); @@ -1047,8 +1045,8 @@ class GPUPredictor : public xgboost::Predictor { // allocate space for (number of features + bias) times the number of rows size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias - out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns * - model.learner_model_param->num_output_group); + auto dim_size = contributions_columns * model.learner_model_param->num_output_group; + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); out_contribs->Fill(0.0f); auto phis = out_contribs->DeviceSpan(); @@ -1058,16 +1056,27 @@ class GPUPredictor : public xgboost::Predictor { d_model.Init(model, 0, tree_end, ctx_->Device()); dh::device_vector categories; ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device()); - for (auto& batch : p_fmat->GetBatches()) { - batch.data.SetDevice(ctx_->Device()); - batch.offset.SetDevice(ctx_->Device()); - SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), - model.learner_model_param->num_feature); - auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; - gpu_treeshap::GPUTreeShap>( - X, device_paths.begin(), device_paths.end(), ngroup, begin, - dh::tend(phis)); + if (p_fmat->PageExists()) { + for (auto& batch : p_fmat->GetBatches()) { + batch.data.SetDevice(ctx_->Device()); + batch.offset.SetDevice(ctx_->Device()); + SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), + model.learner_model_param->num_feature); + auto begin = dh::tbegin(phis) + batch.base_rowid * dim_size; + gpu_treeshap::GPUTreeShap>( + X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + } + } else { + for (auto& batch : p_fmat->GetBatches(ctx_, {})) { + EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_->Device())}; + auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(), + std::numeric_limits::quiet_NaN()}; + auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size; + gpu_treeshap::GPUTreeShap>( + X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + } } + // Add the base margin term to last column p_fmat->Info().base_margin_.SetDevice(ctx_->Device()); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); @@ -1094,9 +1103,6 @@ class GPUPredictor : public xgboost::Predictor { if (tree_weights != nullptr) { LOG(FATAL) << "Dart booster feature " << not_implemented; } - if (!p_fmat->PageExists()) { - LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU."; - } dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); out_contribs->SetDevice(ctx_->Device()); if (tree_end == 0 || tree_end > model.trees.size()) { @@ -1108,9 +1114,9 @@ class GPUPredictor : public xgboost::Predictor { // allocate space for (number of features + bias) times the number of rows size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias - out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns * - contributions_columns * - model.learner_model_param->num_output_group); + auto dim_size = + contributions_columns * contributions_columns * model.learner_model_param->num_output_group; + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); out_contribs->Fill(0.0f); auto phis = out_contribs->DeviceSpan(); @@ -1120,16 +1126,29 @@ class GPUPredictor : public xgboost::Predictor { d_model.Init(model, 0, tree_end, ctx_->Device()); dh::device_vector categories; ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device()); - for (auto& batch : p_fmat->GetBatches()) { - batch.data.SetDevice(ctx_->Device()); - batch.offset.SetDevice(ctx_->Device()); - SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), - model.learner_model_param->num_feature); - auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; - gpu_treeshap::GPUTreeShapInteractions>( - X, device_paths.begin(), device_paths.end(), ngroup, begin, - dh::tend(phis)); + if (p_fmat->PageExists()) { + for (auto const& batch : p_fmat->GetBatches()) { + batch.data.SetDevice(ctx_->Device()); + batch.offset.SetDevice(ctx_->Device()); + SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), + model.learner_model_param->num_feature); + auto begin = dh::tbegin(phis) + batch.base_rowid * dim_size; + gpu_treeshap::GPUTreeShapInteractions>( + X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + } + } else { + for (auto const& batch : p_fmat->GetBatches(ctx_, {})) { + auto impl = batch.Impl(); + auto acc = + impl->GetDeviceAccessor(ctx_->Device(), p_fmat->Info().feature_types.ConstDeviceSpan()); + auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size; + auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(), + std::numeric_limits::quiet_NaN()}; + gpu_treeshap::GPUTreeShapInteractions>( + X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + } } + // Add the base margin term to last column p_fmat->Info().base_margin_.SetDevice(ctx_->Device()); const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); @@ -1180,51 +1199,35 @@ class GPUPredictor : public xgboost::Predictor { bool use_shared = shared_memory_bytes != 0; bst_feature_t num_features = info.num_col_; + auto launch = [&](auto fn, std::uint32_t grid, auto data, bst_idx_t batch_offset) { + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes}( + fn, data, d_model.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(), + + d_model.split_types.ConstDeviceSpan(), d_model.categories_tree_segments.ConstDeviceSpan(), + d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(), + + d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, use_shared, + std::numeric_limits::quiet_NaN()); + }; + if (p_fmat->PageExists()) { + bst_idx_t batch_offset = 0; for (auto const& batch : p_fmat->GetBatches()) { batch.data.SetDevice(ctx_->Device()); batch.offset.SetDevice(ctx_->Device()); - bst_idx_t batch_offset = 0; SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature}; - size_t num_rows = batch.Size(); - auto grid = - static_cast(common::DivRoundUp(num_rows, kBlockThreads)); - dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( - PredictLeafKernel, data, - d_model.nodes.ConstDeviceSpan(), - predictions->DeviceSpan().subspan(batch_offset), - d_model.tree_segments.ConstDeviceSpan(), - - d_model.split_types.ConstDeviceSpan(), - d_model.categories_tree_segments.ConstDeviceSpan(), - d_model.categories_node_segments.ConstDeviceSpan(), - d_model.categories.ConstDeviceSpan(), - - d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, - use_shared, std::numeric_limits::quiet_NaN()); + auto grid = static_cast(common::DivRoundUp(batch.Size(), kBlockThreads)); + launch(PredictLeafKernel, grid, data, batch_offset); batch_offset += batch.Size(); } } else { + bst_idx_t batch_offset = 0; for (auto const& batch : p_fmat->GetBatches(ctx_, BatchParam{})) { - bst_idx_t batch_offset = 0; EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())}; - size_t num_rows = batch.Size(); - auto grid = - static_cast(common::DivRoundUp(num_rows, kBlockThreads)); - dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( - PredictLeafKernel, data, - d_model.nodes.ConstDeviceSpan(), - predictions->DeviceSpan().subspan(batch_offset), - d_model.tree_segments.ConstDeviceSpan(), - - d_model.split_types.ConstDeviceSpan(), - d_model.categories_tree_segments.ConstDeviceSpan(), - d_model.categories_node_segments.ConstDeviceSpan(), - d_model.categories.ConstDeviceSpan(), - - d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, - use_shared, std::numeric_limits::quiet_NaN()); + auto grid = static_cast(common::DivRoundUp(batch.Size(), kBlockThreads)); + launch(PredictLeafKernel, grid, data, batch_offset); batch_offset += batch.Size(); } } diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index ea6eedbb2..16448c2e1 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -1,5 +1,5 @@ /** - * Copyright 2016-2023 by XGBoost Contributors + * Copyright 2016-2024, XGBoost Contributors */ #include @@ -434,12 +434,11 @@ namespace { void VerifyColumnSplit() { size_t constexpr kRows {16}; size_t constexpr kCols {8}; - auto dmat = - RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol); + auto p_fmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, DataSplitMode::kCol); - ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize()); - ASSERT_EQ(dmat->Info().num_row_, kRows); - ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol); + ASSERT_EQ(p_fmat->Info().num_col_, kCols * collective::GetWorldSize()); + ASSERT_EQ(p_fmat->Info().num_row_, kRows); + ASSERT_EQ(p_fmat->Info().data_split_mode, DataSplitMode::kCol); } } // anonymous namespace diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 8a5383ad4..79e236f11 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023, XGBoost contributors + * Copyright 2019-2024, XGBoost contributors */ #include #include @@ -463,7 +463,7 @@ INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU")); std::pair TestModelSlice(std::string booster) { size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3; - auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses); + auto m = RandomDataGenerator{kRows, kCols, 0}.Classes(kClasses).GenerateDMatrix(true); int32_t kIters = 10; std::unique_ptr learner { @@ -592,7 +592,7 @@ TEST(Dart, Slice) { TEST(GBTree, FeatureScore) { size_t n_samples = 1000, n_features = 10, n_classes = 4; - auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); + auto m = RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true); std::unique_ptr learner{ Learner::Create({m}) }; learner->SetParam("num_class", std::to_string(n_classes)); @@ -629,7 +629,7 @@ TEST(GBTree, FeatureScore) { TEST(GBTree, PredictRange) { size_t n_samples = 1000, n_features = 10, n_classes = 4; - auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); + auto m = RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true); std::unique_ptr learner{Learner::Create({m})}; learner->SetParam("num_class", std::to_string(n_classes)); @@ -642,7 +642,7 @@ TEST(GBTree, PredictRange) { ASSERT_THROW(learner->Predict(m, false, &out_predt, 0, 3), dmlc::Error); auto m_1 = - RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); + RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true); HostDeviceVector out_predt_full; learner->Predict(m_1, false, &out_predt_full, 0, 0); ASSERT_TRUE(std::equal(out_predt.HostVector().begin(), out_predt.HostVector().end(), diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index ae5698d2c..3dbf18970 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -376,8 +376,33 @@ void RandomDataGenerator::GenerateCSR( CHECK_EQ(columns->Size(), value->Size()); } +namespace { +void MakeLabels(DeviceOrd device, bst_idx_t n_samples, bst_target_t n_classes, + bst_target_t n_targets, std::shared_ptr out) { + RandomDataGenerator gen{n_samples, n_targets, 0.0f}; + if (n_classes != 0) { + gen.Lower(0).Upper(n_classes).GenerateDense(out->Info().labels.Data()); + out->Info().labels.Reshape(n_samples, n_targets); + auto& h_labels = out->Info().labels.Data()->HostVector(); + for (auto& v : h_labels) { + v = static_cast(static_cast(v)); + } + } else { + gen.GenerateDense(out->Info().labels.Data()); + CHECK_EQ(out->Info().labels.Size(), n_samples * n_targets); + out->Info().labels.Reshape(n_samples, n_targets); + } + if (device.IsCUDA()) { + out->Info().labels.Data()->SetDevice(device); + out->Info().labels.Data()->ConstDevicePointer(); + out->Info().feature_types.SetDevice(device); + out->Info().feature_types.ConstDevicePointer(); + } +} +} // namespace + [[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateDMatrix( - bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const { + bool with_label, DataSplitMode data_split_mode) const { HostDeviceVector data; HostDeviceVector rptrs; HostDeviceVector columns; @@ -388,19 +413,7 @@ void RandomDataGenerator::GenerateCSR( DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1, "", data_split_mode)}; if (with_label) { - RandomDataGenerator gen{rows_, n_targets_, 0.0f}; - if (!float_label) { - gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data()); - out->Info().labels.Reshape(this->rows_, this->n_targets_); - auto& h_labels = out->Info().labels.Data()->HostVector(); - for (auto& v : h_labels) { - v = static_cast(static_cast(v)); - } - } else { - gen.GenerateDense(out->Info().labels.Data()); - CHECK_EQ(out->Info().labels.Size(), this->rows_ * this->n_targets_); - out->Info().labels.Reshape(this->rows_, this->n_targets_); - } + MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, out); } if (device_.IsCUDA()) { out->Info().labels.SetDevice(device_); @@ -435,34 +448,31 @@ void RandomDataGenerator::GenerateCSR( #endif // defined(XGBOOST_USE_CUDA) } - std::unique_ptr dmat{DMatrix::Create( + std::shared_ptr p_fmat{DMatrix::Create( static_cast(iter.get()), iter->Proxy(), Reset, Next, std::numeric_limits::quiet_NaN(), Context{}.Threads(), prefix, on_host_)}; auto row_page_path = - data::MakeId(prefix, dynamic_cast(dmat.get())) + ".row.page"; + data::MakeId(prefix, dynamic_cast(p_fmat.get())) + ".row.page"; EXPECT_TRUE(FileExists(row_page_path)) << row_page_path; // Loop over the batches and count the number of pages std::size_t batch_count = 0; bst_idx_t row_count = 0; - for (const auto& batch : dmat->GetBatches()) { + for (const auto& batch : p_fmat->GetBatches()) { batch_count++; row_count += batch.Size(); CHECK_NE(batch.data.Size(), 0); } EXPECT_EQ(batch_count, n_batches_); - EXPECT_EQ(dmat->NumBatches(), n_batches_); - EXPECT_EQ(row_count, dmat->Info().num_row_); + EXPECT_EQ(p_fmat->NumBatches(), n_batches_); + EXPECT_EQ(row_count, p_fmat->Info().num_row_); if (with_label) { - RandomDataGenerator{static_cast(dmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense( - dmat->Info().labels.Data()); - CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_); - dmat->Info().labels.Reshape(this->rows_, this->n_targets_); + MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, p_fmat); } - return dmat; + return p_fmat; } [[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateExtMemQuantileDMatrix( @@ -492,10 +502,7 @@ void RandomDataGenerator::GenerateCSR( } if (with_label) { - RandomDataGenerator{static_cast(p_fmat->Info().num_row_), this->n_targets_, 0.0f} - .GenerateDense(p_fmat->Info().labels.Data()); - CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_); - p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_); + MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, p_fmat); } return p_fmat; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index a8d5f370f..8e4e82a91 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -229,6 +229,7 @@ class RandomDataGenerator { float upper_{1.0f}; bst_target_t n_targets_{1}; + bst_target_t n_classes_{0}; DeviceOrd device_{DeviceOrd::CPU()}; std::size_t n_batches_{0}; @@ -291,6 +292,10 @@ class RandomDataGenerator { n_targets_ = n_targets; return *this; } + RandomDataGenerator& Classes(bst_target_t n_classes) { + n_classes_ = n_classes; + return *this; + } void GenerateDense(HostDeviceVector* out) const; @@ -315,8 +320,7 @@ class RandomDataGenerator { HostDeviceVector* columns) const; [[nodiscard]] std::shared_ptr GenerateDMatrix( - bool with_label = false, bool float_label = true, size_t classes = 1, - DataSplitMode data_split_mode = DataSplitMode::kRow) const; + bool with_label = false, DataSplitMode data_split_mode = DataSplitMode::kRow) const; [[nodiscard]] std::shared_ptr GenerateSparsePageDMatrix(std::string prefix, bool with_label) const; diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index 2b34cfa38..a9249fc28 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -119,7 +119,8 @@ void TestUnbiasedNDCG(Context const* ctx) { obj->Configure(Args{{"lambdarank_pair_method", "topk"}, {"lambdarank_unbiased", "true"}, {"lambdarank_bias_norm", "0"}}); - std::shared_ptr p_fmat{RandomDataGenerator{10, 1, 0.0f}.GenerateDMatrix(true, false, 2)}; + std::shared_ptr p_fmat{ + RandomDataGenerator{10, 1, 0.0f}.Classes(2).GenerateDMatrix(true)}; auto h_label = p_fmat->Info().labels.HostView().Values(); // Move clicked samples to the beginning. std::sort(h_label.begin(), h_label.end(), std::greater<>{}); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index ee28adb15..2a1b43bf7 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -61,6 +61,12 @@ TEST(CpuPredictor, ExternalMemory) { TestBasic(dmat.get(), &ctx); } +TEST_P(ShapExternalMemoryTest, CPUPredictor) { + Context ctx; + auto [is_qdm, is_interaction] = this->GetParam(); + this->Run(&ctx, is_qdm, is_interaction); +} + TEST(CpuPredictor, InplacePredict) { bst_idx_t constexpr kRows{128}; bst_feature_t constexpr kCols{64}; @@ -110,7 +116,7 @@ void TestUpdatePredictionCache(bool use_subsampling) { } gbm->Configure(args); - auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); + auto dmat = RandomDataGenerator(kRows, kCols, 0).Classes(kClasses).GenerateDMatrix(true); linalg::Matrix gpair({kRows, kClasses}, ctx.Device()); auto h_gpair = gpair.HostView(); @@ -145,7 +151,7 @@ TEST(CPUPredictor, GHistIndexTraining) { auto adapter = data::ArrayAdapter(columnar.c_str()); std::shared_ptr p_full{ DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1)}; - TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist, true); + TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist); } TEST(CPUPredictor, CategoricalPrediction) { diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 5e3021fd7..366d0ab6a 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023, XGBoost contributors + * Copyright 2017-2024, XGBoost contributors */ #include #include @@ -17,7 +17,6 @@ #include "test_predictor.h" namespace xgboost::predictor { - TEST(GPUPredictor, Basic) { auto cpu_lparam = MakeCUDACtx(-1); auto gpu_lparam = MakeCUDACtx(0); @@ -269,10 +268,9 @@ TEST(GPUPredictor, Shap) { trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0); model.CommitModelGroup(std::move(trees), 0); - auto gpu_lparam = MakeCUDACtx(0); auto cpu_lparam = MakeCUDACtx(-1); - std::unique_ptr gpu_predictor = std::unique_ptr( - Predictor::Create("gpu_predictor", &gpu_lparam)); + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor", &ctx)); std::unique_ptr cpu_predictor = std::unique_ptr( Predictor::Create("cpu_predictor", &cpu_lparam)); gpu_predictor->Configure({}); @@ -289,6 +287,12 @@ TEST(GPUPredictor, Shap) { } } +TEST_P(ShapExternalMemoryTest, GPUPredictor) { + auto ctx = MakeCUDACtx(0); + auto [is_qdm, is_interaction] = this->GetParam(); + this->Run(&ctx, is_qdm, is_interaction); +} + TEST(GPUPredictor, IterationRange) { auto ctx = MakeCUDACtx(0); TestIterationRange(&ctx); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index b79b75012..1af873f58 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -4,15 +4,16 @@ #include "test_predictor.h" #include -#include // for Context -#include // for DMatrix, BatchIterator, BatchSet, MetaInfo -#include // for HostDeviceVector -#include // for PredictionCacheEntry, Predictor, Predic... -#include // for StringView +#include // for Context +#include // for DMatrix, BatchIterator, BatchSet, MetaInfo +#include // for HostDeviceVector +#include // for Json +#include // for PredictionCacheEntry, Predictor, Predic... +#include // for StringView -#include // for numeric_limits -#include // for shared_ptr -#include // for unordered_map +#include // for numeric_limits +#include // for shared_ptr +#include // for unordered_map #include "../../../src/common/bitfield.h" // for LBitField32 #include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix @@ -26,7 +27,6 @@ #include "xgboost/tree_model.h" // for RegTree namespace xgboost { - void TestBasic(DMatrix* dmat, Context const *ctx) { auto predictor = std::unique_ptr(CreatePredictorForTest(ctx)); @@ -118,8 +118,7 @@ TEST(Predictor, PredictionCache) { } void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins, - std::shared_ptr p_full, std::shared_ptr p_hist, - bool check_contribs) { + std::shared_ptr p_full, std::shared_ptr p_hist) { size_t constexpr kCols = 16; size_t constexpr kClasses = 3; size_t constexpr kIters = 3; @@ -163,34 +162,32 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins, EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps); } - if (check_contribs) { - // Contributions - HostDeviceVector from_full_contribs; - learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true); - HostDeviceVector from_hist_contribs; - learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true); - for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) { - EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i], - from_full_contribs.ConstHostVector()[i], kRtEps); - } + // Contributions + HostDeviceVector from_full_contribs; + learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true); + HostDeviceVector from_hist_contribs; + learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true); + for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) { + EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i], from_full_contribs.ConstHostVector()[i], + kRtEps); + } - // Contributions (approximate method) - HostDeviceVector from_full_approx_contribs; - learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true); - HostDeviceVector from_hist_approx_contribs; - learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true); - for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) { - EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i], - from_full_approx_contribs.ConstHostVector()[i], kRtEps); - } + // Contributions (approximate method) + HostDeviceVector from_full_approx_contribs; + learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true); + HostDeviceVector from_hist_approx_contribs; + learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true); + for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) { + EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i], + from_full_approx_contribs.ConstHostVector()[i], kRtEps); } } void TestInplacePrediction(Context const *ctx, std::shared_ptr x, bst_idx_t rows, bst_feature_t cols) { std::size_t constexpr kClasses { 4 }; - auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->Device()); - std::shared_ptr m = gen.GenerateDMatrix(true, false, kClasses); + auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->Device()).Classes(kClasses); + std::shared_ptr m = gen.GenerateDMatrix(true); std::unique_ptr learner { Learner::Create({m}) @@ -444,7 +441,8 @@ void TestIterationRange(Context const* ctx) { size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10; auto dmat = RandomDataGenerator(kRows, kCols, 0) .Device(ctx->Device()) - .GenerateDMatrix(true, true, kClasses); + .Classes(kClasses) + .GenerateDMatrix(true); auto learner = LearnerForTest(ctx, dmat, kIters, kForest); bool bound = false; @@ -515,7 +513,7 @@ void VerifyIterationRangeColumnSplit(bool use_gpu, Json const &ranged_model, ctx.UpdateAllowUnknown( Args{{"nthread", std::to_string(n_threads)}, {"device", ctx.DeviceName()}}); - auto dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix(true, true, classes); + auto dmat = RandomDataGenerator(rows, cols, 0).Classes(classes).GenerateDMatrix(true); std::shared_ptr Xy{dmat->SliceCol(world_size, rank)}; std::unique_ptr learner{Learner::Create({Xy})}; @@ -566,7 +564,7 @@ void VerifyIterationRangeColumnSplit(bool use_gpu, Json const &ranged_model, void TestIterationRangeColumnSplit(int world_size, bool use_gpu) { std::size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10; - auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); + auto dmat = RandomDataGenerator(kRows, kCols, 0).Classes(kClasses).GenerateDMatrix(true); Context ctx; if (use_gpu) { ctx = MakeCUDACtx(0); @@ -835,4 +833,69 @@ void TestVectorLeafPrediction(Context const *ctx) { data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0); run_test(1.5, &data); } + +void ShapExternalMemoryTest::Run(Context const *ctx, bool is_qdm, bool is_interaction) { + bst_idx_t n_samples{2048}; + bst_feature_t n_features{16}; + bst_target_t n_classes{3}; + bst_bin_t max_bin{64}; + auto create_pfmat = [&](RandomDataGenerator &rng) { + if (is_qdm) { + return rng.Bins(max_bin).GenerateExtMemQuantileDMatrix("temp", true); + } + return rng.GenerateSparsePageDMatrix("temp", true); + }; + auto p_fmat = create_pfmat(RandomDataGenerator(n_samples, n_features, 0) + .Batches(1) + .Device(ctx->Device()) + .Classes(n_classes)); + std::unique_ptr learner{Learner::Create({p_fmat})}; + learner->SetParam("device", ctx->DeviceName()); + learner->SetParam("base_score", "0.5"); + learner->SetParam("num_parallel_tree", "3"); + learner->SetParam("max_bin", std::to_string(max_bin)); + for (std::int32_t i = 0; i < 4; ++i) { + learner->UpdateOneIter(i, p_fmat); + } + Json model{Object{}}; + learner->SaveModel(&model); + auto j_booster = model["learner"]["gradient_booster"]["model"]; + auto model_param = MakeMP(n_features, 0.0, n_classes, ctx->Device()); + + gbm::GBTreeModel gbtree{&model_param, ctx}; + gbtree.LoadModel(j_booster); + + std::unique_ptr predictor{ + Predictor::Create(ctx->IsCPU() ? "cpu_predictor" : "gpu_predictor", ctx)}; + predictor->Configure({}); + HostDeviceVector contrib; + if (is_interaction) { + predictor->PredictInteractionContributions(p_fmat.get(), &contrib, gbtree); + } else { + predictor->PredictContribution(p_fmat.get(), &contrib, gbtree); + } + + auto p_fmat_ext = create_pfmat(RandomDataGenerator(n_samples, n_features, 0) + .Batches(4) + .Device(ctx->Device()) + .Classes(n_classes)); + + HostDeviceVector contrib_ext; + if (is_interaction) { + predictor->PredictInteractionContributions(p_fmat_ext.get(), &contrib_ext, gbtree); + } else { + predictor->PredictContribution(p_fmat_ext.get(), &contrib_ext, gbtree); + } + + ASSERT_EQ(contrib_ext.Size(), contrib.Size()); + + auto h_contrib = contrib.ConstHostSpan(); + auto h_contrib_ext = contrib_ext.ConstHostSpan(); + for (std::size_t i = 0; i < h_contrib.size(); ++i) { + ASSERT_EQ(h_contrib[i], h_contrib_ext[i]); + } +} + +INSTANTIATE_TEST_SUITE_P(Predictor, ShapExternalMemoryTest, + ::testing::Combine(::testing::Bool(), ::testing::Bool())); } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 1ccd35102..8f110efe0 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -89,8 +89,7 @@ void TestBasic(DMatrix* dmat, Context const * ctx); // p_full and p_hist should come from the same data set. void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins, - std::shared_ptr p_full, std::shared_ptr p_hist, - bool check_contribs = false); + std::shared_ptr p_full, std::shared_ptr p_hist); void TestInplacePrediction(Context const* ctx, std::shared_ptr x, bst_idx_t rows, bst_feature_t cols); @@ -114,6 +113,11 @@ void TestSparsePrediction(Context const* ctx, float sparsity); void TestSparsePredictionColumnSplit(int world_size, bool use_gpu, float sparsity); void TestVectorLeafPrediction(Context const* ctx); + +class ShapExternalMemoryTest : public ::testing::TestWithParam> { + public: + void Run(Context const* ctx, bool is_qdm, bool is_interaction); +}; } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_ diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index a6f3eacec..d53a568d4 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -209,7 +209,7 @@ TEST(Learner, ConfigIO) { bst_idx_t n_samples = 128; bst_feature_t n_features = 12; std::shared_ptr p_fmat{ - RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true, false, 2)}; + RandomDataGenerator{n_samples, n_features, 0}.Classes(2).GenerateDMatrix(true)}; auto serialised_model_tmp = std::string{}; std::string eval_res_0; diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 98b60aecf..b3ccf4ae5 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -343,32 +343,45 @@ class TestGPUPredict: strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy ) @settings(deadline=None, max_examples=20, print_blob=True) - def test_shap(self, num_rounds, dataset, param): + def test_shap(self, num_rounds: int, dataset: tm.TestDataset, param: dict) -> None: if dataset.name.endswith("-l1"): # not supported by the exact tree method return param.update({"tree_method": "hist", "device": "gpu:0"}) param = dataset.set_params(param) dmat = dataset.get_dmat() bst = xgb.train(param, dmat, num_rounds) - test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin) + test_dmat = xgb.DMatrix( + dataset.X, dataset.y, weight=dataset.w, base_margin=dataset.margin + ) bst.set_param({"device": "gpu:0"}) shap = bst.predict(test_dmat, pred_contribs=True) margin = bst.predict(test_dmat, output_margin=True) assume(len(dataset.y) > 0) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3) + dmat = dataset.get_external_dmat() + shap = bst.predict(dmat, pred_contribs=True) + margin = bst.predict(dmat, output_margin=True) + assume(len(dataset.y) > 0) + assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3) + @given( strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy ) @settings(deadline=None, max_examples=10, print_blob=True) - def test_shap_interactions(self, num_rounds, dataset, param): + def test_shap_interactions( + self, num_rounds: int, dataset: tm.TestDataset, param: dict + ) -> None: if dataset.name.endswith("-l1"): # not supported by the exact tree method return param.update({"tree_method": "hist", "device": "cuda:0"}) param = dataset.set_params(param) dmat = dataset.get_dmat() bst = xgb.train(param, dmat, num_rounds) - test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin) + + test_dmat = xgb.DMatrix( + dataset.X, dataset.y, weight=dataset.w, base_margin=dataset.margin + ) bst.set_param({"device": "cuda:0"}) shap = bst.predict(test_dmat, pred_interactions=True) margin = bst.predict(test_dmat, output_margin=True) @@ -380,6 +393,17 @@ class TestGPUPredict: 1e-3, ) + test_dmat = dataset.get_external_dmat() + shap = bst.predict(test_dmat, pred_interactions=True) + margin = bst.predict(test_dmat, output_margin=True) + assume(len(dataset.y) > 0) + assert np.allclose( + np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), + margin, + 1e-3, + 1e-3, + ) + def test_shap_categorical(self): X, y = tm.make_categorical(100, 20, 7, False) Xy = xgb.DMatrix(X, y, enable_categorical=True)