More tests for cpu predictor with column split (#9270)
This commit is contained in:
parent
84d3fcb7ea
commit
ff122d61ff
@ -430,8 +430,7 @@ class ColumnSplitHelper {
|
|||||||
<< "column-split prediction is only supported for distributed training";
|
<< "column-split prediction is only supported for distributed training";
|
||||||
|
|
||||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
CHECK_EQ(out_preds->size(),
|
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
|
||||||
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
|
||||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
|
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -543,8 +542,12 @@ class ColumnSplitHelper {
|
|||||||
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
|
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
|
||||||
auto const gid = model_.tree_info[tree_id];
|
auto const gid = model_.tree_info[tree_id];
|
||||||
for (size_t i = 0; i < block_size; ++i) {
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
preds[(predict_offset + i) * num_group + gid] +=
|
auto const result = PredictOneTree<predict_leaf>(tree_id, batch_offset + i);
|
||||||
PredictOneTree<predict_leaf>(tree_id, batch_offset + i);
|
if constexpr (predict_leaf) {
|
||||||
|
preds[(predict_offset + i) * (tree_end_ - tree_begin_) + tree_id] = result;
|
||||||
|
} else {
|
||||||
|
preds[(predict_offset + i) * num_group + gid] += result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -645,6 +648,9 @@ class CPUPredictor : public Predictor {
|
|||||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||||
if (p_fmat->Info().IsColumnSplit()) {
|
if (p_fmat->Info().IsColumnSplit()) {
|
||||||
|
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||||
|
<< "Predict DMatrix with column split" << MTNotImplemented();
|
||||||
|
|
||||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||||
helper.PredictDMatrix(p_fmat, out_preds);
|
helper.PredictDMatrix(p_fmat, out_preds);
|
||||||
return;
|
return;
|
||||||
@ -743,6 +749,8 @@ class CPUPredictor : public Predictor {
|
|||||||
unsigned tree_end) const override {
|
unsigned tree_end) const override {
|
||||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||||
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
|
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
|
||||||
|
CHECK(!p_m->Info().IsColumnSplit())
|
||||||
|
<< "Inplace predict support for column-wise data split is not yet implemented.";
|
||||||
auto x = proxy->Adapter();
|
auto x = proxy->Adapter();
|
||||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||||
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
|
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
|
||||||
@ -773,6 +781,9 @@ class CPUPredictor : public Predictor {
|
|||||||
out_preds->resize(model.learner_model_param->num_output_group);
|
out_preds->resize(model.learner_model_param->num_output_group);
|
||||||
|
|
||||||
if (is_column_split) {
|
if (is_column_split) {
|
||||||
|
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||||
|
<< "Predict instance with column split" << MTNotImplemented();
|
||||||
|
|
||||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
||||||
helper.PredictInstance(inst, out_preds);
|
helper.PredictInstance(inst, out_preds);
|
||||||
return;
|
return;
|
||||||
@ -802,6 +813,9 @@ class CPUPredictor : public Predictor {
|
|||||||
preds.resize(info.num_row_ * ntree_limit);
|
preds.resize(info.num_row_ * ntree_limit);
|
||||||
|
|
||||||
if (p_fmat->Info().IsColumnSplit()) {
|
if (p_fmat->Info().IsColumnSplit()) {
|
||||||
|
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||||
|
<< "Predict leaf with column split" << MTNotImplemented();
|
||||||
|
|
||||||
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
||||||
helper.PredictLeaf(p_fmat, &preds);
|
helper.PredictLeaf(p_fmat, &preds);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -117,7 +117,7 @@ void TestColumnSplit() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(CpuPredictor, ColumnSplitBasic) {
|
TEST(CpuPredictor, BasicColumnSplit) {
|
||||||
auto constexpr kWorldSize = 2;
|
auto constexpr kWorldSize = 2;
|
||||||
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit);
|
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit);
|
||||||
}
|
}
|
||||||
@ -126,6 +126,10 @@ TEST(CpuPredictor, IterationRange) {
|
|||||||
TestIterationRange("cpu_predictor");
|
TestIterationRange("cpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CpuPredictor, IterationRangeColmnSplit) {
|
||||||
|
TestIterationRangeColumnSplit("cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, ExternalMemory) {
|
TEST(CpuPredictor, ExternalMemory) {
|
||||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||||
@ -223,10 +227,18 @@ TEST(CPUPredictor, CategoricalPrediction) {
|
|||||||
TestCategoricalPrediction("cpu_predictor");
|
TestCategoricalPrediction("cpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
|
||||||
|
TestCategoricalPredictionColumnSplit("cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
||||||
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
|
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
|
||||||
|
TestCategoricalPredictLeafColumnSplit(StringView{"cpu_predictor"});
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||||
TestUpdatePredictionCache(false);
|
TestUpdatePredictionCache(false);
|
||||||
TestUpdatePredictionCache(true);
|
TestUpdatePredictionCache(true);
|
||||||
@ -236,11 +248,20 @@ TEST(CpuPredictor, LesserFeatures) {
|
|||||||
TestPredictionWithLesserFeatures("cpu_predictor");
|
TestPredictionWithLesserFeatures("cpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CpuPredictor, LesserFeaturesColumnSplit) {
|
||||||
|
TestPredictionWithLesserFeaturesColumnSplit("cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, Sparse) {
|
TEST(CpuPredictor, Sparse) {
|
||||||
TestSparsePrediction(0.2, "cpu_predictor");
|
TestSparsePrediction(0.2, "cpu_predictor");
|
||||||
TestSparsePrediction(0.8, "cpu_predictor");
|
TestSparsePrediction(0.8, "cpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CpuPredictor, SparseColumnSplit) {
|
||||||
|
TestSparsePredictionColumnSplit(0.2, "cpu_predictor");
|
||||||
|
TestSparsePredictionColumnSplit(0.8, "cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, Multi) {
|
TEST(CpuPredictor, Multi) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.nthread = 1;
|
ctx.nthread = 1;
|
||||||
|
|||||||
@ -209,7 +209,6 @@ TEST(GPUPredictor, IterationRange) {
|
|||||||
TestIterationRange("gpu_predictor");
|
TestIterationRange("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST(GPUPredictor, CategoricalPrediction) {
|
TEST(GPUPredictor, CategoricalPrediction) {
|
||||||
TestCategoricalPrediction("gpu_predictor");
|
TestCategoricalPrediction("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -153,27 +153,31 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
|
|||||||
learner->Configure();
|
learner->Configure();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
namespace {
|
||||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
std::unique_ptr<Learner> LearnerForTest(std::shared_ptr<DMatrix> dmat, size_t iters,
|
||||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
size_t forest = 1) {
|
||||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({m_train})};
|
learner->SetParams(Args{{"num_parallel_tree", std::to_string(forest)}});
|
||||||
|
for (size_t i = 0; i < iters; ++i) {
|
||||||
for (size_t i = 0; i < kIters; ++i) {
|
learner->UpdateOneIter(i, dmat);
|
||||||
learner->UpdateOneIter(i, m_train);
|
}
|
||||||
|
return learner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void VerifyPredictionWithLesserFeatures(Learner *learner, std::string const &predictor_name,
|
||||||
|
size_t rows, std::shared_ptr<DMatrix> const &m_test,
|
||||||
|
std::shared_ptr<DMatrix> const &m_invalid) {
|
||||||
HostDeviceVector<float> prediction;
|
HostDeviceVector<float> prediction;
|
||||||
learner->SetParam("predictor", predictor_name);
|
learner->SetParam("predictor", predictor_name);
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
Json config{Object()};
|
Json config{Object()};
|
||||||
learner->SaveConfig(&config);
|
learner->SaveConfig(&config);
|
||||||
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name);
|
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]),
|
||||||
|
predictor_name);
|
||||||
|
|
||||||
learner->Predict(m_test, false, &prediction, 0, 0);
|
learner->Predict(m_test, false, &prediction, 0, 0);
|
||||||
ASSERT_EQ(prediction.Size(), kRows);
|
ASSERT_EQ(prediction.Size(), rows);
|
||||||
|
|
||||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
|
||||||
ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
|
ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
@ -192,6 +196,42 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
|||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
||||||
|
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||||
|
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||||
|
auto learner = LearnerForTest(m_train, kIters);
|
||||||
|
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||||
|
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||||
|
VerifyPredictionWithLesserFeatures(learner.get(), predictor_name, kRows, m_test, m_invalid);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner,
|
||||||
|
std::string const &predictor_name, size_t rows,
|
||||||
|
std::shared_ptr<DMatrix> m_test,
|
||||||
|
std::shared_ptr<DMatrix> m_invalid) {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
|
||||||
|
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
|
VerifyPredictionWithLesserFeatures(learner, predictor_name, rows, sliced_test, sliced_invalid);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
void TestPredictionWithLesserFeaturesColumnSplit(std::string predictor_name) {
|
||||||
|
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||||
|
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||||
|
auto learner = LearnerForTest(m_train, kIters);
|
||||||
|
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||||
|
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||||
|
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, VerifyPredictionWithLesserFeaturesColumnSplit,
|
||||||
|
learner.get(), predictor_name, kRows, m_test, m_invalid);
|
||||||
|
}
|
||||||
|
|
||||||
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
||||||
bst_cat_t split_cat, float left_weight,
|
bst_cat_t split_cat, float left_weight,
|
||||||
@ -212,7 +252,7 @@ void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
|||||||
model->CommitModelGroup(std::move(trees), 0);
|
model->CommitModelGroup(std::move(trees), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalPrediction(std::string name) {
|
void TestCategoricalPrediction(std::string name, bool is_column_split) {
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
|
||||||
@ -236,6 +276,9 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
|
|
||||||
std::vector<FeatureType> types(10, FeatureType::kCategorical);
|
std::vector<FeatureType> types(10, FeatureType::kCategorical);
|
||||||
m->Info().feature_types.HostVector() = types;
|
m->Info().feature_types.HostVector() = types;
|
||||||
|
if (is_column_split) {
|
||||||
|
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
}
|
||||||
|
|
||||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||||
@ -246,13 +289,21 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
|
|
||||||
row[split_ind] = split_cat + 1;
|
row[split_ind] = split_cat + 1;
|
||||||
m = GetDMatrixFromData(row, 1, kCols);
|
m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
if (is_column_split) {
|
||||||
|
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
}
|
||||||
out_predictions.version = 0;
|
out_predictions.version = 0;
|
||||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalPredictLeaf(StringView name) {
|
void TestCategoricalPredictionColumnSplit(std::string name) {
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, name, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
|
||||||
@ -275,6 +326,9 @@ void TestCategoricalPredictLeaf(StringView name) {
|
|||||||
std::vector<float> row(kCols);
|
std::vector<float> row(kCols);
|
||||||
row[split_ind] = split_cat;
|
row[split_ind] = split_cat;
|
||||||
auto m = GetDMatrixFromData(row, 1, kCols);
|
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
if (is_column_split) {
|
||||||
|
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
}
|
||||||
|
|
||||||
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||||
CHECK_EQ(out_predictions.predictions.Size(), 1);
|
CHECK_EQ(out_predictions.predictions.Size(), 1);
|
||||||
@ -283,25 +337,25 @@ void TestCategoricalPredictLeaf(StringView name) {
|
|||||||
|
|
||||||
row[split_ind] = split_cat + 1;
|
row[split_ind] = split_cat + 1;
|
||||||
m = GetDMatrixFromData(row, 1, kCols);
|
m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
if (is_column_split) {
|
||||||
|
m = std::shared_ptr<DMatrix>{m->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
}
|
||||||
out_predictions.version = 0;
|
out_predictions.version = 0;
|
||||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestCategoricalPredictLeafColumnSplit(StringView name) {
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, name, true);
|
||||||
|
}
|
||||||
|
|
||||||
void TestIterationRange(std::string name) {
|
void TestIterationRange(std::string name) {
|
||||||
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3;
|
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
auto learner = LearnerForTest(dmat, kIters, kForest);
|
||||||
|
learner->SetParams(Args{{"predictor", name}});
|
||||||
learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)},
|
|
||||||
{"predictor", name}});
|
|
||||||
|
|
||||||
size_t kIters = 10;
|
|
||||||
for (size_t i = 0; i < kIters; ++i) {
|
|
||||||
learner->UpdateOneIter(i, dmat);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bound = false;
|
bool bound = false;
|
||||||
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
|
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
|
||||||
@ -363,15 +417,82 @@ void TestIterationRange(std::string name) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSparsePrediction(float sparsity, std::string predictor) {
|
namespace {
|
||||||
size_t constexpr kRows = 512, kCols = 128;
|
void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *sliced,
|
||||||
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
std::vector<float> const &expected_margin_ranged,
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
std::vector<float> const &expected_margin_sliced,
|
||||||
learner->Configure();
|
std::vector<float> const &expected_leaf_ranged,
|
||||||
for (size_t i = 0; i < 4; ++i) {
|
std::vector<float> const &expected_leaf_sliced) {
|
||||||
learner->UpdateOneIter(i, Xy);
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::shared_ptr<DMatrix> Xy{dmat->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
|
HostDeviceVector<float> out_predt_sliced;
|
||||||
|
HostDeviceVector<float> out_predt_ranged;
|
||||||
|
|
||||||
|
// margin
|
||||||
|
{
|
||||||
|
sliced->Predict(Xy, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
|
||||||
|
learner->Predict(Xy, true, &out_predt_ranged, 0, 3, false, false, false, false, false);
|
||||||
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
|
auto const &h_range = out_predt_ranged.HostVector();
|
||||||
|
ASSERT_EQ(h_sliced.size(), expected_margin_sliced.size());
|
||||||
|
ASSERT_EQ(h_sliced, expected_margin_sliced);
|
||||||
|
ASSERT_EQ(h_range.size(), expected_margin_ranged.size());
|
||||||
|
ASSERT_EQ(h_range, expected_margin_ranged);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Leaf
|
||||||
|
{
|
||||||
|
sliced->Predict(Xy, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
|
||||||
|
learner->Predict(Xy, false, &out_predt_ranged, 0, 3, false, true, false, false, false);
|
||||||
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
|
auto const &h_range = out_predt_ranged.HostVector();
|
||||||
|
ASSERT_EQ(h_sliced.size(), expected_leaf_sliced.size());
|
||||||
|
ASSERT_EQ(h_sliced, expected_leaf_sliced);
|
||||||
|
ASSERT_EQ(h_range.size(), expected_leaf_ranged.size());
|
||||||
|
ASSERT_EQ(h_range, expected_leaf_ranged);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
void TestIterationRangeColumnSplit(std::string name) {
|
||||||
|
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
||||||
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||||
|
auto learner = LearnerForTest(dmat, kIters, kForest);
|
||||||
|
learner->SetParams(Args{{"predictor", name}});
|
||||||
|
|
||||||
|
bool bound = false;
|
||||||
|
std::unique_ptr<Learner> sliced{learner->Slice(0, 3, 1, &bound)};
|
||||||
|
ASSERT_FALSE(bound);
|
||||||
|
|
||||||
|
// margin
|
||||||
|
HostDeviceVector<float> margin_predt_sliced;
|
||||||
|
HostDeviceVector<float> margin_predt_ranged;
|
||||||
|
sliced->Predict(dmat, true, &margin_predt_sliced, 0, 0, false, false, false, false, false);
|
||||||
|
learner->Predict(dmat, true, &margin_predt_ranged, 0, 3, false, false, false, false, false);
|
||||||
|
auto const &margin_sliced = margin_predt_sliced.HostVector();
|
||||||
|
auto const &margin_ranged = margin_predt_ranged.HostVector();
|
||||||
|
|
||||||
|
// Leaf
|
||||||
|
HostDeviceVector<float> leaf_predt_sliced;
|
||||||
|
HostDeviceVector<float> leaf_predt_ranged;
|
||||||
|
sliced->Predict(dmat, false, &leaf_predt_sliced, 0, 0, false, true, false, false, false);
|
||||||
|
learner->Predict(dmat, false, &leaf_predt_ranged, 0, 3, false, true, false, false, false);
|
||||||
|
auto const &leaf_sliced = leaf_predt_sliced.HostVector();
|
||||||
|
auto const &leaf_ranged = leaf_predt_ranged.HostVector();
|
||||||
|
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, VerifyIterationRangeColumnSplit, dmat.get(),
|
||||||
|
learner.get(), sliced.get(), margin_ranged, margin_sliced,
|
||||||
|
leaf_ranged, leaf_sliced);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||||
|
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
|
||||||
|
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
||||||
|
auto learner = LearnerForTest(Xy, kIters);
|
||||||
|
|
||||||
HostDeviceVector<float> sparse_predt;
|
HostDeviceVector<float> sparse_predt;
|
||||||
|
|
||||||
Json model{Object{}};
|
Json model{Object{}};
|
||||||
@ -419,6 +540,43 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner,
|
||||||
|
std::vector<float> const &expected_predt) {
|
||||||
|
std::shared_ptr<DMatrix> sliced{
|
||||||
|
dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
HostDeviceVector<float> sparse_predt;
|
||||||
|
learner->Predict(sliced, false, &sparse_predt, 0, 0);
|
||||||
|
|
||||||
|
auto const &predt = sparse_predt.HostVector();
|
||||||
|
ASSERT_EQ(predt.size(), expected_predt.size());
|
||||||
|
for (size_t i = 0; i < predt.size(); ++i) {
|
||||||
|
ASSERT_FLOAT_EQ(predt[i], expected_predt[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor) {
|
||||||
|
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
|
||||||
|
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
||||||
|
auto learner = LearnerForTest(Xy, kIters);
|
||||||
|
|
||||||
|
HostDeviceVector<float> sparse_predt;
|
||||||
|
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
|
||||||
|
learner.reset(Learner::Create({Xy}));
|
||||||
|
learner->LoadModel(model);
|
||||||
|
|
||||||
|
learner->SetParam("predictor", predictor);
|
||||||
|
learner->Predict(Xy, false, &sparse_predt, 0, 0);
|
||||||
|
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, VerifySparsePredictionColumnSplit, Xy.get(),
|
||||||
|
learner.get(), sparse_predt.HostVector());
|
||||||
|
}
|
||||||
|
|
||||||
void TestVectorLeafPrediction(Context const *ctx) {
|
void TestVectorLeafPrediction(Context const *ctx) {
|
||||||
std::unique_ptr<Predictor> cpu_predictor =
|
std::unique_ptr<Predictor> cpu_predictor =
|
||||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", ctx));
|
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", ctx));
|
||||||
|
|||||||
@ -86,14 +86,24 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
|
|||||||
|
|
||||||
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||||
|
|
||||||
void TestCategoricalPrediction(std::string name);
|
void TestPredictionWithLesserFeaturesColumnSplit(std::string preditor_name);
|
||||||
|
|
||||||
void TestCategoricalPredictLeaf(StringView name);
|
void TestCategoricalPrediction(std::string name, bool is_column_split = false);
|
||||||
|
|
||||||
|
void TestCategoricalPredictionColumnSplit(std::string name);
|
||||||
|
|
||||||
|
void TestCategoricalPredictLeaf(StringView name, bool is_column_split = false);
|
||||||
|
|
||||||
|
void TestCategoricalPredictLeafColumnSplit(StringView name);
|
||||||
|
|
||||||
void TestIterationRange(std::string name);
|
void TestIterationRange(std::string name);
|
||||||
|
|
||||||
|
void TestIterationRangeColumnSplit(std::string name);
|
||||||
|
|
||||||
void TestSparsePrediction(float sparsity, std::string predictor);
|
void TestSparsePrediction(float sparsity, std::string predictor);
|
||||||
|
|
||||||
|
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor);
|
||||||
|
|
||||||
void TestVectorLeafPrediction(Context const* ctx);
|
void TestVectorLeafPrediction(Context const* ctx);
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user