Use ellpack for prediction only when sparsepage doesn't exist. (#5504)

This commit is contained in:
Jiaming Yuan
2020-04-10 12:15:46 +08:00
committed by GitHub
parent ad826e913f
commit 6671b42dd4
35 changed files with 166 additions and 116 deletions

View File

@@ -8,11 +8,12 @@
namespace xgboost {
template <typename Page>
void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) {
constexpr size_t kCols { 8 }, kClasses { 3 };
void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
std::shared_ptr<DMatrix> p_hist) {
constexpr size_t kClasses { 3 };
LearnerModelParam param;
param.num_feature = kCols;
param.num_feature = cols;
param.num_output_group = kClasses;
param.base_score = 0.5;
@@ -25,16 +26,10 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins
gbm::GBTreeModel model = CreateTestModel(&param, kClasses);
{
auto p_ellpack = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
// Use same number of bins as rows.
for (auto const &page DMLC_ATTRIBUTE_UNUSED :
p_ellpack->GetBatches<Page>({0, static_cast<int32_t>(bins), 0})) {
}
auto p_precise = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
PredictionCacheEntry approx_out_predictions;
predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0);
predictor->PredictBatch(p_hist.get(), &approx_out_predictions, model, 0);
PredictionCacheEntry precise_out_predictions;
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
@@ -49,14 +44,17 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins
// Predictor should never try to create the histogram index by itself. As only
// histogram index from training data is valid and predictor doesn't known which
// matrix is used for training.
auto p_dmat = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
auto p_dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
PredictionCacheEntry precise_out_predictions;
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
ASSERT_FALSE(p_dmat->PageExists<Page>());
}
}
void TestTrainingPrediction(size_t rows, std::string tree_method);
// p_full and p_hist should come from the same data set.
void TestTrainingPrediction(size_t rows, std::string tree_method,
std::shared_ptr<DMatrix> p_full,
std::shared_ptr<DMatrix> p_hist);
void TestInplacePrediction(dmlc::any x, std::string predictor,
bst_row_t rows, bst_feature_t cols,