Add categorical data support to GPU predictor. (#6165)

This commit is contained in:
Jiaming Yuan
2020-09-29 11:25:34 +08:00
committed by GitHub
parent 7622b8cdb8
commit 798af22ff4
6 changed files with 198 additions and 44 deletions

View File

@@ -221,5 +221,8 @@ TEST(GPUPredictor, Shap) {
}
}
TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}
} // namespace predictor
} // namespace xgboost

View File

@@ -12,6 +12,8 @@
#include "../helpers.h"
#include "../../../src/common/io.h"
#include "../../../src/common/categorical.h"
#include "../../../src/common/bitfield.h"
namespace xgboost {
TEST(Predictor, PredictionCache) {
@@ -27,7 +29,7 @@ TEST(Predictor, PredictionCache) {
};
add_cache();
ASSERT_EQ(container.Container().size(), 0);
ASSERT_EQ(container.Container().size(), 0ul);
add_cache();
EXPECT_ANY_THROW(container.Entry(m));
}
@@ -174,4 +176,55 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
}
#endif // defined(XGBOOST_USE_CUDA)
}
void TestCategoricalPrediction(std::string name) {
size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions;
LearnerModelParam param;
param.num_feature = kCols;
param.num_output_group = 1;
param.base_score = 0.5;
gbm::GBTreeModel model(&param);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
auto& p_tree = trees.front();
uint32_t split_ind = 3;
bst_cat_t split_cat = 4;
float left_weight = 1.3f;
float right_weight = 1.7f;
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
LBitField32 cats_bits(split_cats);
cats_bits.Set(split_cat);
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
left_weight, right_weight,
3.0f, 2.2f, 7.0f, 9.0f);
model.CommitModel(std::move(trees), 0);
GenericParameter runtime;
runtime.gpu_id = 0;
std::unique_ptr<Predictor> predictor{
Predictor::Create(name.c_str(), &runtime)};
std::vector<float> row(kCols);
row[split_ind] = split_cat;
auto m = GetDMatrixFromData(row, 1, kCols);
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
right_weight + param.base_score); // go to right for matching cat
row[split_ind] = split_cat + 1;
m = GetDMatrixFromData(row, 1, kCols);
out_predictions.version = 0;
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
left_weight + param.base_score);
}
} // namespace xgboost

View File

@@ -61,6 +61,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
int32_t device = -1);
void TestPredictionWithLesserFeatures(std::string preditor_name);
void TestCategoricalPrediction(std::string name);
} // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_