More support for column split in cpu predictor (#9244)
- Added column split support to `PredictInstance` and `PredictLeaf`. - Refactoring of tests.
This commit is contained in:
parent
3bf0f145bb
commit
962a20693f
@ -138,12 +138,14 @@ class Predictor {
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param model The model to predict from
|
||||
* \param tree_end (Optional) The tree end index.
|
||||
* \param is_column_split (Optional) If the data is split column-wise.
|
||||
*/
|
||||
|
||||
virtual void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned tree_end = 0) const = 0;
|
||||
unsigned tree_end = 0,
|
||||
bool is_column_split = false) const = 0;
|
||||
|
||||
/**
|
||||
* \brief predict the leaf index of each tree, the output will be nsample *
|
||||
|
||||
@ -191,6 +191,15 @@ struct SparsePageView {
|
||||
size_t Size() const { return view.Size(); }
|
||||
};
|
||||
|
||||
struct SingleInstanceView {
|
||||
bst_row_t base_rowid{};
|
||||
SparsePage::Inst const &inst;
|
||||
|
||||
explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {}
|
||||
SparsePage::Inst operator[](size_t) { return inst; }
|
||||
static size_t Size() { return 1; }
|
||||
};
|
||||
|
||||
struct GHistIndexMatrixView {
|
||||
private:
|
||||
GHistIndexMatrix const &page_;
|
||||
@ -409,6 +418,24 @@ class ColumnSplitHelper {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(SparsePage::Inst const &inst, std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
PredictBatchKernel<SingleInstanceView, 1>(SingleInstanceView{inst}, out_preds);
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using BitVector = RBitField8;
|
||||
|
||||
@ -498,24 +525,31 @@ class ColumnSplitHelper {
|
||||
return nid;
|
||||
}
|
||||
|
||||
template <bool predict_leaf = false>
|
||||
bst_float PredictOneTree(std::size_t tree_id, std::size_t row_id) {
|
||||
auto const &tree = *model_.trees[tree_id];
|
||||
auto const leaf = GetLeafIndex(tree, tree_id, row_id);
|
||||
if constexpr (predict_leaf) {
|
||||
return static_cast<bst_float>(leaf);
|
||||
} else {
|
||||
return tree[leaf].LeafValue();
|
||||
}
|
||||
}
|
||||
|
||||
template <bool predict_leaf = false>
|
||||
void PredictAllTrees(std::vector<bst_float> *out_preds, std::size_t batch_offset,
|
||||
std::size_t predict_offset, std::size_t num_group, std::size_t block_size) {
|
||||
auto &preds = *out_preds;
|
||||
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
|
||||
auto const gid = model_.tree_info[tree_id];
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] += PredictOneTree(tree_id, batch_offset + i);
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
PredictOneTree<predict_leaf>(tree_id, batch_offset + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size>
|
||||
template <typename DataView, size_t block_of_rows_size, bool predict_leaf = false>
|
||||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
|
||||
auto const num_group = model_.learner_model_param->num_output_group;
|
||||
|
||||
@ -544,8 +578,8 @@ class ColumnSplitHelper {
|
||||
auto const batch_offset = block_id * block_of_rows_size;
|
||||
auto const block_size = std::min(static_cast<std::size_t>(nsize - batch_offset),
|
||||
static_cast<std::size_t>(block_of_rows_size));
|
||||
PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, num_group,
|
||||
block_size);
|
||||
PredictAllTrees<predict_leaf>(out_preds, batch_offset, batch_offset + batch.base_rowid,
|
||||
num_group, block_size);
|
||||
});
|
||||
|
||||
ClearBitVectors();
|
||||
@ -728,18 +762,25 @@ class CPUPredictor : public Predictor {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||
void PredictInstance(const SparsePage::Inst &inst, std::vector<bst_float> *out_preds,
|
||||
const gbm::GBTreeModel &model, unsigned ntree_limit,
|
||||
bool is_column_split) const override {
|
||||
CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented();
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
feat_vecs.resize(1, RegTree::FVec());
|
||||
feat_vecs[0].Init(model.learner_model_param->num_feature);
|
||||
ntree_limit *= model.learner_model_param->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
}
|
||||
out_preds->resize(model.learner_model_param->num_output_group);
|
||||
|
||||
if (is_column_split) {
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
||||
helper.PredictInstance(inst, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
feat_vecs.resize(1, RegTree::FVec());
|
||||
feat_vecs[0].Init(model.learner_model_param->num_feature);
|
||||
auto base_score = model.learner_model_param->BaseScore(ctx_)(0);
|
||||
// loop over output groups
|
||||
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
|
||||
@ -752,16 +793,23 @@ class CPUPredictor : public Predictor {
|
||||
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *out_preds,
|
||||
const gbm::GBTreeModel &model, unsigned ntree_limit) const override {
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
InitThreadTemp(n_threads, &feat_vecs);
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
}
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
std::vector<bst_float> &preds = out_preds->HostVector();
|
||||
preds.resize(info.num_row_ * ntree_limit);
|
||||
|
||||
if (p_fmat->Info().IsColumnSplit()) {
|
||||
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
||||
helper.PredictLeaf(p_fmat, &preds);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
InitThreadTemp(n_threads, &feat_vecs);
|
||||
// start collecting the prediction
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
// parallel over local batch
|
||||
@ -796,6 +844,8 @@ class CPUPredictor : public Predictor {
|
||||
int condition, unsigned condition_feature) const override {
|
||||
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||
<< "Predict contribution" << MTNotImplemented();
|
||||
CHECK(!p_fmat->Info().IsColumnSplit())
|
||||
<< "Predict contribution support for column-wise data split is not yet implemented.";
|
||||
auto const n_threads = this->ctx_->Threads();
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
@ -877,6 +927,8 @@ class CPUPredictor : public Predictor {
|
||||
bool approximate) const override {
|
||||
CHECK(!model.learner_model_param->IsVectorLeaf())
|
||||
<< "Predict interaction contribution" << MTNotImplemented();
|
||||
CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for "
|
||||
"column-wise data split is not yet implemented.";
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
size_t const ncolumns = model.learner_model_param->num_feature;
|
||||
|
||||
@ -929,7 +929,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
void PredictInstance(const SparsePage::Inst&,
|
||||
std::vector<bst_float>*,
|
||||
const gbm::GBTreeModel&, unsigned) const override {
|
||||
const gbm::GBTreeModel&, unsigned, bool) const override {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
@ -17,13 +17,15 @@
|
||||
#include "test_predictor.h"
|
||||
|
||||
namespace xgboost {
|
||||
TEST(CpuPredictor, Basic) {
|
||||
|
||||
namespace {
|
||||
void TestBasic(DMatrix* dmat) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
|
||||
|
||||
size_t constexpr kRows = 5;
|
||||
size_t constexpr kCols = 5;
|
||||
size_t const kRows = dmat->Info().num_row_;
|
||||
size_t const kCols = dmat->Info().num_col_;
|
||||
|
||||
LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
|
||||
|
||||
@ -31,12 +33,10 @@ TEST(CpuPredictor, Basic) {
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||
cpu_predictor->PredictBatch(dmat, &out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
|
||||
@ -48,22 +48,28 @@ TEST(CpuPredictor, Basic) {
|
||||
auto page = batch.GetView();
|
||||
for (size_t i = 0; i < batch.Size(); i++) {
|
||||
std::vector<float> instance_out_predictions;
|
||||
cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model);
|
||||
cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model, 0,
|
||||
dmat->Info().IsColumnSplit());
|
||||
ASSERT_EQ(instance_out_predictions[0], 1.5);
|
||||
}
|
||||
|
||||
// Test predict leaf
|
||||
HostDeviceVector<float> leaf_out_predictions;
|
||||
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||
cpu_predictor->PredictLeaf(dmat, &leaf_out_predictions, model);
|
||||
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
|
||||
for (auto v : h_leaf_out_predictions) {
|
||||
ASSERT_EQ(v, 0);
|
||||
}
|
||||
|
||||
if (dmat->Info().IsColumnSplit()) {
|
||||
// Predict contribution is not supported for column split.
|
||||
return;
|
||||
}
|
||||
|
||||
// Test predict contribution
|
||||
HostDeviceVector<float> out_contribution_hdv;
|
||||
auto& out_contribution = out_contribution_hdv.HostVector();
|
||||
cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model);
|
||||
cpu_predictor->PredictContribution(dmat, &out_contribution_hdv, model);
|
||||
ASSERT_EQ(out_contribution.size(), kRows * (kCols + 1));
|
||||
for (size_t i = 0; i < out_contribution.size(); ++i) {
|
||||
auto const& contri = out_contribution[i];
|
||||
@ -76,8 +82,7 @@ TEST(CpuPredictor, Basic) {
|
||||
}
|
||||
}
|
||||
// Test predict contribution (approximate method)
|
||||
cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model,
|
||||
0, nullptr, true);
|
||||
cpu_predictor->PredictContribution(dmat, &out_contribution_hdv, model, 0, nullptr, true);
|
||||
for (size_t i = 0; i < out_contribution.size(); ++i) {
|
||||
auto const& contri = out_contribution[i];
|
||||
// shift 1 for bias, as test tree is a decision dump, only global bias is
|
||||
@ -89,41 +94,32 @@ TEST(CpuPredictor, Basic) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
namespace {
|
||||
void TestColumnSplitPredictBatch() {
|
||||
TEST(CpuPredictor, Basic) {
|
||||
size_t constexpr kRows = 5;
|
||||
size_t constexpr kCols = 5;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
TestBasic(dmat.get());
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestColumnSplit() {
|
||||
size_t constexpr kRows = 5;
|
||||
size_t constexpr kCols = 5;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
dmat = std::unique_ptr<DMatrix>{dmat->SliceCol(world_size, rank)};
|
||||
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
|
||||
|
||||
LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||
auto sliced = std::unique_ptr<DMatrix>{dmat->SliceCol(world_size, rank)};
|
||||
cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
|
||||
ASSERT_EQ(out_predictions_h[i], 1.5);
|
||||
}
|
||||
TestBasic(dmat.get());
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(CpuPredictor, ColumnSplit) {
|
||||
TEST(CpuPredictor, ColumnSplitBasic) {
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplitPredictBatch);
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, IterationRange) {
|
||||
@ -133,69 +129,8 @@ TEST(CpuPredictor, IterationRange) {
|
||||
TEST(CpuPredictor, ExternalMemory) {
|
||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
|
||||
|
||||
LearnerModelParam mparam{MakeMP(dmat->Info().num_col_, .0, 1)};
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
|
||||
ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_);
|
||||
for (const auto& v : out_predictions_h) {
|
||||
ASSERT_EQ(v, 1.5);
|
||||
}
|
||||
|
||||
// Test predict leaf
|
||||
HostDeviceVector<float> leaf_out_predictions;
|
||||
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
|
||||
ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_);
|
||||
for (const auto& v : h_leaf_out_predictions) {
|
||||
ASSERT_EQ(v, 0);
|
||||
}
|
||||
|
||||
// Test predict contribution
|
||||
HostDeviceVector<float> out_contribution_hdv;
|
||||
auto& out_contribution = out_contribution_hdv.HostVector();
|
||||
cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model);
|
||||
ASSERT_EQ(out_contribution.size(), dmat->Info().num_row_ * (dmat->Info().num_col_ + 1));
|
||||
for (size_t i = 0; i < out_contribution.size(); ++i) {
|
||||
auto const& contri = out_contribution[i];
|
||||
// shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue().
|
||||
if ((i + 1) % (dmat->Info().num_col_ + 1) == 0) {
|
||||
ASSERT_EQ(out_contribution.back(), 1.5f);
|
||||
} else {
|
||||
ASSERT_EQ(contri, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Test predict contribution (approximate method)
|
||||
HostDeviceVector<float> out_contribution_approximate_hdv;
|
||||
auto& out_contribution_approximate = out_contribution_approximate_hdv.HostVector();
|
||||
cpu_predictor->PredictContribution(
|
||||
dmat.get(), &out_contribution_approximate_hdv, model, 0, nullptr, true);
|
||||
ASSERT_EQ(out_contribution_approximate.size(),
|
||||
dmat->Info().num_row_ * (dmat->Info().num_col_ + 1));
|
||||
for (size_t i = 0; i < out_contribution.size(); ++i) {
|
||||
auto const& contri = out_contribution[i];
|
||||
// shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue().
|
||||
if ((i + 1) % (dmat->Info().num_col_ + 1) == 0) {
|
||||
ASSERT_EQ(out_contribution.back(), 1.5f);
|
||||
} else {
|
||||
ASSERT_EQ(contri, 0);
|
||||
}
|
||||
}
|
||||
TestBasic(dmat.get());
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, InplacePredict) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user