More tests for cpu predictor with column split (#9270)

This commit is contained in:
Rong Ou 2023-06-08 07:47:19 -07:00 committed by GitHub
parent 84d3fcb7ea
commit ff122d61ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 243 additions and 41 deletions

View File

@ -430,8 +430,7 @@ class ColumnSplitHelper {
<< "column-split prediction is only supported for distributed training"; << "column-split prediction is only supported for distributed training";
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) { for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(), CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds); PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
} }
} }
@ -543,8 +542,12 @@ class ColumnSplitHelper {
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) { for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
auto const gid = model_.tree_info[tree_id]; auto const gid = model_.tree_info[tree_id];
for (size_t i = 0; i < block_size; ++i) { for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] += auto const result = PredictOneTree<predict_leaf>(tree_id, batch_offset + i);
PredictOneTree<predict_leaf>(tree_id, batch_offset + i); if constexpr (predict_leaf) {
preds[(predict_offset + i) * (tree_end_ - tree_begin_) + tree_id] = result;
} else {
preds[(predict_offset + i) * num_group + gid] += result;
}
} }
} }
} }
@ -645,6 +648,9 @@ class CPUPredictor : public Predictor {
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds, void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->Info().IsColumnSplit()) { if (p_fmat->Info().IsColumnSplit()) {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict DMatrix with column split" << MTNotImplemented();
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end); ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds); helper.PredictDMatrix(p_fmat, out_preds);
return; return;
@ -743,6 +749,8 @@ class CPUPredictor : public Predictor {
unsigned tree_end) const override { unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get()); auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input."; CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(!p_m->Info().IsColumnSplit())
<< "Inplace predict support for column-wise data split is not yet implemented.";
auto x = proxy->Adapter(); auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) { if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>( this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
@ -773,6 +781,9 @@ class CPUPredictor : public Predictor {
out_preds->resize(model.learner_model_param->num_output_group); out_preds->resize(model.learner_model_param->num_output_group);
if (is_column_split) { if (is_column_split) {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict instance with column split" << MTNotImplemented();
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit); ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
helper.PredictInstance(inst, out_preds); helper.PredictInstance(inst, out_preds);
return; return;
@ -802,6 +813,9 @@ class CPUPredictor : public Predictor {
preds.resize(info.num_row_ * ntree_limit); preds.resize(info.num_row_ * ntree_limit);
if (p_fmat->Info().IsColumnSplit()) { if (p_fmat->Info().IsColumnSplit()) {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict leaf with column split" << MTNotImplemented();
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit); ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
helper.PredictLeaf(p_fmat, &preds); helper.PredictLeaf(p_fmat, &preds);
return; return;

View File

@ -117,7 +117,7 @@ void TestColumnSplit() {
} }
} // anonymous namespace } // anonymous namespace
TEST(CpuPredictor, ColumnSplitBasic) { TEST(CpuPredictor, BasicColumnSplit) {
auto constexpr kWorldSize = 2; auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit); RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit);
} }
@ -126,6 +126,10 @@ TEST(CpuPredictor, IterationRange) {
TestIterationRange("cpu_predictor"); TestIterationRange("cpu_predictor");
} }
TEST(CpuPredictor, IterationRangeColmnSplit) {
TestIterationRangeColumnSplit("cpu_predictor");
}
TEST(CpuPredictor, ExternalMemory) { TEST(CpuPredictor, ExternalMemory) {
size_t constexpr kPageSize = 64, kEntriesPerCol = 3; size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
@ -223,10 +227,18 @@ TEST(CPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("cpu_predictor"); TestCategoricalPrediction("cpu_predictor");
} }
TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
TestCategoricalPredictionColumnSplit("cpu_predictor");
}
TEST(CPUPredictor, CategoricalPredictLeaf) { TEST(CPUPredictor, CategoricalPredictLeaf) {
TestCategoricalPredictLeaf(StringView{"cpu_predictor"}); TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
} }
TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
TestCategoricalPredictLeafColumnSplit(StringView{"cpu_predictor"});
}
TEST(CpuPredictor, UpdatePredictionCache) { TEST(CpuPredictor, UpdatePredictionCache) {
TestUpdatePredictionCache(false); TestUpdatePredictionCache(false);
TestUpdatePredictionCache(true); TestUpdatePredictionCache(true);
@ -236,11 +248,20 @@ TEST(CpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("cpu_predictor"); TestPredictionWithLesserFeatures("cpu_predictor");
} }
TEST(CpuPredictor, LesserFeaturesColumnSplit) {
TestPredictionWithLesserFeaturesColumnSplit("cpu_predictor");
}
TEST(CpuPredictor, Sparse) { TEST(CpuPredictor, Sparse) {
TestSparsePrediction(0.2, "cpu_predictor"); TestSparsePrediction(0.2, "cpu_predictor");
TestSparsePrediction(0.8, "cpu_predictor"); TestSparsePrediction(0.8, "cpu_predictor");
} }
TEST(CpuPredictor, SparseColumnSplit) {
TestSparsePredictionColumnSplit(0.2, "cpu_predictor");
TestSparsePredictionColumnSplit(0.8, "cpu_predictor");
}
TEST(CpuPredictor, Multi) { TEST(CpuPredictor, Multi) {
Context ctx; Context ctx;
ctx.nthread = 1; ctx.nthread = 1;

View File

@ -209,7 +209,6 @@ TEST(GPUPredictor, IterationRange) {
TestIterationRange("gpu_predictor"); TestIterationRange("gpu_predictor");
} }
TEST(GPUPredictor, CategoricalPrediction) { TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor"); TestCategoricalPrediction("gpu_predictor");
} }

View File

@ -153,28 +153,32 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
learner->Configure(); learner->Configure();
} }
void TestPredictionWithLesserFeatures(std::string predictor_name) { namespace {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4; std::unique_ptr<Learner> LearnerForTest(std::shared_ptr<DMatrix> dmat, size_t iters,
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true); size_t forest = 1) {
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false); std::unique_ptr<Learner> learner{Learner::Create({dmat})};
std::unique_ptr<Learner> learner{Learner::Create({m_train})}; learner->SetParams(Args{{"num_parallel_tree", std::to_string(forest)}});
for (size_t i = 0; i < iters; ++i) {
for (size_t i = 0; i < kIters; ++i) { learner->UpdateOneIter(i, dmat);
learner->UpdateOneIter(i, m_train);
} }
return learner;
}
void VerifyPredictionWithLesserFeatures(Learner *learner, std::string const &predictor_name,
size_t rows, std::shared_ptr<DMatrix> const &m_test,
std::shared_ptr<DMatrix> const &m_invalid) {
HostDeviceVector<float> prediction; HostDeviceVector<float> prediction;
learner->SetParam("predictor", predictor_name); learner->SetParam("predictor", predictor_name);
learner->Configure(); learner->Configure();
Json config{Object()}; Json config{Object()};
learner->SaveConfig(&config); learner->SaveConfig(&config);
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name); ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]),
predictor_name);
learner->Predict(m_test, false, &prediction, 0, 0); learner->Predict(m_test, false, &prediction, 0, 0);
ASSERT_EQ(prediction.Size(), kRows); ASSERT_EQ(prediction.Size(), rows);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false); ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction, 0, 0);}, dmlc::Error);
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
HostDeviceVector<float> from_cpu; HostDeviceVector<float> from_cpu;
@ -185,13 +189,49 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
learner->SetParam("predictor", "gpu_predictor"); learner->SetParam("predictor", "gpu_predictor");
learner->Predict(m_test, false, &from_cuda, 0, 0); learner->Predict(m_test, false, &from_cuda, 0, 0);
auto const& h_cpu = from_cpu.ConstHostVector(); auto const &h_cpu = from_cpu.ConstHostVector();
auto const& h_gpu = from_cuda.ConstHostVector(); auto const &h_gpu = from_cuda.ConstHostVector();
for (size_t i = 0; i < h_cpu.size(); ++i) { for (size_t i = 0; i < h_cpu.size(); ++i) {
ASSERT_NEAR(h_cpu[i], h_gpu[i], kRtEps); ASSERT_NEAR(h_cpu[i], h_gpu[i], kRtEps);
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} }
} // anonymous namespace
void TestPredictionWithLesserFeatures(std::string predictor_name) {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto learner = LearnerForTest(m_train, kIters);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
VerifyPredictionWithLesserFeatures(learner.get(), predictor_name, kRows, m_test, m_invalid);
}
namespace {
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner,
std::string const &predictor_name, size_t rows,
std::shared_ptr<DMatrix> m_test,
std::shared_ptr<DMatrix> m_invalid) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
VerifyPredictionWithLesserFeatures(learner, predictor_name, rows, sliced_test, sliced_invalid);
}
} // anonymous namespace
void TestPredictionWithLesserFeaturesColumnSplit(std::string predictor_name) {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto learner = LearnerForTest(m_train, kIters);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, VerifyPredictionWithLesserFeaturesColumnSplit,
learner.get(), predictor_name, kRows, m_test, m_invalid);
}
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind, void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
bst_cat_t split_cat, float left_weight, bst_cat_t split_cat, float left_weight,
@ -212,7 +252,7 @@ void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
model->CommitModelGroup(std::move(trees), 0); model->CommitModelGroup(std::move(trees), 0);
} }
void TestCategoricalPrediction(std::string name) { void TestCategoricalPrediction(std::string name, bool is_column_split) {
size_t constexpr kCols = 10; size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
@ -236,6 +276,9 @@ void TestCategoricalPrediction(std::string name) {
std::vector<FeatureType> types(10, FeatureType::kCategorical); std::vector<FeatureType> types(10, FeatureType::kCategorical);
m->Info().feature_types.HostVector() = types; m->Info().feature_types.HostVector() = types;
if (is_column_split) {
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
}
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictBatch(m.get(), &out_predictions, model, 0); predictor->PredictBatch(m.get(), &out_predictions, model, 0);
@ -246,13 +289,21 @@ void TestCategoricalPrediction(std::string name) {
row[split_ind] = split_cat + 1; row[split_ind] = split_cat + 1;
m = GetDMatrixFromData(row, 1, kCols); m = GetDMatrixFromData(row, 1, kCols);
if (is_column_split) {
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
}
out_predictions.version = 0; out_predictions.version = 0;
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictBatch(m.get(), &out_predictions, model, 0); predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score); ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
} }
void TestCategoricalPredictLeaf(StringView name) { void TestCategoricalPredictionColumnSplit(std::string name) {
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, name, true);
}
void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
size_t constexpr kCols = 10; size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
@ -275,6 +326,9 @@ void TestCategoricalPredictLeaf(StringView name) {
std::vector<float> row(kCols); std::vector<float> row(kCols);
row[split_ind] = split_cat; row[split_ind] = split_cat;
auto m = GetDMatrixFromData(row, 1, kCols); auto m = GetDMatrixFromData(row, 1, kCols);
if (is_column_split) {
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
}
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
CHECK_EQ(out_predictions.predictions.Size(), 1); CHECK_EQ(out_predictions.predictions.Size(), 1);
@ -283,25 +337,25 @@ void TestCategoricalPredictLeaf(StringView name) {
row[split_ind] = split_cat + 1; row[split_ind] = split_cat + 1;
m = GetDMatrixFromData(row, 1, kCols); m = GetDMatrixFromData(row, 1, kCols);
if (is_column_split) {
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
}
out_predictions.version = 0; out_predictions.version = 0;
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1); ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
} }
void TestCategoricalPredictLeafColumnSplit(StringView name) {
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, name, true);
}
void TestIterationRange(std::string name) { void TestIterationRange(std::string name) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3; size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
std::unique_ptr<Learner> learner{Learner::Create({dmat})}; auto learner = LearnerForTest(dmat, kIters, kForest);
learner->SetParams(Args{{"predictor", name}});
learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)},
{"predictor", name}});
size_t kIters = 10;
for (size_t i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, dmat);
}
bool bound = false; bool bound = false;
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)}; std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
@ -363,15 +417,82 @@ void TestIterationRange(std::string name) {
} }
} }
void TestSparsePrediction(float sparsity, std::string predictor) { namespace {
size_t constexpr kRows = 512, kCols = 128; void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *sliced,
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true); std::vector<float> const &expected_margin_ranged,
std::unique_ptr<Learner> learner{Learner::Create({Xy})}; std::vector<float> const &expected_margin_sliced,
learner->Configure(); std::vector<float> const &expected_leaf_ranged,
for (size_t i = 0; i < 4; ++i) { std::vector<float> const &expected_leaf_sliced) {
learner->UpdateOneIter(i, Xy); auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> Xy{dmat->SliceCol(world_size, rank)};
HostDeviceVector<float> out_predt_sliced;
HostDeviceVector<float> out_predt_ranged;
// margin
{
sliced->Predict(Xy, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
learner->Predict(Xy, true, &out_predt_ranged, 0, 3, false, false, false, false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), expected_margin_sliced.size());
ASSERT_EQ(h_sliced, expected_margin_sliced);
ASSERT_EQ(h_range.size(), expected_margin_ranged.size());
ASSERT_EQ(h_range, expected_margin_ranged);
} }
// Leaf
{
sliced->Predict(Xy, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
learner->Predict(Xy, false, &out_predt_ranged, 0, 3, false, true, false, false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), expected_leaf_sliced.size());
ASSERT_EQ(h_sliced, expected_leaf_sliced);
ASSERT_EQ(h_range.size(), expected_leaf_ranged.size());
ASSERT_EQ(h_range, expected_leaf_ranged);
}
}
} // anonymous namespace
void TestIterationRangeColumnSplit(std::string name) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
auto learner = LearnerForTest(dmat, kIters, kForest);
learner->SetParams(Args{{"predictor", name}});
bool bound = false;
std::unique_ptr<Learner> sliced{learner->Slice(0, 3, 1, &bound)};
ASSERT_FALSE(bound);
// margin
HostDeviceVector<float> margin_predt_sliced;
HostDeviceVector<float> margin_predt_ranged;
sliced->Predict(dmat, true, &margin_predt_sliced, 0, 0, false, false, false, false, false);
learner->Predict(dmat, true, &margin_predt_ranged, 0, 3, false, false, false, false, false);
auto const &margin_sliced = margin_predt_sliced.HostVector();
auto const &margin_ranged = margin_predt_ranged.HostVector();
// Leaf
HostDeviceVector<float> leaf_predt_sliced;
HostDeviceVector<float> leaf_predt_ranged;
sliced->Predict(dmat, false, &leaf_predt_sliced, 0, 0, false, true, false, false, false);
learner->Predict(dmat, false, &leaf_predt_ranged, 0, 3, false, true, false, false, false);
auto const &leaf_sliced = leaf_predt_sliced.HostVector();
auto const &leaf_ranged = leaf_predt_ranged.HostVector();
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, VerifyIterationRangeColumnSplit, dmat.get(),
learner.get(), sliced.get(), margin_ranged, margin_sliced,
leaf_ranged, leaf_sliced);
}
void TestSparsePrediction(float sparsity, std::string predictor) {
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
auto learner = LearnerForTest(Xy, kIters);
HostDeviceVector<float> sparse_predt; HostDeviceVector<float> sparse_predt;
Json model{Object{}}; Json model{Object{}};
@ -419,6 +540,43 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
} }
} }
namespace {
void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner,
std::vector<float> const &expected_predt) {
std::shared_ptr<DMatrix> sliced{
dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
HostDeviceVector<float> sparse_predt;
learner->Predict(sliced, false, &sparse_predt, 0, 0);
auto const &predt = sparse_predt.HostVector();
ASSERT_EQ(predt.size(), expected_predt.size());
for (size_t i = 0; i < predt.size(); ++i) {
ASSERT_FLOAT_EQ(predt[i], expected_predt[i]);
}
}
} // anonymous namespace
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor) {
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
auto learner = LearnerForTest(Xy, kIters);
HostDeviceVector<float> sparse_predt;
Json model{Object{}};
learner->SaveModel(&model);
learner.reset(Learner::Create({Xy}));
learner->LoadModel(model);
learner->SetParam("predictor", predictor);
learner->Predict(Xy, false, &sparse_predt, 0, 0);
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, VerifySparsePredictionColumnSplit, Xy.get(),
learner.get(), sparse_predt.HostVector());
}
void TestVectorLeafPrediction(Context const *ctx) { void TestVectorLeafPrediction(Context const *ctx) {
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", ctx)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", ctx));

View File

@ -86,14 +86,24 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
void TestPredictionWithLesserFeatures(std::string preditor_name); void TestPredictionWithLesserFeatures(std::string preditor_name);
void TestCategoricalPrediction(std::string name); void TestPredictionWithLesserFeaturesColumnSplit(std::string preditor_name);
void TestCategoricalPredictLeaf(StringView name); void TestCategoricalPrediction(std::string name, bool is_column_split = false);
void TestCategoricalPredictionColumnSplit(std::string name);
void TestCategoricalPredictLeaf(StringView name, bool is_column_split = false);
void TestCategoricalPredictLeafColumnSplit(StringView name);
void TestIterationRange(std::string name); void TestIterationRange(std::string name);
void TestIterationRangeColumnSplit(std::string name);
void TestSparsePrediction(float sparsity, std::string predictor); void TestSparsePrediction(float sparsity, std::string predictor);
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor);
void TestVectorLeafPrediction(Context const* ctx); void TestVectorLeafPrediction(Context const* ctx);
} // namespace xgboost } // namespace xgboost