Update GPUTreeshap (#6163)
* Reduce shap test duration * Test interoperability with shap package * Add feature interactions * Update GPUTreeShap
This commit is contained in:
@@ -176,12 +176,13 @@ TEST(GPUPredictor, ShapStump) {
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
|
||||
auto gpu_lparam = CreateEmptyGenericParam(0);
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
|
||||
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor>(
|
||||
Predictor::Create("gpu_predictor", &gpu_lparam));
|
||||
gpu_predictor->Configure({});
|
||||
std::vector<float > phis;
|
||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||
gpu_predictor->PredictContribution(dmat.get(), &phis, model);
|
||||
HostDeviceVector<float> predictions;
|
||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||
gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
|
||||
auto& phis = predictions.HostVector();
|
||||
EXPECT_EQ(phis[0], 0.0);
|
||||
EXPECT_EQ(phis[1], param.base_score);
|
||||
EXPECT_EQ(phis[2], 0.0);
|
||||
@@ -202,19 +203,20 @@ TEST(GPUPredictor, Shap) {
|
||||
|
||||
auto gpu_lparam = CreateEmptyGenericParam(0);
|
||||
auto cpu_lparam = CreateEmptyGenericParam(-1);
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam));
|
||||
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor>(
|
||||
Predictor::Create("gpu_predictor", &gpu_lparam));
|
||||
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor>(
|
||||
Predictor::Create("cpu_predictor", &cpu_lparam));
|
||||
gpu_predictor->Configure({});
|
||||
cpu_predictor->Configure({});
|
||||
std::vector<float > phis;
|
||||
std::vector<float > cpu_phis;
|
||||
HostDeviceVector<float> predictions;
|
||||
HostDeviceVector<float> cpu_predictions;
|
||||
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
|
||||
gpu_predictor->PredictContribution(dmat.get(), &phis, model);
|
||||
cpu_predictor->PredictContribution(dmat.get(), &cpu_phis, model);
|
||||
for(auto i = 0ull; i < phis.size(); i++)
|
||||
{
|
||||
gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
|
||||
cpu_predictor->PredictContribution(dmat.get(), &cpu_predictions, model);
|
||||
auto& phis = predictions.HostVector();
|
||||
auto& cpu_phis = cpu_predictions.HostVector();
|
||||
for (auto i = 0ull; i < phis.size(); i++) {
|
||||
EXPECT_NEAR(cpu_phis[i], phis[i], 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user