Implement categorical data support for SHAP. (#7053)
* Add CPU implementation. * Update GPUTreeSHAP. * Add GPU implementation by defining custom split condition.
This commit is contained in:
@@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(CpuPredictor, IterationRange) {
|
||||
TestIterationRange("cpu_predictor");
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, ExternalMemory) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
|
||||
@@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, IterationRange) {
|
||||
TestIterationRange("gpu_predictor");
|
||||
}
|
||||
|
||||
|
||||
TEST(GPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("gpu_predictor");
|
||||
}
|
||||
|
||||
@@ -281,4 +281,78 @@ void TestCategoricalPredictLeaf(StringView name) {
|
||||
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||
}
|
||||
|
||||
|
||||
void TestIterationRange(std::string name) {
|
||||
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||
|
||||
learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)},
|
||||
{"predictor", name}});
|
||||
|
||||
size_t kIters = 10;
|
||||
for (size_t i = 0; i < kIters; ++i) {
|
||||
learner->UpdateOneIter(i, dmat);
|
||||
}
|
||||
|
||||
bool bound = false;
|
||||
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
|
||||
ASSERT_FALSE(bound);
|
||||
|
||||
HostDeviceVector<float> out_predt_sliced;
|
||||
HostDeviceVector<float> out_predt_ranged;
|
||||
|
||||
// margin
|
||||
{
|
||||
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
|
||||
false, false);
|
||||
|
||||
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
|
||||
false, false);
|
||||
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||
ASSERT_EQ(h_sliced, h_range);
|
||||
}
|
||||
|
||||
// SHAP
|
||||
{
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
|
||||
true, false, false);
|
||||
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
|
||||
false, false);
|
||||
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||
ASSERT_EQ(h_sliced, h_range);
|
||||
}
|
||||
|
||||
// SHAP interaction
|
||||
{
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
|
||||
false, false, true);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
|
||||
false, true);
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||
ASSERT_EQ(h_sliced, h_range);
|
||||
}
|
||||
|
||||
// Leaf
|
||||
{
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
|
||||
false, false, false);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
|
||||
false, false);
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||
ASSERT_EQ(h_sliced, h_range);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -68,6 +68,8 @@ void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||
void TestCategoricalPrediction(std::string name);
|
||||
|
||||
void TestCategoricalPredictLeaf(StringView name);
|
||||
|
||||
void TestIterationRange(std::string name);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
Reference in New Issue
Block a user