From 66191e9926b68de2989c292c587a04383e084aae Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 4 Feb 2023 22:26:24 -0800 Subject: [PATCH] Support cpu quantile sketch with column-wise data split (#8742) --- include/xgboost/data.h | 6 +- src/common/hist_util.cc | 6 +- src/common/quantile.cc | 162 +++++++++++----------- src/common/quantile.h | 14 +- src/data/data.cc | 5 +- src/data/iterative_dmatrix.cc | 6 +- src/data/iterative_dmatrix.h | 2 +- src/data/proxy_dmatrix.h | 2 +- src/data/simple_dmatrix.cc | 7 +- src/data/simple_dmatrix.h | 2 +- src/data/sparse_page_dmatrix.h | 2 +- tests/cpp/common/test_quantile.cc | 137 +++++++++++++++++- tests/cpp/common/test_quantile.h | 1 - tests/cpp/data/test_simple_dmatrix.cc | 13 +- tests/cpp/predictor/test_cpu_predictor.cc | 3 +- 15 files changed, 250 insertions(+), 118 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index e243e4219..9411fcfab 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -627,11 +627,11 @@ class DMatrix { /** * \brief Slice a DMatrix by columns. * - * @param start The position of the first column - * @param size The number of columns in the slice + * @param num_slices Total number of slices + * @param slice_id Index of the current slice * @return DMatrix containing the slice of columns */ - virtual DMatrix *SliceCol(std::size_t start, std::size_t size) = 0; + virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0; protected: virtual BatchSet GetRowBatches() = 0; diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index f34b6f7a2..3b4d42a8d 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -45,14 +45,16 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b if (!use_sorted) { HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced, - HostSketchContainer::UseGroup(info), n_threads); + HostSketchContainer::UseGroup(info), + m->Info().data_split_mode == DataSplitMode::kCol, n_threads); for (auto const& page : m->GetBatches()) { container.PushRowPage(page, info, hessian); } container.MakeCuts(&out); } else { SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced, - HostSketchContainer::UseGroup(info), n_threads}; + HostSketchContainer::UseGroup(info), + m->Info().data_split_mode == DataSplitMode::kCol, n_threads}; for (auto const& page : m->GetBatches()) { container.PushColPage(page, info, hessian); } diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 3f3bb5326..87eb0ec20 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -18,11 +18,13 @@ template SketchContainerImpl::SketchContainerImpl(std::vector columns_size, int32_t max_bins, Span feature_types, - bool use_group, int32_t n_threads) + bool use_group, bool col_split, + int32_t n_threads) : feature_types_(feature_types.cbegin(), feature_types.cend()), columns_size_{std::move(columns_size)}, max_bins_{max_bins}, use_group_ind_{use_group}, + col_split_{col_split}, n_threads_{n_threads} { monitor_.Init(__func__); CHECK_NE(columns_size_.size(), 0); @@ -137,80 +139,6 @@ struct QuantileAllreduce { return worker_values.subspan(feat_beg, feat_size); } }; - -/** - * \brief Merge all categories from other workers. - */ -void AllreduceCategories(Span feature_types, int32_t n_threads, - std::vector> *p_categories) { - auto &categories = *p_categories; - auto world_size = collective::GetWorldSize(); - auto rank = collective::GetRank(); - if (world_size == 1) { - return; - } - - // CSC indptr to each feature - std::vector feature_ptr(categories.size() + 1, 0); - for (size_t i = 0; i < categories.size(); ++i) { - auto const &feat = categories[i]; - feature_ptr[i + 1] = feat.size(); - } - std::partial_sum(feature_ptr.begin(), feature_ptr.end(), feature_ptr.begin()); - CHECK_EQ(feature_ptr.front(), 0); - - // gather all feature ptrs from workers - std::vector global_feat_ptrs(feature_ptr.size() * world_size, 0); - size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker - std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin); - collective::Allreduce(global_feat_ptrs.data(), - global_feat_ptrs.size()); - - // move all categories into a flatten vector to prepare for allreduce - size_t total = feature_ptr.back(); - std::vector flatten(total, 0); - auto cursor{flatten.begin()}; - for (auto const &feat : categories) { - cursor = std::copy(feat.cbegin(), feat.cend(), cursor); - } - - // indptr for indexing workers - std::vector global_worker_ptr(world_size + 1, 0); - global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr - collective::Allreduce(global_worker_ptr.data(), - global_worker_ptr.size()); - std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin()); - // total number of categories in all workers with all features - auto gtotal = global_worker_ptr.back(); - - // categories in all workers with all features. - std::vector global_categories(gtotal, 0); - auto rank_begin = global_worker_ptr[rank]; - auto rank_size = global_worker_ptr[rank + 1] - rank_begin; - CHECK_EQ(rank_size, total); - std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin); - // gather values from all workers. - collective::Allreduce(global_categories.data(), - global_categories.size()); - QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, - categories.size()}; - ParallelFor(categories.size(), n_threads, [&](auto fidx) { - if (!IsCat(feature_types, fidx)) { - return; - } - for (int32_t r = 0; r < world_size; ++r) { - if (r == rank) { - // continue if it's current worker. - continue; - } - // 1 feature of 1 worker - auto worker_feature = allreduce_result.Values(r, fidx); - for (auto c : worker_feature) { - categories[fidx].emplace(c); - } - } - }); -} } // anonymous namespace template @@ -273,6 +201,76 @@ void SketchContainerImpl::GatherSketchInfo( global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)); } +template +void SketchContainerImpl::AllreduceCategories() { + auto world_size = collective::GetWorldSize(); + auto rank = collective::GetRank(); + if (world_size == 1 || col_split_) { + return; + } + + // CSC indptr to each feature + std::vector feature_ptr(categories_.size() + 1, 0); + for (size_t i = 0; i < categories_.size(); ++i) { + auto const &feat = categories_[i]; + feature_ptr[i + 1] = feat.size(); + } + std::partial_sum(feature_ptr.begin(), feature_ptr.end(), feature_ptr.begin()); + CHECK_EQ(feature_ptr.front(), 0); + + // gather all feature ptrs from workers + std::vector global_feat_ptrs(feature_ptr.size() * world_size, 0); + size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker + std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin); + collective::Allreduce(global_feat_ptrs.data(), + global_feat_ptrs.size()); + + // move all categories into a flatten vector to prepare for allreduce + size_t total = feature_ptr.back(); + std::vector flatten(total, 0); + auto cursor{flatten.begin()}; + for (auto const &feat : categories_) { + cursor = std::copy(feat.cbegin(), feat.cend(), cursor); + } + + // indptr for indexing workers + std::vector global_worker_ptr(world_size + 1, 0); + global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr + collective::Allreduce(global_worker_ptr.data(), + global_worker_ptr.size()); + std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin()); + // total number of categories in all workers with all features + auto gtotal = global_worker_ptr.back(); + + // categories in all workers with all features. + std::vector global_categories(gtotal, 0); + auto rank_begin = global_worker_ptr[rank]; + auto rank_size = global_worker_ptr[rank + 1] - rank_begin; + CHECK_EQ(rank_size, total); + std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin); + // gather values from all workers. + collective::Allreduce(global_categories.data(), + global_categories.size()); + QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, + categories_.size()}; + ParallelFor(categories_.size(), n_threads_, [&](auto fidx) { + if (!IsCat(feature_types_, fidx)) { + return; + } + for (int32_t r = 0; r < world_size; ++r) { + if (r == rank) { + // continue if it's current worker. + continue; + } + // 1 feature of 1 worker + auto worker_feature = allreduce_result.Values(r, fidx); + for (auto c : worker_feature) { + categories_[fidx].emplace(c); + } + } + }); +} + template void SketchContainerImpl::AllReduce( std::vector *p_reduced, @@ -283,7 +281,7 @@ void SketchContainerImpl::AllReduce( collective::Allreduce(&n_columns, 1); CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers"; - AllreduceCategories(feature_types_, n_threads_, &categories_); + AllreduceCategories(); auto& num_cuts = *p_num_cuts; CHECK_EQ(num_cuts.size(), 0); @@ -294,8 +292,10 @@ void SketchContainerImpl::AllReduce( // Prune the intermediate num cuts for synchronization. std::vector global_column_size(columns_size_); - collective::Allreduce(global_column_size.data(), - global_column_size.size()); + if (!col_split_) { + collective::Allreduce(global_column_size.data(), + global_column_size.size()); + } ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { int32_t intermediate_num_cuts = static_cast( @@ -316,7 +316,7 @@ void SketchContainerImpl::AllReduce( }); auto world = collective::GetWorldSize(); - if (world == 1) { + if (world == 1 || col_split_) { monitor_.Stop(__func__); return; } @@ -442,8 +442,8 @@ template class SketchContainerImpl>; HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span ft, std::vector columns_size, bool use_group, - int32_t n_threads) - : SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} { + bool col_split, int32_t n_threads) + : SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} { monitor_.Init(__func__); ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); diff --git a/src/common/quantile.h b/src/common/quantile.h index 27c528e8e..751fb773f 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -802,6 +802,7 @@ class SketchContainerImpl { std::vector columns_size_; int32_t max_bins_; bool use_group_ind_{false}; + bool col_split_; int32_t n_threads_; bool has_categorical_{false}; Monitor monitor_; @@ -814,7 +815,7 @@ class SketchContainerImpl { * \param use_group whether is assigned to group to data instance. */ SketchContainerImpl(std::vector columns_size, int32_t max_bins, - common::Span feature_types, bool use_group, + common::Span feature_types, bool use_group, bool col_split, int32_t n_threads); static bool UseGroup(MetaInfo const &info) { @@ -896,6 +897,10 @@ class SketchContainerImpl { void PushRowPage(SparsePage const &page, MetaInfo const &info, Span hessian = {}); void MakeCuts(HistogramCuts* cuts); + + private: + // Merge all categories from other workers. + void AllreduceCategories(); }; class HostSketchContainer : public SketchContainerImpl> { @@ -904,7 +909,8 @@ class HostSketchContainer : public SketchContainerImpl ft, - std::vector columns_size, bool use_group, int32_t n_threads); + std::vector columns_size, bool use_group, bool col_split, + int32_t n_threads); template void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing); @@ -1000,9 +1006,9 @@ class SortedSketchContainer : public SketchContainerImpl ft, - std::vector columns_size, bool use_group, + std::vector columns_size, bool use_group, bool col_split, int32_t n_threads) - : SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} { + : SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} { monitor_.Init(__func__); sketches_.resize(columns_size.size()); size_t i = 0; diff --git a/src/data/data.cc b/src/data/data.cc index 9aa0271c2..570226212 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -897,10 +897,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s if (!cache_file.empty()) { LOG(FATAL) << "Column-wise data split is not support for external memory."; } - auto slice_cols = (dmat->Info().num_col_ + 1) / npart; - auto slice_start = slice_cols * partid; - auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start); - auto* sliced = dmat->SliceCol(slice_start, size); + auto* sliced = dmat->SliceCol(npart, partid); delete dmat; return sliced; } else { diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 19dd3490d..8aacca48e 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -172,9 +172,9 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, size_t i = 0; while (iter.Next()) { if (!p_sketch) { - p_sketch.reset(new common::HostSketchContainer{batch_param_.max_bin, - proxy->Info().feature_types.ConstHostSpan(), - column_sizes, false, ctx_.Threads()}); + p_sketch.reset(new common::HostSketchContainer{ + batch_param_.max_bin, proxy->Info().feature_types.ConstHostSpan(), column_sizes, false, + proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()}); } HostAdapterDispatch(proxy, [&](auto const& batch) { proxy->Info().num_nonzero_ = batch_nnz[i]; diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index c79e84370..4df2c9753 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -86,7 +86,7 @@ class IterativeDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix."; return nullptr; } - DMatrix *SliceCol(std::size_t, std::size_t) override { + DMatrix *SliceCol(int num_slices, int slice_id) override { LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix."; return nullptr; } diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index af579ea72..6c8a04077 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -87,7 +87,7 @@ class DMatrixProxy : public DMatrix { LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix."; return nullptr; } - DMatrix* SliceCol(std::size_t, std::size_t) override { + DMatrix* SliceCol(int num_slices, int slice_id) override { LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix."; return nullptr; } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 33992f5f7..014b57282 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -46,9 +46,12 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { return out; } -DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) { +DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { auto out = new SimpleDMatrix; SparsePage& out_page = *out->sparse_page_; + auto const slice_size = info_.num_col_ / num_slices; + auto const slice_start = slice_size * slice_id; + auto const slice_end = (slice_id == num_slices - 1) ? info_.num_col_ : slice_start + slice_size; for (auto const &page : this->GetBatches()) { auto batch = page.GetView(); auto& h_data = out_page.data.HostVector(); @@ -58,7 +61,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) { auto inst = batch[i]; auto prev_size = h_data.size(); std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) { - return e.index >= start && e.index < start + size; + return e.index >= slice_start && e.index < slice_end; }); rptr += h_data.size() - prev_size; h_offset.emplace_back(rptr); diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 9b9b5accf..897abfcf0 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -35,7 +35,7 @@ class SimpleDMatrix : public DMatrix { bool SingleColBlock() const override { return true; } DMatrix* Slice(common::Span ridxs) override; - DMatrix* SliceCol(std::size_t start, std::size_t size) override; + DMatrix* SliceCol(int num_slices, int slice_id) override; /*! \brief magic number used to identify SimpleDMatrix binary files */ static const int kMagic = 0xffffab01; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 4f09684a7..5157116bf 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -107,7 +107,7 @@ class SparsePageDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; return nullptr; } - DMatrix *SliceCol(std::size_t, std::size_t) override { + DMatrix *SliceCol(int num_slices, int slice_id) override { LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory."; return nullptr; } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 7b609f476..3cd32ea0c 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -6,7 +6,6 @@ #include #include "../../../src/common/hist_util.h" -#include "../../../src/common/quantile.h" #include "../../../src/data/adapter.h" #include "xgboost/context.h" @@ -74,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { auto hess = Span{hessian}; ContainerType sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, AllThreadsForTest()); + column_size, false, false, AllThreadsForTest()); if (use_column) { for (auto const& page : m->GetBatches()) { @@ -95,7 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); m->Info().num_row_ = world * rows; ContainerType sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, AllThreadsForTest()); + column_size, false, false, AllThreadsForTest()); m->Info().num_row_ = rows; for (auto rank = 0; rank < world; ++rank) { @@ -170,6 +169,132 @@ TEST(Quantile, SortedDistributed) { TestDistributedQuantile(kRows, kCols); } +namespace { +template +void DoTestColSplitQuantile(size_t rows, size_t cols) { + auto const world = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + + auto m = std::unique_ptr{[=]() { + auto sparsity = 0.5f; + std::vector ft(cols); + for (size_t i = 0; i < ft.size(); ++i) { + ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical; + } + auto dmat = RandomDataGenerator{rows, cols, sparsity} + .Seed(0) + .Lower(.0f) + .Upper(1.0f) + .Type(ft) + .MaxCategory(13) + .GenerateDMatrix(); + return dmat->SliceCol(world, rank); + }()}; + + std::vector column_size(cols, 0); + auto const slice_size = cols / world; + auto const slice_start = slice_size * rank; + auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size; + for (auto i = slice_start; i < slice_end; i++) { + column_size[i] = rows; + } + + auto const n_bins = 64; + + // Generate cuts for distributed environment. + HistogramCuts distributed_cuts; + { + ContainerType sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(), + column_size, false, true, AllThreadsForTest()); + + std::vector hessian(rows, 1.0); + auto hess = Span{hessian}; + if (use_column) { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_distributed, page, m->Info(), hess); + } + } else { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_distributed, page, m->Info(), hess); + } + } + + sketch_distributed.MakeCuts(&distributed_cuts); + } + + // Generate cuts for single node environment + collective::Finalize(); + CHECK_EQ(collective::GetWorldSize(), 1); + HistogramCuts single_node_cuts; + { + ContainerType sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(), + column_size, false, false, AllThreadsForTest()); + + std::vector hessian(rows, 1.0); + auto hess = Span{hessian}; + if (use_column) { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_on_single_node, page, m->Info(), hess); + } + } else { + for (auto const& page : m->GetBatches()) { + PushPage(&sketch_on_single_node, page, m->Info(), hess); + } + } + + sketch_on_single_node.MakeCuts(&single_node_cuts); + } + + auto const& sptrs = single_node_cuts.Ptrs(); + auto const& dptrs = distributed_cuts.Ptrs(); + auto const& svals = single_node_cuts.Values(); + auto const& dvals = distributed_cuts.Values(); + auto const& smins = single_node_cuts.MinValues(); + auto const& dmins = distributed_cuts.MinValues(); + + EXPECT_EQ(sptrs.size(), dptrs.size()); + for (size_t i = 0; i < sptrs.size(); ++i) { + EXPECT_EQ(sptrs[i], dptrs[i]) << "rank: " << rank << ", i: " << i; + } + + EXPECT_EQ(svals.size(), dvals.size()); + for (size_t i = 0; i < svals.size(); ++i) { + EXPECT_NEAR(svals[i], dvals[i], 2e-2f) << "rank: " << rank << ", i: " << i; + } + + EXPECT_EQ(smins.size(), dmins.size()); + for (size_t i = 0; i < smins.size(); ++i) { + EXPECT_FLOAT_EQ(smins[i], dmins[i]) << "rank: " << rank << ", i: " << i; + } +} + +template +void TestColSplitQuantile(size_t rows, size_t cols) { + auto constexpr kWorkers = 4; + RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantile, rows, cols); +} +} // anonymous namespace + +TEST(Quantile, ColSplitBasic) { + constexpr size_t kRows = 10, kCols = 10; + TestColSplitQuantile(kRows, kCols); +} + +TEST(Quantile, ColSplit) { + constexpr size_t kRows = 4000, kCols = 200; + TestColSplitQuantile(kRows, kCols); +} + +TEST(Quantile, ColSplitSortedBasic) { + constexpr size_t kRows = 10, kCols = 10; + TestColSplitQuantile(kRows, kCols); +} + +TEST(Quantile, ColSplitSorted) { + constexpr size_t kRows = 4000, kCols = 200; + TestColSplitQuantile(kRows, kCols); +} + namespace { void TestSameOnAllWorkers() { auto const world = collective::GetWorldSize(); @@ -222,17 +347,17 @@ void TestSameOnAllWorkers() { for (int32_t i = 0; i < world; i++) { for (size_t j = 0; j < value_size; ++j) { size_t idx = i * value_size + j; - ASSERT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); + EXPECT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); } for (size_t j = 0; j < ptr_size; ++j) { size_t idx = i * ptr_size + j; - ASSERT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx)); + EXPECT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx)); } for (size_t j = 0; j < min_value_size; ++j) { size_t idx = i * min_value_size + j; - ASSERT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); + EXPECT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); } } }); diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index 957e5c987..d34c5e0e4 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -6,7 +6,6 @@ #include #include "../helpers.h" -#include "../../src/collective/communicator-inl.h" namespace xgboost { namespace common { diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 3dbe0a51a..a37352626 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -338,10 +338,10 @@ TEST(SimpleDMatrix, SliceCol) { auto& margin = p_m->Info().base_margin_; margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, Context::kCpuId}; - size_t constexpr kSlicCols {4}; - for (auto slice = 0; slice < 2; slice++) { - auto const slice_start = slice * kSlicCols; - std::unique_ptr out { p_m->SliceCol(slice_start, kSlicCols) }; + auto constexpr kSlices {2}; + auto constexpr kSliceSize {4}; + for (auto slice = 0; slice < kSlices; slice++) { + std::unique_ptr out { p_m->SliceCol(kSlices, slice) }; ASSERT_EQ(out->Info().labels.Size(), kRows); ASSERT_EQ(out->Info().labels_lower_bound_.Size(), kRows); ASSERT_EQ(out->Info().labels_upper_bound_.Size(), kRows); @@ -355,7 +355,8 @@ TEST(SimpleDMatrix, SliceCol) { auto out_inst = out_page[i]; auto in_inst = in_page[i]; ASSERT_EQ(out_inst.size() * 2, in_inst.size()) << i; - for (size_t j = 0; j < kSlicCols; ++j) { + for (size_t j = 0; j < kSliceSize; ++j) { + auto const slice_start = kSliceSize * slice; ASSERT_EQ(in_inst[slice_start + j].fvalue, out_inst[j].fvalue); ASSERT_EQ(in_inst[slice_start + j].index, out_inst[j].index); } @@ -377,7 +378,7 @@ TEST(SimpleDMatrix, SliceCol) { ASSERT_EQ(out->Info().num_col_, out->Info().num_col_); ASSERT_EQ(out->Info().num_row_, kRows); - ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense + ASSERT_EQ(out->Info().num_nonzero_, kRows * kSliceSize); // dense ASSERT_EQ(out->Info().data_split_mode, DataSplitMode::kCol); } } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index af666432a..9a0ebee18 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -97,7 +97,6 @@ void TestColumnSplitPredictBatch() { auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - auto const kSliceSize = (kCols + 1) / world_size; auto lparam = CreateEmptyGenericParam(GPUIDX); std::unique_ptr cpu_predictor = @@ -112,7 +111,7 @@ void TestColumnSplitPredictBatch() { // Test predict batch PredictionCacheEntry out_predictions; cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - auto sliced = std::unique_ptr{dmat->SliceCol(rank * kSliceSize, kSliceSize)}; + auto sliced = std::unique_ptr{dmat->SliceCol(world_size, rank)}; cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0); std::vector& out_predictions_h = out_predictions.predictions.HostVector();