Support cpu quantile sketch with column-wise data split (#8742)

This commit is contained in:
Rong Ou 2023-02-04 22:26:24 -08:00 committed by GitHub
parent c1786849e3
commit 66191e9926
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 250 additions and 118 deletions

View File

@ -627,11 +627,11 @@ class DMatrix {
/**
* \brief Slice a DMatrix by columns.
*
* @param start The position of the first column
* @param size The number of columns in the slice
* @param num_slices Total number of slices
* @param slice_id Index of the current slice
* @return DMatrix containing the slice of columns
*/
virtual DMatrix *SliceCol(std::size_t start, std::size_t size) = 0;
virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0;
protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;

View File

@ -45,14 +45,16 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads);
HostSketchContainer::UseGroup(info),
m->Info().data_split_mode == DataSplitMode::kCol, n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads};
HostSketchContainer::UseGroup(info),
m->Info().data_split_mode == DataSplitMode::kCol, n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}

View File

@ -18,11 +18,13 @@ template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
int32_t max_bins,
Span<FeatureType const> feature_types,
bool use_group, int32_t n_threads)
bool use_group, bool col_split,
int32_t n_threads)
: feature_types_(feature_types.cbegin(), feature_types.cend()),
columns_size_{std::move(columns_size)},
max_bins_{max_bins},
use_group_ind_{use_group},
col_split_{col_split},
n_threads_{n_threads} {
monitor_.Init(__func__);
CHECK_NE(columns_size_.size(), 0);
@ -137,80 +139,6 @@ struct QuantileAllreduce {
return worker_values.subspan(feat_beg, feat_size);
}
};
/**
* \brief Merge all categories from other workers.
*/
void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_threads,
std::vector<std::set<float>> *p_categories) {
auto &categories = *p_categories;
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1) {
return;
}
// CSC indptr to each feature
std::vector<size_t> feature_ptr(categories.size() + 1, 0);
for (size_t i = 0; i < categories.size(); ++i) {
auto const &feat = categories[i];
feature_ptr[i + 1] = feat.size();
}
std::partial_sum(feature_ptr.begin(), feature_ptr.end(), feature_ptr.begin());
CHECK_EQ(feature_ptr.front(), 0);
// gather all feature ptrs from workers
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
global_feat_ptrs.size());
// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
std::vector<float> flatten(total, 0);
auto cursor{flatten.begin()};
for (auto const &feat : categories) {
cursor = std::copy(feat.cbegin(), feat.cend(), cursor);
}
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
global_worker_ptr.size());
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
// categories in all workers with all features.
std::vector<float> global_categories(gtotal, 0);
auto rank_begin = global_worker_ptr[rank];
auto rank_size = global_worker_ptr[rank + 1] - rank_begin;
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
global_categories.size());
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories.size()};
ParallelFor(categories.size(), n_threads, [&](auto fidx) {
if (!IsCat(feature_types, fidx)) {
return;
}
for (int32_t r = 0; r < world_size; ++r) {
if (r == rank) {
// continue if it's current worker.
continue;
}
// 1 feature of 1 worker
auto worker_feature = allreduce_result.Values(r, fidx);
for (auto c : worker_feature) {
categories[fidx].emplace(c);
}
}
});
}
} // anonymous namespace
template <typename WQSketch>
@ -273,6 +201,76 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
}
template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories() {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || col_split_) {
return;
}
// CSC indptr to each feature
std::vector<size_t> feature_ptr(categories_.size() + 1, 0);
for (size_t i = 0; i < categories_.size(); ++i) {
auto const &feat = categories_[i];
feature_ptr[i + 1] = feat.size();
}
std::partial_sum(feature_ptr.begin(), feature_ptr.end(), feature_ptr.begin());
CHECK_EQ(feature_ptr.front(), 0);
// gather all feature ptrs from workers
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
global_feat_ptrs.size());
// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
std::vector<float> flatten(total, 0);
auto cursor{flatten.begin()};
for (auto const &feat : categories_) {
cursor = std::copy(feat.cbegin(), feat.cend(), cursor);
}
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
global_worker_ptr.size());
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
// categories in all workers with all features.
std::vector<float> global_categories(gtotal, 0);
auto rank_begin = global_worker_ptr[rank];
auto rank_size = global_worker_ptr[rank + 1] - rank_begin;
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
global_categories.size());
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
if (!IsCat(feature_types_, fidx)) {
return;
}
for (int32_t r = 0; r < world_size; ++r) {
if (r == rank) {
// continue if it's current worker.
continue;
}
// 1 feature of 1 worker
auto worker_feature = allreduce_result.Values(r, fidx);
for (auto c : worker_feature) {
categories_[fidx].emplace(c);
}
}
});
}
template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
@ -283,7 +281,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
AllreduceCategories(feature_types_, n_threads_, &categories_);
AllreduceCategories();
auto& num_cuts = *p_num_cuts;
CHECK_EQ(num_cuts.size(), 0);
@ -294,8 +292,10 @@ void SketchContainerImpl<WQSketch>::AllReduce(
// Prune the intermediate num cuts for synchronization.
std::vector<bst_row_t> global_column_size(columns_size_);
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
global_column_size.size());
if (!col_split_) {
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
global_column_size.size());
}
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(
@ -316,7 +316,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
});
auto world = collective::GetWorldSize();
if (world == 1) {
if (world == 1 || col_split_) {
monitor_.Stop(__func__);
return;
}
@ -442,8 +442,8 @@ template class SketchContainerImpl<WXQuantileSketch<float, float>>;
HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group,
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
bool col_split, int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} {
monitor_.Init(__func__);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);

View File

@ -802,6 +802,7 @@ class SketchContainerImpl {
std::vector<bst_row_t> columns_size_;
int32_t max_bins_;
bool use_group_ind_{false};
bool col_split_;
int32_t n_threads_;
bool has_categorical_{false};
Monitor monitor_;
@ -814,7 +815,7 @@ class SketchContainerImpl {
* \param use_group whether is assigned to group to data instance.
*/
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
common::Span<FeatureType const> feature_types, bool use_group,
common::Span<FeatureType const> feature_types, bool use_group, bool col_split,
int32_t n_threads);
static bool UseGroup(MetaInfo const &info) {
@ -896,6 +897,10 @@ class SketchContainerImpl {
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
void MakeCuts(HistogramCuts* cuts);
private:
// Merge all categories from other workers.
void AllreduceCategories();
};
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
@ -904,7 +909,8 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
public:
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);
std::vector<size_t> columns_size, bool use_group, bool col_split,
int32_t n_threads);
template <typename Batch>
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
@ -1000,9 +1006,9 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
public:
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group,
std::vector<size_t> columns_size, bool use_group, bool col_split,
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
: SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} {
monitor_.Init(__func__);
sketches_.resize(columns_size.size());
size_t i = 0;

View File

@ -897,10 +897,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
}
auto slice_cols = (dmat->Info().num_col_ + 1) / npart;
auto slice_start = slice_cols * partid;
auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start);
auto* sliced = dmat->SliceCol(slice_start, size);
auto* sliced = dmat->SliceCol(npart, partid);
delete dmat;
return sliced;
} else {

View File

@ -172,9 +172,9 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
size_t i = 0;
while (iter.Next()) {
if (!p_sketch) {
p_sketch.reset(new common::HostSketchContainer{batch_param_.max_bin,
proxy->Info().feature_types.ConstHostSpan(),
column_sizes, false, ctx_.Threads()});
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, proxy->Info().feature_types.ConstHostSpan(), column_sizes, false,
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];

View File

@ -86,7 +86,7 @@ class IterativeDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
return nullptr;
}
DMatrix *SliceCol(std::size_t, std::size_t) override {
DMatrix *SliceCol(int num_slices, int slice_id) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
return nullptr;
}

View File

@ -87,7 +87,7 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
return nullptr;
}
DMatrix* SliceCol(std::size_t, std::size_t) override {
DMatrix* SliceCol(int num_slices, int slice_id) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
return nullptr;
}

View File

@ -46,9 +46,12 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
return out;
}
DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
auto out = new SimpleDMatrix;
SparsePage& out_page = *out->sparse_page_;
auto const slice_size = info_.num_col_ / num_slices;
auto const slice_start = slice_size * slice_id;
auto const slice_end = (slice_id == num_slices - 1) ? info_.num_col_ : slice_start + slice_size;
for (auto const &page : this->GetBatches<SparsePage>()) {
auto batch = page.GetView();
auto& h_data = out_page.data.HostVector();
@ -58,7 +61,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
auto inst = batch[i];
auto prev_size = h_data.size();
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) {
return e.index >= start && e.index < start + size;
return e.index >= slice_start && e.index < slice_end;
});
rptr += h_data.size() - prev_size;
h_offset.emplace_back(rptr);

View File

@ -35,7 +35,7 @@ class SimpleDMatrix : public DMatrix {
bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
DMatrix* SliceCol(std::size_t start, std::size_t size) override;
DMatrix* SliceCol(int num_slices, int slice_id) override;
/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;

View File

@ -107,7 +107,7 @@ class SparsePageDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}
DMatrix *SliceCol(std::size_t, std::size_t) override {
DMatrix *SliceCol(int num_slices, int slice_id) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
return nullptr;
}

View File

@ -6,7 +6,6 @@
#include <gtest/gtest.h>
#include "../../../src/common/hist_util.h"
#include "../../../src/common/quantile.h"
#include "../../../src/data/adapter.h"
#include "xgboost/context.h"
@ -74,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
auto hess = Span<float const>{hessian};
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, AllThreadsForTest());
column_size, false, false, AllThreadsForTest());
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
@ -95,7 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
m->Info().num_row_ = world * rows;
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, AllThreadsForTest());
column_size, false, false, AllThreadsForTest());
m->Info().num_row_ = rows;
for (auto rank = 0; rank < world; ++rank) {
@ -170,6 +169,132 @@ TEST(Quantile, SortedDistributed) {
TestDistributedQuantile<true>(kRows, kCols);
}
namespace {
template <bool use_column>
void DoTestColSplitQuantile(size_t rows, size_t cols) {
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto m = std::unique_ptr<DMatrix>{[=]() {
auto sparsity = 0.5f;
std::vector<FeatureType> ft(cols);
for (size_t i = 0; i < ft.size(); ++i) {
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
}
auto dmat = RandomDataGenerator{rows, cols, sparsity}
.Seed(0)
.Lower(.0f)
.Upper(1.0f)
.Type(ft)
.MaxCategory(13)
.GenerateDMatrix();
return dmat->SliceCol(world, rank);
}()};
std::vector<bst_row_t> column_size(cols, 0);
auto const slice_size = cols / world;
auto const slice_start = slice_size * rank;
auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size;
for (auto i = slice_start; i < slice_end; i++) {
column_size[i] = rows;
}
auto const n_bins = 64;
// Generate cuts for distributed environment.
HistogramCuts distributed_cuts;
{
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, true, AllThreadsForTest());
std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
} else {
for (auto const& page : m->GetBatches<SparsePage>()) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
}
sketch_distributed.MakeCuts(&distributed_cuts);
}
// Generate cuts for single node environment
collective::Finalize();
CHECK_EQ(collective::GetWorldSize(), 1);
HistogramCuts single_node_cuts;
{
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
PushPage(&sketch_on_single_node, page, m->Info(), hess);
}
} else {
for (auto const& page : m->GetBatches<SparsePage>()) {
PushPage(&sketch_on_single_node, page, m->Info(), hess);
}
}
sketch_on_single_node.MakeCuts(&single_node_cuts);
}
auto const& sptrs = single_node_cuts.Ptrs();
auto const& dptrs = distributed_cuts.Ptrs();
auto const& svals = single_node_cuts.Values();
auto const& dvals = distributed_cuts.Values();
auto const& smins = single_node_cuts.MinValues();
auto const& dmins = distributed_cuts.MinValues();
EXPECT_EQ(sptrs.size(), dptrs.size());
for (size_t i = 0; i < sptrs.size(); ++i) {
EXPECT_EQ(sptrs[i], dptrs[i]) << "rank: " << rank << ", i: " << i;
}
EXPECT_EQ(svals.size(), dvals.size());
for (size_t i = 0; i < svals.size(); ++i) {
EXPECT_NEAR(svals[i], dvals[i], 2e-2f) << "rank: " << rank << ", i: " << i;
}
EXPECT_EQ(smins.size(), dmins.size());
for (size_t i = 0; i < smins.size(); ++i) {
EXPECT_FLOAT_EQ(smins[i], dmins[i]) << "rank: " << rank << ", i: " << i;
}
}
template <bool use_column>
void TestColSplitQuantile(size_t rows, size_t cols) {
auto constexpr kWorkers = 4;
RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantile<use_column>, rows, cols);
}
} // anonymous namespace
TEST(Quantile, ColSplitBasic) {
constexpr size_t kRows = 10, kCols = 10;
TestColSplitQuantile<false>(kRows, kCols);
}
TEST(Quantile, ColSplit) {
constexpr size_t kRows = 4000, kCols = 200;
TestColSplitQuantile<false>(kRows, kCols);
}
TEST(Quantile, ColSplitSortedBasic) {
constexpr size_t kRows = 10, kCols = 10;
TestColSplitQuantile<true>(kRows, kCols);
}
TEST(Quantile, ColSplitSorted) {
constexpr size_t kRows = 4000, kCols = 200;
TestColSplitQuantile<true>(kRows, kCols);
}
namespace {
void TestSameOnAllWorkers() {
auto const world = collective::GetWorldSize();
@ -222,17 +347,17 @@ void TestSameOnAllWorkers() {
for (int32_t i = 0; i < world; i++) {
for (size_t j = 0; j < value_size; ++j) {
size_t idx = i * value_size + j;
ASSERT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps);
EXPECT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps);
}
for (size_t j = 0; j < ptr_size; ++j) {
size_t idx = i * ptr_size + j;
ASSERT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx));
EXPECT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx));
}
for (size_t j = 0; j < min_value_size; ++j) {
size_t idx = i * min_value_size + j;
ASSERT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx));
EXPECT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx));
}
}
});

View File

@ -6,7 +6,6 @@
#include <vector>
#include "../helpers.h"
#include "../../src/collective/communicator-inl.h"
namespace xgboost {
namespace common {

View File

@ -338,10 +338,10 @@ TEST(SimpleDMatrix, SliceCol) {
auto& margin = p_m->Info().base_margin_;
margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, Context::kCpuId};
size_t constexpr kSlicCols {4};
for (auto slice = 0; slice < 2; slice++) {
auto const slice_start = slice * kSlicCols;
std::unique_ptr<DMatrix> out { p_m->SliceCol(slice_start, kSlicCols) };
auto constexpr kSlices {2};
auto constexpr kSliceSize {4};
for (auto slice = 0; slice < kSlices; slice++) {
std::unique_ptr<DMatrix> out { p_m->SliceCol(kSlices, slice) };
ASSERT_EQ(out->Info().labels.Size(), kRows);
ASSERT_EQ(out->Info().labels_lower_bound_.Size(), kRows);
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), kRows);
@ -355,7 +355,8 @@ TEST(SimpleDMatrix, SliceCol) {
auto out_inst = out_page[i];
auto in_inst = in_page[i];
ASSERT_EQ(out_inst.size() * 2, in_inst.size()) << i;
for (size_t j = 0; j < kSlicCols; ++j) {
for (size_t j = 0; j < kSliceSize; ++j) {
auto const slice_start = kSliceSize * slice;
ASSERT_EQ(in_inst[slice_start + j].fvalue, out_inst[j].fvalue);
ASSERT_EQ(in_inst[slice_start + j].index, out_inst[j].index);
}
@ -377,7 +378,7 @@ TEST(SimpleDMatrix, SliceCol) {
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
ASSERT_EQ(out->Info().num_row_, kRows);
ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense
ASSERT_EQ(out->Info().num_nonzero_, kRows * kSliceSize); // dense
ASSERT_EQ(out->Info().data_split_mode, DataSplitMode::kCol);
}
}

View File

@ -97,7 +97,6 @@ void TestColumnSplitPredictBatch() {
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const kSliceSize = (kCols + 1) / world_size;
auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> cpu_predictor =
@ -112,7 +111,7 @@ void TestColumnSplitPredictBatch() {
// Test predict batch
PredictionCacheEntry out_predictions;
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
auto sliced = std::unique_ptr<DMatrix>{dmat->SliceCol(rank * kSliceSize, kSliceSize)};
auto sliced = std::unique_ptr<DMatrix>{dmat->SliceCol(world_size, rank)};
cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0);
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();