Implement contribution prediction with QuantileDMatrix (#10043)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Louis Desreumaux
2024-02-19 14:03:29 +01:00
committed by GitHub
parent 057f03cacc
commit edf501d227
6 changed files with 137 additions and 55 deletions

View File

@@ -148,7 +148,7 @@ TEST(CPUPredictor, GHistIndexTraining) {
auto adapter = data::ArrayAdapter(columnar.c_str());
std::shared_ptr<DMatrix> p_full{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist);
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist, true);
}
TEST(CPUPredictor, CategoricalPrediction) {

View File

@@ -118,7 +118,8 @@ TEST(Predictor, PredictionCache) {
}
void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist) {
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs) {
size_t constexpr kCols = 16;
size_t constexpr kClasses = 3;
size_t constexpr kIters = 3;
@@ -161,6 +162,28 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
for (size_t i = 0; i < rows; ++i) {
EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps);
}
if (check_contribs) {
// Contributions
HostDeviceVector<float> from_full_contribs;
learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true);
HostDeviceVector<float> from_hist_contribs;
learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true);
for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i],
from_full_contribs.ConstHostVector()[i], kRtEps);
}
// Contributions (approximate method)
HostDeviceVector<float> from_full_approx_contribs;
learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true);
HostDeviceVector<float> from_hist_approx_contribs;
learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true);
for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i],
from_full_approx_contribs.ConstHostVector()[i], kRtEps);
}
}
}
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,

View File

@@ -89,7 +89,8 @@ void TestBasic(DMatrix* dmat, Context const * ctx);
// p_full and p_hist should come from the same data set.
void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist);
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs = false);
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
bst_feature_t cols);