Implement categorical data support for SHAP. (#7053)

* Add CPU implementation.
* Update GPUTreeSHAP.
* Add GPU implementation by defining custom split condition.
This commit is contained in:
Jiaming Yuan
2021-06-25 19:02:46 +08:00
committed by GitHub
parent 663136aa08
commit 8fa32fdda2
12 changed files with 287 additions and 50 deletions

View File

@@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) {
}
}
TEST(CpuPredictor, IterationRange) {
TestIterationRange("cpu_predictor");
}
TEST(CpuPredictor, ExternalMemory) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";

View File

@@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) {
}
}
TEST(GPUPredictor, IterationRange) {
TestIterationRange("gpu_predictor");
}
TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}

View File

@@ -281,4 +281,78 @@ void TestCategoricalPredictLeaf(StringView name) {
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
}
void TestIterationRange(std::string name) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)},
{"predictor", name}});
size_t kIters = 10;
for (size_t i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, dmat);
}
bool bound = false;
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
ASSERT_FALSE(bound);
HostDeviceVector<float> out_predt_sliced;
HostDeviceVector<float> out_predt_ranged;
// margin
{
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
false, false);
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
// SHAP
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
true, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
// SHAP interaction
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
false, false, true);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
false, true);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
// Leaf
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
false, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
}
} // namespace xgboost

View File

@@ -68,6 +68,8 @@ void TestPredictionWithLesserFeatures(std::string preditor_name);
void TestCategoricalPrediction(std::string name);
void TestCategoricalPredictLeaf(StringView name);
void TestIterationRange(std::string name);
} // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_