Initial support for column-split cpu predictor (#8676)
This commit is contained in:
parent
980233e648
commit
78396f8a6e
@ -207,6 +207,8 @@ class Communicator {
|
|||||||
result = CommunicatorType::kRabit;
|
result = CommunicatorType::kRabit;
|
||||||
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
||||||
result = CommunicatorType::kFederated;
|
result = CommunicatorType::kFederated;
|
||||||
|
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
|
||||||
|
result = CommunicatorType::kInMemory;
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown communicator type " << str;
|
LOG(FATAL) << "Unknown communicator type " << str;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,12 +8,12 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../data/gradient_index.h"
|
#include "../data/gradient_index.h"
|
||||||
#include "../data/proxy_dmatrix.h"
|
|
||||||
#include "../gbm/gbtree_model.h"
|
#include "../gbm/gbtree_model.h"
|
||||||
#include "cpu_treeshap.h" // CalculateContributions
|
#include "cpu_treeshap.h" // CalculateContributions
|
||||||
#include "predict_fn.h"
|
#include "predict_fn.h"
|
||||||
@ -23,7 +23,6 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
#include "xgboost/tree_updater.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace predictor {
|
namespace predictor {
|
||||||
@ -284,16 +283,277 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
|
|||||||
FillNodeMeanValues(tree, 0, mean_values);
|
FillNodeMeanValues(tree, 0, mean_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
class CPUPredictor : public Predictor {
|
namespace {
|
||||||
protected:
|
// init thread buffers
|
||||||
// init thread buffers
|
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
|
||||||
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
|
int prev_thread_temp_size = out->size();
|
||||||
int prev_thread_temp_size = out->size();
|
if (prev_thread_temp_size < nthread) {
|
||||||
if (prev_thread_temp_size < nthread) {
|
out->resize(nthread, RegTree::FVec());
|
||||||
out->resize(nthread, RegTree::FVec());
|
}
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A helper class for prediction when the DMatrix is split by column.
|
||||||
|
*
|
||||||
|
* When data is split by column, a local DMatrix only contains a subset of features. All the workers
|
||||||
|
* in a distributed/federated environment need to cooperate to produce a prediction. This is done in
|
||||||
|
* two passes with the help of bit vectors.
|
||||||
|
*
|
||||||
|
* First pass:
|
||||||
|
* for each tree:
|
||||||
|
* for each row:
|
||||||
|
* for each node:
|
||||||
|
* if the feature is available and passes the filter, mark the corresponding decision bit
|
||||||
|
* if the feature is missing, mark the missing bit
|
||||||
|
*
|
||||||
|
* Once the two bit vectors are populated, run allreduce on both, using bitwise OR for the decision
|
||||||
|
* bits, and bitwise AND for the missing bits.
|
||||||
|
*
|
||||||
|
* Second pass:
|
||||||
|
* for each tree:
|
||||||
|
* for each row:
|
||||||
|
* find the leaf node using the decision and missing bits, return the leaf value
|
||||||
|
*
|
||||||
|
* The size of the decision/missing bit vector is:
|
||||||
|
* number of rows in a batch * sum(number of nodes in each tree)
|
||||||
|
*/
|
||||||
|
class ColumnSplitHelper {
|
||||||
|
public:
|
||||||
|
ColumnSplitHelper(std::int32_t n_threads, gbm::GBTreeModel const &model, uint32_t tree_begin,
|
||||||
|
uint32_t tree_end)
|
||||||
|
: n_threads_{n_threads}, model_{model}, tree_begin_{tree_begin}, tree_end_{tree_end} {
|
||||||
|
auto const n_trees = tree_end_ - tree_begin_;
|
||||||
|
tree_sizes_.resize(n_trees);
|
||||||
|
tree_offsets_.resize(n_trees);
|
||||||
|
for (auto i = 0; i < n_trees; i++) {
|
||||||
|
auto const &tree = *model_.trees[tree_begin_ + i];
|
||||||
|
tree_sizes_[i] = tree.GetNodes().size();
|
||||||
|
}
|
||||||
|
// std::exclusive_scan (only available in c++17) equivalent to get tree offsets.
|
||||||
|
tree_offsets_[0] = 0;
|
||||||
|
for (auto i = 1; i < n_trees; i++) {
|
||||||
|
tree_offsets_[i] = tree_offsets_[i - 1] + tree_sizes_[i - 1];
|
||||||
|
}
|
||||||
|
bits_per_row_ = tree_offsets_.back() + tree_sizes_.back();
|
||||||
|
|
||||||
|
InitThreadTemp(n_threads_ * kBlockOfRowsSize, &feat_vecs_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable copy (and move) semantics.
|
||||||
|
ColumnSplitHelper(ColumnSplitHelper const &) = delete;
|
||||||
|
ColumnSplitHelper &operator=(ColumnSplitHelper const &) = delete;
|
||||||
|
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
|
||||||
|
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;
|
||||||
|
|
||||||
|
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||||
|
CHECK(xgboost::collective::IsDistributed())
|
||||||
|
<< "column-split prediction is only supported for distributed training";
|
||||||
|
|
||||||
|
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
CHECK_EQ(out_preds->size(),
|
||||||
|
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
||||||
|
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using BitVector = RBitField8;
|
||||||
|
|
||||||
|
void InitBitVectors(std::size_t n_rows) {
|
||||||
|
n_rows_ = n_rows;
|
||||||
|
auto const size = BitVector::ComputeStorageSize(bits_per_row_ * n_rows_);
|
||||||
|
decision_storage_.resize(size);
|
||||||
|
decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_));
|
||||||
|
missing_storage_.resize(size);
|
||||||
|
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClearBitVectors() {
|
||||||
|
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
|
||||||
|
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const {
|
||||||
|
size_t tree_index = tree_id - tree_begin_;
|
||||||
|
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllreduceBitVectors() {
|
||||||
|
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||||
|
decision_storage_.size());
|
||||||
|
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||||
|
missing_storage_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
|
||||||
|
auto const &tree = *model_.trees[tree_id];
|
||||||
|
auto const &cats = tree.GetCategoriesMatrix();
|
||||||
|
auto const has_categorical = tree.HasCategoricalSplit();
|
||||||
|
|
||||||
|
for (auto nid = 0; nid < tree.GetNodes().size(); nid++) {
|
||||||
|
auto const &node = tree[nid];
|
||||||
|
if (node.IsDeleted() || node.IsLeaf()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const bit_index = BitIndex(tree_id, row_id, nid);
|
||||||
|
unsigned split_index = node.SplitIndex();
|
||||||
|
if (feat.IsMissing(split_index)) {
|
||||||
|
missing_bits_.Set(bit_index);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const fvalue = feat.GetFvalue(split_index);
|
||||||
|
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||||
|
auto const node_categories =
|
||||||
|
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||||
|
if (!common::Decision(node_categories, fvalue)) {
|
||||||
|
decision_bits_.Set(bit_index);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fvalue >= node.SplitCond()) {
|
||||||
|
decision_bits_.Set(bit_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaskAllTrees(std::size_t batch_offset, std::size_t fvec_offset, std::size_t block_size) {
|
||||||
|
for (auto tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
|
||||||
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
|
MaskOneTree(feat_vecs_[fvec_offset + i], tree_id, batch_offset + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_node_t GetNextNode(RegTree::Node const &node, std::size_t bit_index) {
|
||||||
|
if (missing_bits_.Check(bit_index)) {
|
||||||
|
return node.DefaultChild();
|
||||||
|
} else {
|
||||||
|
return node.LeftChild() + decision_bits_.Check(bit_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_node_t GetLeafIndex(RegTree const &tree, std::size_t tree_id, std::size_t row_id) {
|
||||||
|
bst_node_t nid = 0;
|
||||||
|
while (!tree[nid].IsLeaf()) {
|
||||||
|
auto const bit_index = BitIndex(tree_id, row_id, nid);
|
||||||
|
nid = GetNextNode(tree[nid], bit_index);
|
||||||
|
}
|
||||||
|
return nid;
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float PredictOneTree(std::size_t tree_id, std::size_t row_id) {
|
||||||
|
auto const &tree = *model_.trees[tree_id];
|
||||||
|
auto const leaf = GetLeafIndex(tree, tree_id, row_id);
|
||||||
|
return tree[leaf].LeafValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PredictAllTrees(std::vector<bst_float> *out_preds, std::size_t batch_offset,
|
||||||
|
std::size_t predict_offset, std::size_t num_group, std::size_t block_size) {
|
||||||
|
auto &preds = *out_preds;
|
||||||
|
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
|
||||||
|
auto const gid = model_.tree_info[tree_id];
|
||||||
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
|
preds[(predict_offset + i) * num_group + gid] += PredictOneTree(tree_id, batch_offset + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DataView, size_t block_of_rows_size>
|
||||||
|
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
|
||||||
|
auto const num_group = model_.learner_model_param->num_output_group;
|
||||||
|
|
||||||
|
CHECK_EQ(model_.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far";
|
||||||
|
// parallel over local batch
|
||||||
|
auto const nsize = batch.Size();
|
||||||
|
auto const num_feature = model_.learner_model_param->num_feature;
|
||||||
|
auto const n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
|
||||||
|
InitBitVectors(nsize);
|
||||||
|
|
||||||
|
// auto block_id has the same type as `n_blocks`.
|
||||||
|
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
|
||||||
|
auto const batch_offset = block_id * block_of_rows_size;
|
||||||
|
auto const block_size = std::min(nsize - batch_offset, block_of_rows_size);
|
||||||
|
auto const fvec_offset = omp_get_thread_num() * block_of_rows_size;
|
||||||
|
|
||||||
|
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, &feat_vecs_);
|
||||||
|
MaskAllTrees(batch_offset, fvec_offset, block_size);
|
||||||
|
FVecDrop(block_size, batch_offset, &batch, fvec_offset, &feat_vecs_);
|
||||||
|
});
|
||||||
|
|
||||||
|
AllreduceBitVectors();
|
||||||
|
|
||||||
|
// auto block_id has the same type as `n_blocks`.
|
||||||
|
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
|
||||||
|
auto const batch_offset = block_id * block_of_rows_size;
|
||||||
|
auto const block_size = std::min(nsize - batch_offset, block_of_rows_size);
|
||||||
|
PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, num_group,
|
||||||
|
block_size);
|
||||||
|
});
|
||||||
|
|
||||||
|
ClearBitVectors();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::size_t constexpr kBlockOfRowsSize = 64;
|
||||||
|
|
||||||
|
std::int32_t const n_threads_;
|
||||||
|
gbm::GBTreeModel const &model_;
|
||||||
|
uint32_t const tree_begin_;
|
||||||
|
uint32_t const tree_end_;
|
||||||
|
|
||||||
|
std::vector<std::size_t> tree_sizes_{};
|
||||||
|
std::vector<std::size_t> tree_offsets_{};
|
||||||
|
std::size_t bits_per_row_{};
|
||||||
|
std::vector<RegTree::FVec> feat_vecs_{};
|
||||||
|
|
||||||
|
std::size_t n_rows_;
|
||||||
|
/**
|
||||||
|
* @brief Stores decision bit for each split node.
|
||||||
|
*
|
||||||
|
* Conceptually it's a 3-dimensional bit matrix:
|
||||||
|
* - 1st dimension is the tree index, from `tree_begin_` to `tree_end_`.
|
||||||
|
* - 2nd dimension is the row index, for each row in the batch.
|
||||||
|
* - 3rd dimension is the node id, for each node in the tree.
|
||||||
|
*
|
||||||
|
* Since we have to ship the whole thing over the wire to do an allreduce, the matrix is flattened
|
||||||
|
* into a 1-dimensional array.
|
||||||
|
*
|
||||||
|
* First, it's divided by the tree index:
|
||||||
|
*
|
||||||
|
* [ tree 0 ] [ tree 1 ] ...
|
||||||
|
*
|
||||||
|
* Then each tree is divided by row:
|
||||||
|
*
|
||||||
|
* [ tree 0 ] [ tree 1 ] ...
|
||||||
|
* [ row 0 ] [ row 1 ] ... [ row n-1 ] [ row 0 ] ...
|
||||||
|
*
|
||||||
|
* Finally, each row is divided by the node id:
|
||||||
|
*
|
||||||
|
* [ tree 0 ]
|
||||||
|
* [ row 0 ] [ row 1 ] ...
|
||||||
|
* [ node 0 ] [ node 1 ] ... [ node n-1 ] [ node 0 ] ...
|
||||||
|
*
|
||||||
|
* The first two dimensions are fixed length, while the last dimension is variable length since
|
||||||
|
* each tree may have a different number of nodes. We precompute the tree offsets, which are the
|
||||||
|
* cumulative sums of tree sizes. The index of tree t, row r, node n is:
|
||||||
|
* index(t, r, n) = tree_offsets[t] * n_rows + r * tree_sizes[t] + n
|
||||||
|
*/
|
||||||
|
std::vector<BitVector::value_type> decision_storage_{};
|
||||||
|
BitVector decision_bits_{};
|
||||||
|
/**
|
||||||
|
* @brief Stores whether the feature is missing for each split node.
|
||||||
|
*
|
||||||
|
* See above for the storage layout.
|
||||||
|
*/
|
||||||
|
std::vector<BitVector::value_type> missing_storage_{};
|
||||||
|
BitVector missing_bits_{};
|
||||||
|
};
|
||||||
|
|
||||||
|
class CPUPredictor : public Predictor {
|
||||||
|
protected:
|
||||||
void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin,
|
void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||||
int32_t tree_end, std::vector<bst_float> *out_preds) const {
|
int32_t tree_end, std::vector<bst_float> *out_preds) const {
|
||||||
auto const n_threads = this->ctx_->Threads();
|
auto const n_threads = this->ctx_->Threads();
|
||||||
@ -323,6 +583,12 @@ class CPUPredictor : public Predictor {
|
|||||||
|
|
||||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||||
|
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) {
|
||||||
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||||
|
helper.PredictDMatrix(p_fmat, out_preds);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (!p_fmat->PageExists<SparsePage>()) {
|
if (!p_fmat->PageExists<SparsePage>()) {
|
||||||
this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds);
|
this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -12,14 +12,17 @@ namespace collective {
|
|||||||
TEST(CommunicatorFactory, TypeFromEnv) {
|
TEST(CommunicatorFactory, TypeFromEnv) {
|
||||||
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv());
|
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv());
|
||||||
|
|
||||||
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
|
||||||
|
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);
|
||||||
|
|
||||||
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "rabit");
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "rabit");
|
||||||
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv());
|
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv());
|
||||||
|
|
||||||
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "Federated");
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "Federated");
|
||||||
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv());
|
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv());
|
||||||
|
|
||||||
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "In-Memory");
|
||||||
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);
|
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromEnv());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CommunicatorFactory, TypeFromArgs) {
|
TEST(CommunicatorFactory, TypeFromArgs) {
|
||||||
@ -32,6 +35,9 @@ TEST(CommunicatorFactory, TypeFromArgs) {
|
|||||||
config["xgboost_communicator"] = String("federated");
|
config["xgboost_communicator"] = String("federated");
|
||||||
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
|
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["xgboost_communicator"] = String("in-memory");
|
||||||
|
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
config["xgboost_communicator"] = String("foo");
|
config["xgboost_communicator"] = String("foo");
|
||||||
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
|
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
|
||||||
}
|
}
|
||||||
@ -46,6 +52,9 @@ TEST(CommunicatorFactory, TypeFromArgsUpperCase) {
|
|||||||
config["XGBOOST_COMMUNICATOR"] = String("federated");
|
config["XGBOOST_COMMUNICATOR"] = String("federated");
|
||||||
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
|
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["XGBOOST_COMMUNICATOR"] = String("in-memory");
|
||||||
|
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
config["XGBOOST_COMMUNICATOR"] = String("foo");
|
config["XGBOOST_COMMUNICATOR"] = String("foo");
|
||||||
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
|
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,9 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/predictor.h>
|
#include <xgboost/predictor.h>
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
#include "../../../src/data/adapter.h"
|
#include "../../../src/data/adapter.h"
|
||||||
#include "../../../src/data/proxy_dmatrix.h"
|
#include "../../../src/data/proxy_dmatrix.h"
|
||||||
#include "../../../src/gbm/gbtree.h"
|
#include "../../../src/gbm/gbtree.h"
|
||||||
@ -86,6 +89,49 @@ TEST(CpuPredictor, Basic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CpuPredictor, ColumnSplit) {
|
||||||
|
size_t constexpr kRows = 5;
|
||||||
|
size_t constexpr kCols = 5;
|
||||||
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
size_t constexpr kWorldSize = 2;
|
||||||
|
size_t constexpr kSliceSize = (kCols + 1) / kWorldSize;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back([=, &dmat]() {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["xgboost_communicator"] = String("in-memory");
|
||||||
|
config["in_memory_world_size"] = kWorldSize;
|
||||||
|
config["in_memory_rank"] = rank;
|
||||||
|
xgboost::collective::Init(config);
|
||||||
|
|
||||||
|
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||||
|
std::unique_ptr<Predictor> cpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
|
||||||
|
|
||||||
|
LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
ctx.UpdateAllowUnknown(Args{});
|
||||||
|
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||||
|
|
||||||
|
// Test predict batch
|
||||||
|
PredictionCacheEntry out_predictions;
|
||||||
|
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||||
|
auto sliced = std::unique_ptr<DMatrix>{dmat->SliceCol(rank * kSliceSize, kSliceSize)};
|
||||||
|
cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0);
|
||||||
|
|
||||||
|
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||||
|
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
|
||||||
|
ASSERT_EQ(out_predictions_h[i], 1.5);
|
||||||
|
}
|
||||||
|
xgboost::collective::Finalize();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, IterationRange) {
|
TEST(CpuPredictor, IterationRange) {
|
||||||
TestIterationRange("cpu_predictor");
|
TestIterationRange("cpu_predictor");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user