More tests for cpu predictor with column split (#9270)

This commit is contained in:
Rong Ou
2023-06-08 07:47:19 -07:00
committed by GitHub
parent 84d3fcb7ea
commit ff122d61ff
5 changed files with 243 additions and 41 deletions

View File

@@ -430,8 +430,7 @@ class ColumnSplitHelper {
<< "column-split prediction is only supported for distributed training";
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(),
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
}
}
@@ -543,8 +542,12 @@ class ColumnSplitHelper {
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
auto const gid = model_.tree_info[tree_id];
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] +=
PredictOneTree<predict_leaf>(tree_id, batch_offset + i);
auto const result = PredictOneTree<predict_leaf>(tree_id, batch_offset + i);
if constexpr (predict_leaf) {
preds[(predict_offset + i) * (tree_end_ - tree_begin_) + tree_id] = result;
} else {
preds[(predict_offset + i) * num_group + gid] += result;
}
}
}
}
@@ -645,6 +648,9 @@ class CPUPredictor : public Predictor {
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->Info().IsColumnSplit()) {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict DMatrix with column split" << MTNotImplemented();
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds);
return;
@@ -743,6 +749,8 @@ class CPUPredictor : public Predictor {
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(!p_m->Info().IsColumnSplit())
<< "Inplace predict support for column-wise data split is not yet implemented.";
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
@@ -773,6 +781,9 @@ class CPUPredictor : public Predictor {
out_preds->resize(model.learner_model_param->num_output_group);
if (is_column_split) {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict instance with column split" << MTNotImplemented();
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
helper.PredictInstance(inst, out_preds);
return;
@@ -802,6 +813,9 @@ class CPUPredictor : public Predictor {
preds.resize(info.num_row_ * ntree_limit);
if (p_fmat->Info().IsColumnSplit()) {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict leaf with column split" << MTNotImplemented();
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
helper.PredictLeaf(p_fmat, &preds);
return;