From 6a7c6a8ae6ff7e35ce6fdedae3331dd2ef324485 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Sat, 23 Mar 2024 05:55:25 +0100 Subject: [PATCH 01/26] add sycl reaslisation of ghist builder (#10138) Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/common/hist_util.cc | 334 ++++++++++++++++++++ plugin/sycl/common/hist_util.h | 89 ++++++ tests/cpp/plugin/sycl_helpers.h | 9 +- tests/cpp/plugin/test_sycl_ghist_builder.cc | 157 +++++++++ 4 files changed, 585 insertions(+), 4 deletions(-) create mode 100644 plugin/sycl/common/hist_util.cc create mode 100644 plugin/sycl/common/hist_util.h create mode 100644 tests/cpp/plugin/test_sycl_ghist_builder.cc diff --git a/plugin/sycl/common/hist_util.cc b/plugin/sycl/common/hist_util.cc new file mode 100644 index 000000000..fd813a92c --- /dev/null +++ b/plugin/sycl/common/hist_util.cc @@ -0,0 +1,334 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file hist_util.cc + */ +#include +#include +#include + +#include "../data/gradient_index.h" +#include "hist_util.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + +/*! + * \brief Fill histogram with zeroes + */ +template +void InitHist(::sycl::queue qu, GHistRow* hist, + size_t size, ::sycl::event* event) { + *event = qu.fill(hist->Begin(), + xgboost::detail::GradientPairInternal(), size, *event); +} +template void InitHist(::sycl::queue qu, + GHistRow* hist, + size_t size, ::sycl::event* event); +template void InitHist(::sycl::queue qu, + GHistRow* hist, + size_t size, ::sycl::event* event); + +/*! + * \brief Compute Subtraction: dst = src1 - src2 + */ +template +::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv) { + GradientSumT* pdst = reinterpret_cast(dst->Data()); + const GradientSumT* psrc1 = reinterpret_cast(src1.DataConst()); + const GradientSumT* psrc2 = reinterpret_cast(src2.DataConst()); + + auto event_final = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_priv); + cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) { + const size_t i = pid.get_id(0); + pdst[i] = psrc1[i] - psrc2[i]; + }); + }); + return event_final; +} +template ::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); +template ::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); + +// Kernel with buffer using +template +::sycl::event BuildHistKernel(::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + GHistRow* hist_buffer, + ::sycl::event event_priv) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const GradientPair::ValueT* pgh = + reinterpret_cast(gpair_device.DataConst()); + const BinIdxType* gradient_index = gmat.index.data(); + const uint32_t* offsets = gmat.index.Offset(); + FPType* hist_data = reinterpret_cast(hist->Data()); + const size_t nbins = gmat.nbins; + + const size_t max_work_group_size = + qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); + const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size; + + const size_t max_nblocks = hist_buffer->Size() / (nbins * 2); + const size_t min_block_size = 128; + size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size)); + const size_t block_size = size / nblocks + !!(size % nblocks); + FPType* hist_buffer_data = reinterpret_cast(hist_buffer->Data()); + + auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv); + auto event_main = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_fill); + cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size), + ::sycl::range<2>(1, work_group_size)), + [=](::sycl::nd_item<2> pid) { + size_t block = pid.get_global_id(0); + size_t feat = pid.get_global_id(1); + + FPType* hist_local = hist_buffer_data + block * nbins * 2; + for (size_t idx = 0; idx < block_size; ++idx) { + size_t i = block * block_size + idx; + if (i < size) { + const size_t icol_start = n_columns * rid[i]; + const size_t idx_gh = rid[i]; + + pid.barrier(::sycl::access::fence_space::local_space); + const BinIdxType* gr_index_local = gradient_index + icol_start; + + for (size_t j = feat; j < n_columns; j += work_group_size) { + uint32_t idx_bin = static_cast(gr_index_local[j]); + if constexpr (isDense) { + idx_bin += offsets[j]; + } + if (idx_bin < nbins) { + hist_local[2 * idx_bin] += pgh[2 * idx_gh]; + hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1]; + } + } + } + } + }); + }); + + auto event_save = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_main); + cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) { + size_t idx_bin = pid.get_id(0); + + FPType gsum = 0.0f; + FPType hsum = 0.0f; + + for (size_t j = 0; j < nblocks; ++j) { + gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin]; + hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1]; + } + + hist_data[2 * idx_bin] = gsum; + hist_data[2 * idx_bin + 1] = hsum; + }); + }); + return event_save; +} + +// Kernel with atomic using +template +::sycl::event BuildHistKernel(::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + ::sycl::event event_priv) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const GradientPair::ValueT* pgh = + reinterpret_cast(gpair_device.DataConst()); + const BinIdxType* gradient_index = gmat.index.data(); + const uint32_t* offsets = gmat.index.Offset(); + FPType* hist_data = reinterpret_cast(hist->Data()); + const size_t nbins = gmat.nbins; + + const size_t max_work_group_size = + qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); + const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size; + + auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv); + auto event_main = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_fill); + cgh.parallel_for<>(::sycl::range<2>(size, feat_local), + [=](::sycl::item<2> pid) { + size_t i = pid.get_id(0); + size_t feat = pid.get_id(1); + + const size_t icol_start = n_columns * rid[i]; + const size_t idx_gh = rid[i]; + + const BinIdxType* gr_index_local = gradient_index + icol_start; + + for (size_t j = feat; j < n_columns; j += feat_local) { + uint32_t idx_bin = static_cast(gr_index_local[j]); + if constexpr (isDense) { + idx_bin += offsets[j]; + } + if (idx_bin < nbins) { + AtomicRef gsum(hist_data[2 * idx_bin]); + AtomicRef hsum(hist_data[2 * idx_bin + 1]); + gsum.fetch_add(pgh[2 * idx_gh]); + hsum.fetch_add(pgh[2 * idx_gh + 1]); + } + } + }); + }); + return event_main; +} + +template +::sycl::event BuildHistDispatchKernel( + ::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + bool isDense, + GHistRow* hist_buffer, + ::sycl::event events_priv, + bool force_atomic_use) { + const size_t size = row_indices.Size(); + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const size_t nbins = gmat.nbins; + + // max cycle size, while atomics are still effective + const size_t max_cycle_size_atomics = nbins; + const size_t cycle_size = size; + + // TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one + bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns); + + // force_atomic_use flag is used only for testing + use_atomic = use_atomic || force_atomic_use; + if (!use_atomic) { + if (isDense) { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, hist_buffer, + events_priv); + } else { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, hist_buffer, + events_priv); + } + } else { + if (isDense) { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, events_priv); + } else { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, events_priv); + } + } +} + +template +::sycl::event BuildHistKernel(::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, const bool isDense, + GHistRow* hist, + GHistRow* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use) { + const bool is_dense = isDense; + switch (gmat.index.GetBinTypeSize()) { + case BinTypeSize::kUint8BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv, force_atomic_use); + break; + case BinTypeSize::kUint16BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv, force_atomic_use); + break; + case BinTypeSize::kUint32BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv, force_atomic_use); + break; + default: + CHECK(false); // no default behavior + } +} + +template +::sycl::event GHistBuilder::BuildHist( + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix &gmat, + GHistRowT* hist, + bool isDense, + GHistRowT* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use) { + return BuildHistKernel(qu_, gpair_device, row_indices, gmat, + isDense, hist, hist_buffer, event_priv, + force_atomic_use); +} + +template +::sycl::event GHistBuilder::BuildHist( + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + bool isDense, + GHistRow* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use); +template +::sycl::event GHistBuilder::BuildHist( + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + bool isDense, + GHistRow* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use); + +template +void GHistBuilder::SubtractionTrick(GHistRowT* self, + const GHistRowT& sibling, + const GHistRowT& parent) { + const size_t size = self->Size(); + CHECK_EQ(sibling.Size(), size); + CHECK_EQ(parent.Size(), size); + + SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event()); +} +template +void GHistBuilder::SubtractionTrick(GHistRow* self, + const GHistRow& sibling, + const GHistRow& parent); +template +void GHistBuilder::SubtractionTrick(GHistRow* self, + const GHistRow& sibling, + const GHistRow& parent); +} // namespace common +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/common/hist_util.h b/plugin/sycl/common/hist_util.h new file mode 100644 index 000000000..7c7af71ae --- /dev/null +++ b/plugin/sycl/common/hist_util.h @@ -0,0 +1,89 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file hist_util.h + */ +#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_ +#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_ + +#include +#include +#include + +#include "../data.h" +#include "row_set.h" + +#include "../../src/common/hist_util.h" +#include "../data/gradient_index.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + +template +using GHistRow = USMVector, memory_type>; + +using BinTypeSize = ::xgboost::common::BinTypeSize; + +class ColumnMatrix; + +/*! + * \brief Fill histogram with zeroes + */ +template +void InitHist(::sycl::queue qu, + GHistRow* hist, + size_t size, ::sycl::event* event); + +/*! + * \brief Compute subtraction: dst = src1 - src2 + */ +template +::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); + +/*! + * \brief Builder for histograms of gradient statistics + */ +template +class GHistBuilder { + public: + template + using GHistRowT = GHistRow; + + GHistBuilder() = default; + GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {} + + // Construct a histogram via histogram aggregation + ::sycl::event BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRowT* HistCollection, + bool isDense, + GHistRowT* hist_buffer, + ::sycl::event event, + bool force_atomic_use = false); + + // Construct a histogram via subtraction trick + void SubtractionTrick(GHistRowT* self, + const GHistRowT& sibling, + const GHistRowT& parent); + + uint32_t GetNumBins() const { + return nbins_; + } + + private: + /*! \brief Number of all bins over all features */ + uint32_t nbins_ { 0 }; + + ::sycl::queue qu_; +}; +} // namespace common +} // namespace sycl +} // namespace xgboost +#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_ diff --git a/tests/cpp/plugin/sycl_helpers.h b/tests/cpp/plugin/sycl_helpers.h index c5cdd3ea5..afc403d86 100644 --- a/tests/cpp/plugin/sycl_helpers.h +++ b/tests/cpp/plugin/sycl_helpers.h @@ -8,22 +8,23 @@ namespace xgboost::sycl { template void VerifySyclVector(const USMVector& sycl_vector, - const Container& host_vector) { + const Container& host_vector, T eps = T()) { ASSERT_EQ(sycl_vector.Size(), host_vector.size()); size_t size = sycl_vector.Size(); for (size_t i = 0; i < size; ++i) { - ASSERT_EQ(sycl_vector[i], host_vector[i]); + EXPECT_NEAR(sycl_vector[i], host_vector[i], eps); } } template -void VerifySyclVector(const std::vector& sycl_vector, const Container& host_vector) { +void VerifySyclVector(const std::vector& sycl_vector, + const Container& host_vector, T eps = T()) { ASSERT_EQ(sycl_vector.size(), host_vector.size()); size_t size = sycl_vector.size(); for (size_t i = 0; i < size; ++i) { - ASSERT_EQ(sycl_vector[i], host_vector[i]); + EXPECT_NEAR(sycl_vector[i], host_vector[i], eps); } } diff --git a/tests/cpp/plugin/test_sycl_ghist_builder.cc b/tests/cpp/plugin/test_sycl_ghist_builder.cc new file mode 100644 index 000000000..dacbc75fc --- /dev/null +++ b/tests/cpp/plugin/test_sycl_ghist_builder.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix +#pragma GCC diagnostic pop + +#include "../../../plugin/sycl/common/hist_util.h" +#include "../../../plugin/sycl/device_manager.h" +#include "sycl_helpers.h" +#include "../helpers.h" + +namespace xgboost::sycl::common { + +template +void GHistBuilderTest(float sparsity, bool force_atomic_use) { + const size_t num_rows = 8; + const size_t num_columns = 1; + const int n_bins = 2; + const GradientSumT eps = 1e-6; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + sycl::DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + + GHistIndexMatrix gmat_sycl; + gmat_sycl.Init(qu, &ctx, dmat, n_bins); + + xgboost::GHistIndexMatrix gmat{&ctx, p_fmat.get(), n_bins, 0.3, false}; + + RowSetCollection row_set_collection; + auto& row_indices = row_set_collection.Data(); + row_indices.Resize(&qu, num_rows); + size_t* p_row_indices = row_indices.Data(); + + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(num_rows), + [p_row_indices](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + p_row_indices[idx] = idx; + }); + }).wait_and_throw(); + row_set_collection.Init(); + + auto builder = GHistBuilder(qu, n_bins); + + std::vector gpair = { + {0.1f, 0.2f}, {0.3f, 0.4f}, {0.5f, 0.6f}, {0.7f, 0.8f}, + {0.9f, 0.1f}, {0.2f, 0.3f}, {0.4f, 0.5f}, {0.6f, 0.7f}}; + CHECK_EQ(gpair.size(), num_rows); + USMVector gpair_device(&qu, gpair); + + std::vector hist_host(2*n_bins); + GHistRow hist(&qu, 2 * n_bins); + ::sycl::event event; + + const size_t nblocks = 2; + GHistRow hist_buffer(&qu, 2 * nblocks * n_bins); + + InitHist(qu, &hist, hist.Size(), &event); + InitHist(qu, &hist_buffer, hist_buffer.Size(), &event); + + event = builder.BuildHist(gpair_device, row_set_collection[0], gmat_sycl, &hist, + sparsity < eps , &hist_buffer, event, force_atomic_use); + qu.memcpy(hist_host.data(), hist.Data(), + 2 * n_bins * sizeof(GradientSumT), event); + qu.wait_and_throw(); + + // Build hist on host to compare + std::vector hist_desired(2*n_bins); + for (size_t rid = 0; rid < num_rows; ++rid) { + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; + for (size_t i = ibegin; i < iend; ++i) { + const size_t bin_idx = gmat.index[i]; + hist_desired[2*bin_idx] += gpair[rid].GetGrad(); + hist_desired[2*bin_idx+1] += gpair[rid].GetHess(); + } + } + + VerifySyclVector(hist_host, hist_desired, eps); +} + +template +void GHistSubtractionTest() { + const size_t n_bins = 4; + using GHistType = GHistRow; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + ::sycl::event event; + std::vector hist1_host = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}; + GHistType hist1(&qu, 2 * n_bins); + event = qu.memcpy(hist1.Data(), hist1_host.data(), + 2 * n_bins * sizeof(GradientSumT), event); + + std::vector hist2_host = {0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1}; + GHistType hist2(&qu, 2 * n_bins); + event = qu.memcpy(hist2.Data(), hist2_host.data(), + 2 * n_bins * sizeof(GradientSumT), event); + + std::vector hist3_host(2 * n_bins); + GHistType hist3(&qu, 2 * n_bins); + event = SubtractionHist(qu, &hist3, hist1, hist2, n_bins, event); + qu.memcpy(hist3_host.data(), hist3.Data(), + 2 * n_bins * sizeof(GradientSumT), event); + qu.wait_and_throw(); + + std::vector hist3_desired(2 * n_bins); + for (size_t idx = 0; idx < 2 * n_bins; ++idx) { + hist3_desired[idx] = hist1_host[idx] - hist2_host[idx]; + } + + const GradientSumT eps = 1e-6; + VerifySyclVector(hist3_host, hist3_desired, eps); +} + +TEST(SyclGHistBuilder, ByBlockDenseCase) { + GHistBuilderTest(0.0, false); + GHistBuilderTest(0.0, false); +} + +TEST(SyclGHistBuilder, ByBlockSparseCase) { + GHistBuilderTest(0.3, false); + GHistBuilderTest(0.3, false); +} + +TEST(SyclGHistBuilder, ByAtomicDenseCase) { + GHistBuilderTest(0.0, true); + GHistBuilderTest(0.0, true); +} + +TEST(SyclGHistBuilder, ByAtomicSparseCase) { + GHistBuilderTest(0.3, true); + GHistBuilderTest(0.3, true); +} + +TEST(SyclGHistBuilder, Subtraction) { + GHistSubtractionTest(); + GHistSubtractionTest(); +} + +} // namespace xgboost::sycl::common From 230010d9a0ddd0334b34f30dfed183f51653baca Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 26 Mar 2024 23:26:24 +0800 Subject: [PATCH 02/26] Cleanup set info. (#10139) - Use the array interface internally. - Deprecate `XGDMatrixSetDenseInfo`. - Deprecate `XGDMatrixSetUIntInfo`. - Move the handling of `DataType` into the deprecated C function. --------- Co-authored-by: Philip Hyunsu Cho --- .github/workflows/r_tests.yml | 8 +-- include/xgboost/c_api.h | 60 ++++-------------- include/xgboost/data.h | 13 ---- include/xgboost/linalg.h | 15 ++++- include/xgboost/span.h | 10 ++- .../xgboost4j/src/native/xgboost4j.cpp | 10 +-- src/c_api/c_api.cc | 62 ++++++++++++++----- src/c_api/c_api_utils.h | 19 +++--- src/collective/nccl_device_communicator.cu | 2 + src/common/error_msg.cc | 2 +- src/common/error_msg.h | 2 +- src/common/host_device_vector.cu | 1 - src/common/quantile.cc | 1 + src/common/quantile.cu | 6 +- src/common/quantile.cuh | 5 +- src/data/data.cc | 46 +------------- src/data/file_iterator.cc | 52 ++++++++++++++-- src/data/file_iterator.h | 44 +------------ src/gbm/gbtree.h | 6 +- src/metric/elementwise_metric.cu | 6 +- src/metric/metric_common.h | 2 - src/metric/multiclass_metric.cu | 2 +- src/metric/survival_metric.cu | 3 +- tests/cpp/collective/test_allreduce.cc | 2 + tests/cpp/collective/test_worker.h | 3 +- tests/cpp/common/test_hist_util.cc | 12 ++-- tests/cpp/common/test_hist_util.cu | 10 +-- tests/cpp/common/test_transform_range.cc | 5 +- tests/cpp/data/test_array_interface.cu | 5 +- tests/cpp/data/test_metainfo.cc | 34 +++++----- tests/cpp/gbm/test_gbtree.cc | 2 +- tests/cpp/helpers.cu | 5 +- tests/cpp/helpers.h | 19 ++++-- tests/cpp/metric/test_elementwise_metric.h | 3 +- .../cpp/objective/test_regression_obj_cpu.cc | 11 ++-- tests/cpp/test_learner.cc | 12 ++-- tests/cpp/tree/gpu_hist/test_histogram.cu | 14 +++++ 37 files changed, 246 insertions(+), 268 deletions(-) diff --git a/.github/workflows/r_tests.yml b/.github/workflows/r_tests.yml index 045dac575..7dbdf3a84 100644 --- a/.github/workflows/r_tests.yml +++ b/.github/workflows/r_tests.yml @@ -110,7 +110,7 @@ jobs: name: Test R package on Debian runs-on: ubuntu-latest container: - image: rhub/debian-gcc-devel + image: rhub/debian-gcc-release steps: - name: Install system dependencies @@ -130,12 +130,12 @@ jobs: - name: Install dependencies shell: bash -l {0} run: | - /tmp/R-devel/bin/Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')" + Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')" - name: Test R shell: bash -l {0} run: | - python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --build-tool=autotools --task=check + python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --build-tool=autotools --task=check - uses: dorny/paths-filter@v2 id: changes @@ -147,4 +147,4 @@ jobs: - name: Run document check if: steps.changes.outputs.r_package == 'true' run: | - python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --task=doc + python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --task=doc diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 795c78946..e065d8ba1 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1,5 +1,5 @@ /** - * Copyright 2015~2023 by XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file c_api.h * \author Tianqi Chen * \brief C API of XGBoost, used for interfacing to other languages. @@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, * \param len length of array * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, - const char *field, - const float *array, +XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array, bst_ulong len); -/*! - * \brief set uint32 vector to a content in info - * \param handle a instance of data matrix - * \param field field name - * \param array pointer to unsigned int vector - * \param len length of array - * \return 0 when success, -1 when failure happens +/** + * @deprecated since 2.1.0 + * + * Use @ref XGDMatrixSetInfoFromInterface instead. */ -XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, - const char *field, - const unsigned *array, +XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array, bst_ulong len); /*! @@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, bst_ulong *size, const char ***out_features); -/*! - * \brief Set meta info from dense matrix. Valid field names are: +/** + * @deprecated since 2.1.0 * - * - label - * - weight - * - base_margin - * - group - * - label_lower_bound - * - label_upper_bound - * - feature_weights - * - * \param handle An instance of data matrix - * \param field Field name - * \param data Pointer to consecutive memory storing data. - * \param size Size of the data, this is relative to size of type. (Meaning NOT number - * of bytes.) - * \param type Indicator of data type. This is defined in xgboost::DataType enum class. - * - float = 1 - * - double = 2 - * - uint32_t = 3 - * - uint64_t = 4 - * \return 0 when success, -1 when failure happens + * Use @ref XGDMatrixSetInfoFromInterface instead. */ -XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, - void const *data, bst_ulong size, int type); - -/*! - * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix - * \param handle a instance of data matrix - * \param group pointer to group size - * \param len length of array - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, - const unsigned *group, - bst_ulong len); +XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data, + bst_ulong size, int type); /*! * \brief get float info vector from matrix. diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 2bdf3713d..ec06a9c86 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -137,14 +136,6 @@ class MetaInfo { * \param fo The output stream. */ void SaveBinary(dmlc::Stream* fo) const; - /*! - * \brief Set information in the meta info. - * \param key The key of the information. - * \param dptr The data pointer of the source array. - * \param dtype The type of the source data. - * \param num Number of elements in the source array. - */ - void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num); /*! * \brief Set information in the meta info with array interface. * \param key The key of the information. @@ -517,10 +508,6 @@ class DMatrix { DMatrix() = default; /*! \brief meta information of the dataset */ virtual MetaInfo& Info() = 0; - virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { - auto const& ctx = *this->Ctx(); - this->Info().SetInfo(ctx, key, dptr, dtype, num); - } virtual void SetInfo(const char* key, std::string const& interface_str) { auto const& ctx = *this->Ctx(); this->Info().SetInfo(ctx, key, StringView{interface_str}); diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index f538adbcd..cb7668f4c 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) { // uint division optimization inspired by the CIndexer in cupy. Division operation is // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. -template +template LINALG_HD auto UnravelImpl(I idx, common::Span shape) { - size_t index[D]{0}; + std::size_t index[D]{0}; static_assert(std::is_signed::value, "Don't change the type without changing the for loop."); + auto const sptr = shape.data(); for (int32_t dim = D; --dim > 0;) { - auto s = static_cast>>(shape[dim]); + auto s = static_cast>>(sptr[dim]); if (s & (s - 1)) { auto t = idx / s; index[dim] = idx - t * s; @@ -745,6 +746,14 @@ auto ArrayInterfaceStr(TensorView const &t) { return str; } +template +auto Make1dInterface(T const *vec, std::size_t len) { + Context ctx; + auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len); + auto str = linalg::ArrayInterfaceStr(t); + return str; +} + /** * \brief A tensor storage. To use it for other functionality like slicing one needs to * obtain a view first. This way we can use it on both host and device. diff --git a/include/xgboost/span.h b/include/xgboost/span.h index 29ca76d3c..7471c2e44 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -30,9 +30,8 @@ #define XGBOOST_SPAN_H_ #include -#include -#include // size_t +#include // size_t #include #include #include // numeric_limits @@ -73,8 +72,7 @@ #endif // defined(_MSC_VER) && _MSC_VER < 1910 -namespace xgboost { -namespace common { +namespace xgboost::common { #if defined(__CUDA_ARCH__) // Usual logging facility is not available inside device code. @@ -707,8 +705,8 @@ class IterSpan { return it_ + size(); } }; -} // namespace common -} // namespace xgboost +} // namespace xgboost::common + #if defined(_MSC_VER) &&_MSC_VER < 1910 #undef constexpr diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 332b1a127..9ba944d5a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -408,7 +408,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); + auto str = xgboost::linalg::Make1dInterface(array, len); + int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str()); JVM_CHECK_CALL(ret); //release if (field) jenv->ReleaseStringUTFChars(jfield, field); @@ -427,7 +428,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn const char* field = jenv->GetStringUTFChars(jfield, 0); jint* array = jenv->GetIntArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); + auto str = xgboost::linalg::Make1dInterface(array, len); + int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str()); JVM_CHECK_CALL(ret); //release if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); @@ -730,8 +732,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr if (jmargin) { margin = jenv->GetFloatArrayElements(jmargin, nullptr); JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy)); - JVM_CHECK_CALL( - XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin))); + auto str = xgboost::linalg::Make1dInterface(margin, jenv->GetArrayLength(jmargin)); + JVM_CHECK_CALL(XGDMatrixSetInfoFromInterface(proxy, "base_margin", str.c_str())); } bst_ulong const *out_shape; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 0f4748bfe..598b7f2f5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors */ #include "xgboost/c_api.h" @@ -614,8 +614,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(field); - auto const& p_fmat = *static_cast *>(handle); - p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len); + auto const &p_fmat = *static_cast *>(handle); + p_fmat->SetInfo(field, linalg::Make1dInterface(info, len)); API_END(); } @@ -634,8 +634,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(field); + LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface"); auto const &p_fmat = *static_cast *>(handle); - p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len); + p_fmat->SetInfo(field, linalg::Make1dInterface(info, len)); API_END(); } @@ -679,19 +680,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void xgboost::bst_ulong size, int type) { API_BEGIN(); CHECK_HANDLE(); + LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface"); auto const &p_fmat = *static_cast *>(handle); CHECK(type >= 1 && type <= 4); xgboost_CHECK_C_ARG_PTR(field); - p_fmat->SetInfo(field, data, static_cast(type), size); - API_END(); -} -XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) { - API_BEGIN(); - CHECK_HANDLE(); - LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead."; - auto const &p_fmat = *static_cast *>(handle); - p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len); + Context ctx; + auto dtype = static_cast(type); + std::string str; + auto proc = [&](auto cast_d_ptr) { + using T = std::remove_pointer_t; + auto t = linalg::TensorView( + common::Span{cast_d_ptr, static_cast::index_type>(size)}, + {size}, DeviceOrd::CPU()); + CHECK(t.CContiguous()); + Json interface{linalg::ArrayInterface(t)}; + CHECK(ArrayInterface<1>{interface}.is_contiguous); + str = Json::Dump(interface); + return str; + }; + + // Legacy code using XGBoost dtype, which is a small subset of array interface types. + switch (dtype) { + case xgboost::DataType::kFloat32: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + case xgboost::DataType::kDouble: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt32: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt64: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + default: + LOG(FATAL) << "Unknown data type" << static_cast(dtype); + } + API_END(); } @@ -987,7 +1021,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs bst_float *hess, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); - error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter"); + LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter"); auto *learner = static_cast(handle); auto ctx = learner->Ctx()->MakeCPU(); diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 95efb5b9d..04b0fc007 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -1,17 +1,18 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_ -#include -#include -#include -#include // for shared_ptr -#include // for string -#include // for make_tuple -#include // for move -#include +#include // for min +#include // for size_t +#include // for multiplies +#include // for shared_ptr +#include // for accumulate +#include // for string +#include // for make_tuple +#include // for move +#include // for vector #include "../common/json_utils.h" // for TypeCheck #include "xgboost/c_api.h" diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 31c2d394d..b896e7d06 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -2,6 +2,8 @@ * Copyright 2023 XGBoost contributors */ #if defined(XGBOOST_USE_NCCL) +#include // for accumulate + #include "comm.cuh" #include "nccl_device_communicator.cuh" diff --git a/src/common/error_msg.cc b/src/common/error_msg.cc index 8871c1a1d..cdbe5ebf6 100644 --- a/src/common/error_msg.cc +++ b/src/common/error_msg.cc @@ -11,7 +11,7 @@ #include "xgboost/logging.h" namespace xgboost::error { -std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) { +[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) { std::stringstream ss; ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead."; return ss.str(); diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 7264c3532..67114320b 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -89,7 +89,7 @@ void WarnDeprecatedGPUId(); void WarnEmptyDataset(); -std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); +[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); constexpr StringView InvalidCUDAOrdinal() { return "Invalid device. `device` is required to be CUDA and there must be at least one GPU " diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 267309288..99448df21 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -6,7 +6,6 @@ #include #include -#include #include "xgboost/data.h" #include "xgboost/host_device_vector.h" diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 4ae6ecd36..8c743d940 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -4,6 +4,7 @@ #include "quantile.h" #include +#include // for partial_sum #include #include "../collective/aggregator.h" diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 529ee30df..e7f09fc4d 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include #include @@ -8,8 +8,8 @@ #include #include -#include // std::numeric_limits -#include +#include // std::numeric_limits +#include // for partial_sum #include #include "../collective/communicator-inl.cuh" diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 6a5a38613..898da03a0 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -1,8 +1,9 @@ +/** + * Copyright 2020-2024, XGBoost Contributors + */ #ifndef XGBOOST_COMMON_QUANTILE_CUH_ #define XGBOOST_COMMON_QUANTILE_CUH_ -#include - #include "xgboost/span.h" #include "xgboost/data.h" #include "device_helpers.cuh" diff --git a/src/data/data.cc b/src/data/data.cc index 8cdcde201..22854def8 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -11,7 +11,6 @@ #include // for abs #include // for uint64_t, int32_t, uint8_t, uint32_t #include // for size_t, strcmp, memcpy -#include // for exception #include // for operator<<, basic_ostream, basic_ostream::op... #include // for map, operator!= #include // for accumulate, partial_sum @@ -22,7 +21,6 @@ #include "../collective/communicator.h" // for Operation #include "../common/algorithm.h" // for StableSort #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "../common/common.h" // for Split #include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData #include "../common/group_data.h" // for ParallelGroupBuilder #include "../common/io.h" // for PeekableInStream @@ -473,11 +471,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_ << ", must have at least 1 column even if it's empty."; auto const& first = get(array.front()); auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); - is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr); } else { auto const& first = get(j_interface); auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); - is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr); } if (is_cuda) { @@ -567,46 +565,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { } } -void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, - size_t num) { - CHECK(key); - auto proc = [&](auto cast_d_ptr) { - using T = std::remove_pointer_t; - auto t = linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, DeviceOrd::CPU()); - CHECK(t.CContiguous()); - Json interface { - linalg::ArrayInterface(t) - }; - assert(ArrayInterface<1>{interface}.is_contiguous); - return interface; - }; - // Legacy code using XGBoost dtype, which is a small subset of array interface types. - switch (dtype) { - case xgboost::DataType::kFloat32: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - case xgboost::DataType::kDouble: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - case xgboost::DataType::kUInt32: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - case xgboost::DataType::kUInt64: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - default: - LOG(FATAL) << "Unknown data type" << static_cast(dtype); - } -} - void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, const void** out_dptr) const { if (dtype == DataType::kFloat32) { diff --git a/src/data/file_iterator.cc b/src/data/file_iterator.cc index cebfbdc19..1e341447c 100644 --- a/src/data/file_iterator.cc +++ b/src/data/file_iterator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023, XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ #include "file_iterator.h" @@ -10,7 +10,10 @@ #include // for operator<<, basic_ostream, istringstream #include // for vector -#include "../common/common.h" // for Split +#include "../common/common.h" // for Split +#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec +#include "xgboost/linalg.h" +#include "xgboost/logging.h" // for CHECK #include "xgboost/string_view.h" // for operator<<, StringView namespace xgboost::data { @@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) { for (size_t i = 0; i < arg_list.size(); ++i) { std::istringstream is(arg_list[i]); std::pair kv; - CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format" - << " for key in arg " << i + 1; - CHECK(std::getline(is, kv.second)) << "Invalid uri argument format" - << " for value in arg " << i + 1; + CHECK(std::getline(is, kv.first, '=')) + << "Invalid uri argument format" << " for key in arg " << i + 1; + CHECK(std::getline(is, kv.second)) + << "Invalid uri argument format" << " for value in arg " << i + 1; args.insert(kv); } if (args.find("format") == args.cend()) { @@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) { return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1]; } } + +int FileIterator::Next() { + CHECK(parser_); + if (parser_->Next()) { + row_block_ = parser_->Value(); + + indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1); + values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]); + indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]); + + size_t n_columns = + *std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]); + // dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore + // this condition and just add 1 to n_columns + n_columns += 1; + + XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns); + + if (row_block_.label) { + auto str = linalg::Make1dInterface(row_block_.label, row_block_.size); + XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str()); + } + if (row_block_.qid) { + auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size); + XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str()); + } + if (row_block_.weight) { + auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size); + XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str()); + } + // Continue iteration + return true; + } else { + // Stop iteration + return false; + } +} } // namespace xgboost::data diff --git a/src/data/file_iterator.h b/src/data/file_iterator.h index c7f23b478..a4afbabe4 100644 --- a/src/data/file_iterator.h +++ b/src/data/file_iterator.h @@ -1,20 +1,16 @@ /** - * Copyright 2021-2023, XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ #ifndef XGBOOST_DATA_FILE_ITERATOR_H_ #define XGBOOST_DATA_FILE_ITERATOR_H_ -#include // for max_element -#include // for size_t #include // for uint32_t #include // for unique_ptr #include // for string #include // for move #include "dmlc/data.h" // for RowBlock, Parser -#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate -#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec -#include "xgboost/logging.h" // for CHECK +#include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate namespace xgboost::data { [[nodiscard]] std::string ValidateFileFormat(std::string const& uri); @@ -53,41 +49,7 @@ class FileIterator { XGDMatrixFree(proxy_); } - int Next() { - CHECK(parser_); - if (parser_->Next()) { - row_block_ = parser_->Value(); - using linalg::MakeVec; - - indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1)); - values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size])); - indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size])); - - size_t n_columns = *std::max_element(row_block_.index, - row_block_.index + row_block_.offset[row_block_.size]); - // dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore - // this condition and just add 1 to n_columns - n_columns += 1; - - XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), - values_.c_str(), n_columns); - - if (row_block_.label) { - XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1); - } - if (row_block_.qid) { - XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1); - } - if (row_block_.weight) { - XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1); - } - // Continue iteration - return true; - } else { - // Stop iteration - return false; - } - } + int Next(); auto Proxy() -> decltype(proxy_) { return proxy_; } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index a2d84d848..d6ed851c8 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by Contributors + * Copyright 2014-2024, XGBoost Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -11,14 +11,12 @@ #include #include // std::int32_t -#include #include +#include // for iota #include -#include #include #include -#include "../common/common.h" #include "../common/timer.h" #include "../tree/param.h" // TrainParam #include "gbtree_model.h" diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 9c26011aa..ec5b9079d 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -10,15 +10,15 @@ #include #include +#include // for accumulate -#include "../collective/communicator-inl.h" -#include "../common/common.h" // MetricNoCache +#include "../common/common.h" // for AssertGPUSupport #include "../common/math.h" #include "../common/optional_weight.h" // OptionalWeights #include "../common/pseudo_huber.h" #include "../common/quantile_loss_utils.h" // QuantileLossParam #include "../common/threading_utils.h" -#include "metric_common.h" +#include "metric_common.h" // MetricNoCache #include "xgboost/collective/result.h" // for SafeColl #include "xgboost/metric.h" diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 53c38ff2a..2b9239990 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -9,8 +9,6 @@ #include #include "../collective/aggregator.h" -#include "../collective/communicator-inl.h" -#include "../common/common.h" #include "xgboost/metric.h" namespace xgboost { diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index acaef7cf7..e51509fc7 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -9,8 +9,8 @@ #include #include #include +#include // for accumulate -#include "../collective/communicator-inl.h" #include "../common/math.h" #include "../common/threading_utils.h" #include "metric_common.h" // MetricNoCache diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index c64fece6c..9c57be3ab 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -9,10 +9,9 @@ #include #include +#include // for accumulate #include -#include "../collective/communicator-inl.h" -#include "../common/math.h" #include "../common/survival_util.h" #include "../common/threading_utils.h" #include "metric_common.h" // MetricNoCache diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 21b4d9fd0..8359d17a6 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -3,6 +3,8 @@ */ #include +#include // for iota + #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/tracker.h" diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index acee0f297..7b76052c8 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -1,11 +1,12 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include #include // for seconds #include // for int32_t +#include // for ifstream #include // for string #include // for thread #include // for move diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 5391bc2cf..24e67c9aa 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -1,10 +1,9 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include #include -#include #include "../../../src/common/hist_util.h" #include "../../../src/data/gradient_index.h" @@ -135,7 +134,7 @@ TEST(CutsBuilder, SearchGroupInd) { group[2] = 7; group[3] = 5; - p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups); + p_mat->SetInfo("group", Make1dInterfaceTest(group.data(), group.size())); HistogramCuts hmat; @@ -348,7 +347,8 @@ void TestSketchFromWeights(bool with_group) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); + auto sg = linalg::Make1dInterface(groups.data(), kGroups); + info.SetInfo(ctx, "group", sg.c_str()); } info.num_row_ = kRows; @@ -356,10 +356,10 @@ void TestSketchFromWeights(bool with_group) { // Assign weights. if (with_group) { - m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + m->SetInfo("group", Make1dInterfaceTest(groups.data(), kGroups)); } - m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + m->SetInfo("weight", Make1dInterfaceTest(h_weights.data(), h_weights.size())); m->Info().num_col_ = kCols; m->Info().num_row_ = kRows; ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 73af7115c..e37f02ddb 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include @@ -682,7 +682,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + m->SetInfo("group", Make1dInterfaceTest(groups.data(), kGroups)); HistogramCuts weighted_cuts = DeviceSketch(&ctx, m.get(), kBins, 0); // sketch with no weight @@ -727,7 +727,7 @@ void TestAdapterSketchFromWeights(bool with_group) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); + info.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), kGroups)); } info.weights_.SetDevice(DeviceOrd::CUDA(0)); @@ -746,10 +746,10 @@ void TestAdapterSketchFromWeights(bool with_group) { auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); if (with_group) { - dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); + dmat->Info().SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), kGroups)); } - dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + dmat->Info().SetInfo(ctx, "weight", Make1dInterfaceTest(h_weights.data(), h_weights.size())); dmat->Info().num_col_ = kCols; dmat->Info().num_row_ = kRows; ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index 24d0267b6..4fc06f639 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -1,11 +1,12 @@ /** - * Copyright 2018-2023 by XGBoost Contributors + * Copyright 2018-2024, XGBoost Contributors */ #include #include -#include #include +#include +#include // for iota #include #include "../../../src/common/transform.h" diff --git a/tests/cpp/data/test_array_interface.cu b/tests/cpp/data/test_array_interface.cu index 00b996fb9..be8160c8a 100644 --- a/tests/cpp/data/test_array_interface.cu +++ b/tests/cpp/data/test_array_interface.cu @@ -1,10 +1,11 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include #include -#include "../helpers.h" + #include "../../../src/data/array_interface.h" +#include "../helpers.h" namespace xgboost { diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 0e63ab8f8..a7d9a0c76 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -10,7 +10,6 @@ #include #include -#include "../../../src/common/version.h" #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" // for GMockTHrow #include "xgboost/base.h" @@ -23,23 +22,22 @@ TEST(MetaInfo, GetSet) { double double2[2] = {1.0, 2.0}; EXPECT_EQ(info.labels.Size(), 0); - info.SetInfo(ctx, "label", double2, xgboost::DataType::kFloat32, 2); + info.SetInfo(ctx, "label", Make1dInterfaceTest(double2, 2)); EXPECT_EQ(info.labels.Size(), 2); float float2[2] = {1.0f, 2.0f}; - EXPECT_EQ(info.GetWeight(1), 1.0f) - << "When no weights are given, was expecting default value 1"; - info.SetInfo(ctx, "weight", float2, xgboost::DataType::kFloat32, 2); + EXPECT_EQ(info.GetWeight(1), 1.0f) << "When no weights are given, was expecting default value 1"; + info.SetInfo(ctx, "weight", Make1dInterfaceTest(float2, 2)); EXPECT_EQ(info.GetWeight(1), 2.0f); uint32_t uint32_t2[2] = {1U, 2U}; EXPECT_EQ(info.base_margin_.Size(), 0); - info.SetInfo(ctx, "base_margin", uint32_t2, xgboost::DataType::kUInt32, 2); + info.SetInfo(ctx, "base_margin", Make1dInterfaceTest(uint32_t2, 2)); EXPECT_EQ(info.base_margin_.Size(), 2); uint64_t uint64_t2[2] = {1U, 2U}; EXPECT_EQ(info.group_ptr_.size(), 0); - info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2); + info.SetInfo(ctx, "group", Make1dInterfaceTest(uint64_t2, 2)); ASSERT_EQ(info.group_ptr_.size(), 3); EXPECT_EQ(info.group_ptr_[2], 3); @@ -135,9 +133,9 @@ TEST(MetaInfo, SaveLoadBinary) { }; std::vector values (kRows); std::generate(values.begin(), values.end(), generator); - info.SetInfo(ctx, "label", values.data(), xgboost::DataType::kFloat32, kRows); - info.SetInfo(ctx, "weight", values.data(), xgboost::DataType::kFloat32, kRows); - info.SetInfo(ctx, "base_margin", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo(ctx, "label", Make1dInterfaceTest(values.data(), kRows)); + info.SetInfo(ctx, "weight", Make1dInterfaceTest(values.data(), kRows)); + info.SetInfo(ctx, "base_margin", Make1dInterfaceTest(values.data(), kRows)); info.num_row_ = kRows; info.num_col_ = kCols; @@ -271,7 +269,7 @@ TEST(MetaInfo, CPUQid) { qid[i] = i; } - info.SetInfo(ctx, "qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_); + info.SetInfo(ctx, "qid", Make1dInterfaceTest(qid.data(), info.num_row_)); ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1); ASSERT_EQ(info.group_ptr_.front(), 0); ASSERT_EQ(info.group_ptr_.back(), info.num_row_); @@ -288,14 +286,12 @@ TEST(MetaInfo, Validate) { info.num_col_ = 3; std::vector groups (11); Context ctx; - info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, 11); + info.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())); EXPECT_THROW(info.Validate(FstCU()), dmlc::Error); std::vector labels(info.num_row_ + 1); EXPECT_THROW( - { - info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); - }, + { info.SetInfo(ctx, "label", Make1dInterfaceTest(labels.data(), info.num_row_ + 1)); }, dmlc::Error); // Make overflow data, which can happen when users pass group structure as int @@ -305,13 +301,13 @@ TEST(MetaInfo, Validate) { groups.push_back(1562500); } groups.push_back(static_cast(-1)); - EXPECT_THROW(info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()), + EXPECT_THROW(info.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())), dmlc::Error); #if defined(XGBOOST_USE_CUDA) info.group_ptr_.clear(); labels.resize(info.num_row_); - info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); + info.SetInfo(ctx, "label", Make1dInterfaceTest(labels.data(), info.num_row_)); info.labels.SetDevice(FstCU()); EXPECT_THROW(info.Validate(DeviceOrd::CUDA(1)), dmlc::Error); @@ -340,8 +336,8 @@ TEST(MetaInfo, HostExtend) { for (size_t g = 0; g < kRows / per_group; ++g) { groups.emplace_back(per_group); } - lhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()); - rhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()); + lhs.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())); + rhs.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())); lhs.Extend(rhs, true, true); ASSERT_EQ(lhs.num_row_, kRows * 2); diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 8f1588077..dcb89b971 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -408,7 +408,7 @@ class Dart : public testing::TestWithParam { for (size_t i = 0; i < kRows; ++i) { labels[i] = i % 2; } - p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows); + p_mat->SetInfo("label", Make1dInterfaceTest(labels.data(), kRows)); auto learner = std::unique_ptr(Learner::Create({p_mat})); learner->SetParam("booster", "dart"); diff --git a/tests/cpp/helpers.cu b/tests/cpp/helpers.cu index db94da27a..f75628953 100644 --- a/tests/cpp/helpers.cu +++ b/tests/cpp/helpers.cu @@ -1,8 +1,11 @@ +/** + * Copyright 2020-2024, XGBoost contributors + */ #include -#include "helpers.h" #include "../../src/data/device_adapter.cuh" #include "../../src/data/iterative_dmatrix.h" +#include "helpers.h" namespace xgboost { diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index c161856bb..273cc0f00 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -15,19 +15,18 @@ #include // std::int32_t #include -#include -#include #include #include -#include #include #include "../../src/collective/communicator-inl.h" #include "../../src/common/common.h" #include "../../src/common/threading_utils.h" -#include "../../src/data/array_interface.h" #include "filesystem.h" // dmlc::TemporaryDirectory #include "xgboost/linalg.h" +#if !defined(_OPENMP) +#include +#endif #if defined(__CUDACC__) #define DeclareUnifiedTest(name) GPU ## name @@ -333,7 +332,7 @@ inline std::vector GenerateRandomCategoricalSingleColumn(int n, size_t nu std::vector x(n); std::mt19937 rng(0); std::uniform_int_distribution dist(0, num_categories - 1); - std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); + std::generate(x.begin(), x.end(), [&]() { return static_cast(dist(rng)); }); // Make sure each category is present for (size_t i = 0; i < num_categories; i++) { x[i] = static_cast(i); @@ -494,6 +493,16 @@ inline int Next(DataIterHandle self) { return static_cast(self)->Next(); } +/** + * @brief Create an array interface for host vector. + */ +template +char const* Make1dInterfaceTest(T const* vec, std::size_t len) { + static thread_local std::string str; + str = linalg::Make1dInterface(vec, len); + return str.c_str(); +} + class RMMAllocator; using RMMAllocatorPtr = std::unique_ptr; RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv); diff --git a/tests/cpp/metric/test_elementwise_metric.h b/tests/cpp/metric/test_elementwise_metric.h index ef34d7651..4435c0807 100644 --- a/tests/cpp/metric/test_elementwise_metric.h +++ b/tests/cpp/metric/test_elementwise_metric.h @@ -5,10 +5,9 @@ #include #include -#include #include +#include // for iota -#include "../../../src/common/linalg_op.h" #include "../helpers.h" namespace xgboost::metric { diff --git a/tests/cpp/objective/test_regression_obj_cpu.cc b/tests/cpp/objective/test_regression_obj_cpu.cc index 3613d0d90..18ee4db7e 100644 --- a/tests/cpp/objective/test_regression_obj_cpu.cc +++ b/tests/cpp/objective/test_regression_obj_cpu.cc @@ -1,14 +1,15 @@ -/*! - * Copyright 2018-2023 XGBoost contributors +/** + * Copyright 2018-2024, XGBoost contributors */ #include #include #include -#include "../../../src/objective/adaptive.h" -#include "../../../src/tree/param.h" // for TrainParam -#include "../helpers.h" +#include // for iota +#include "../../../src/objective/adaptive.h" +#include "../../../src/tree/param.h" // for TrainParam +#include "../helpers.h" #include "test_regression_obj.h" namespace xgboost { diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 6fe65b97e..541f53008 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -12,7 +12,6 @@ #include // for int32_t, int64_t, uint32_t #include // for size_t #include // for ofstream -#include // for back_insert_iterator, back_inserter #include // for numeric_limits #include // for map #include // for unique_ptr, shared_ptr, __shared_ptr_... @@ -30,7 +29,6 @@ #include "../../src/common/random.h" // for GlobalRandom #include "dmlc/io.h" // for Stream #include "dmlc/omp.h" // for omp_get_max_threads -#include "dmlc/registry.h" // for Registry #include "filesystem.h" // for TemporaryDirectory #include "helpers.h" // for GetBaseScore, RandomDataGenerator #include "objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator @@ -103,9 +101,9 @@ TEST(Learner, CheckGroup) { labels[i] = i % 2; } - p_mat->SetInfo("weight", static_cast(weight.data()), DataType::kFloat32, kNumGroups); - p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups); - p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kNumRows); + p_mat->SetInfo("weight", Make1dInterfaceTest(weight.data(), kNumGroups)); + p_mat->SetInfo("group", Make1dInterfaceTest(group.data(), kNumGroups)); + p_mat->SetInfo("label", Make1dInterfaceTest(labels.data(), kNumRows)); std::vector> mat = {p_mat}; auto learner = std::unique_ptr(Learner::Create(mat)); @@ -115,7 +113,7 @@ TEST(Learner, CheckGroup) { group.resize(kNumGroups+1); group[3] = 4; group[4] = 1; - p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1); + p_mat->SetInfo("group", Make1dInterfaceTest(group.data(), kNumGroups+1)); EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat)); } @@ -132,7 +130,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT for (size_t i = 0; i < num_row; ++i) { labels[i] = i % 2; } - dmat->SetInfo("label", labels.data(), DataType::kFloat32, num_row); + dmat->SetInfo("label", Make1dInterfaceTest(labels.data(), num_row)); std::vector> mat{dmat}; auto learner = std::unique_ptr(Learner::Create(mat)); learner->SetParams(Args{{"objective", "binary:logistic"}}); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index f7f2e27ea..84cd956db 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -239,4 +239,18 @@ void TestAtomicAdd() { TEST(Histogram, AtomicAddInt64) { TestAtomicAdd(); } + +TEST(Histogram, Quantiser) { + auto ctx = MakeCUDACtx(0); + std::size_t n_samples{16}; + HostDeviceVector gpair(n_samples, GradientPair{1.0, 1.0}); + gpair.SetDevice(ctx.Device()); + + auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); + for (auto v : gpair.ConstHostVector()) { + auto gh = quantiser.ToFloatingPoint(quantiser.ToFixedPoint(v)); + ASSERT_EQ(gh.GetGrad(), 1.0); + ASSERT_EQ(gh.GetHess(), 1.0); + } +} } // namespace xgboost::tree From 8bad677c2f5a178bdecc3c0fdf4d2720a90f8be9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 30 Mar 2024 18:57:31 +0800 Subject: [PATCH 03/26] Update collective implementation. (#10152) * Update collective implementation. - Cleanup resource during `Finalize` to avoid handling threads in destructor. - Calculate the size for allgather automatically. - Use simple allgather for small (smaller than the number of worker) allreduce. --- plugin/federated/federated_coll.cc | 10 +-- plugin/federated/federated_coll.cu | 5 +- plugin/federated/federated_coll.cuh | 5 +- plugin/federated/federated_coll.h | 8 +- plugin/federated/federated_comm.cuh | 3 +- plugin/federated/federated_comm.h | 15 +++- plugin/federated/federated_server.h | 8 +- plugin/federated/federated_tracker.h | 3 +- src/c_api/c_api.cc | 6 +- src/c_api/coll_c_api.cc | 27 +++++-- src/collective/allgather.cc | 21 +++-- src/collective/allgather.h | 22 +++--- src/collective/allreduce.cc | 78 ++++++++++++++++--- src/collective/coll.cc | 5 +- src/collective/coll.cu | 7 +- src/collective/coll.cuh | 8 +- src/collective/coll.h | 4 +- src/collective/comm.cc | 26 ++++--- src/collective/comm.cu | 2 +- src/collective/comm.cuh | 4 + src/collective/comm.h | 25 ++++-- src/collective/comm_group.cc | 19 ++--- src/collective/comm_group.h | 21 ++++- src/collective/in_memory_handler.h | 3 +- tests/cpp/collective/test_allgather.cc | 6 +- tests/cpp/collective/test_allgather.cu | 6 +- tests/cpp/collective/test_allreduce.cc | 3 +- tests/cpp/collective/test_allreduce.cu | 2 +- .../plugin/federated/test_federated_coll.cc | 3 +- .../plugin/federated/test_federated_coll.cu | 4 +- tests/cpp/plugin/test_federated_adapter.cu | 1 - 31 files changed, 233 insertions(+), 127 deletions(-) diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc index 980992d61..b62abdada 100644 --- a/plugin/federated/federated_coll.cc +++ b/plugin/federated/federated_coll.cc @@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() { [[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span data, std::int32_t root) { - if (comm.Rank() == root) { - return BroadcastImpl(comm, &this->sequence_number_, data, root); - } else { - return BroadcastImpl(comm, &this->sequence_number_, data, root); - } + return BroadcastImpl(comm, &this->sequence_number_, data, root); } -[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data) { using namespace federated; // NOLINT auto fed = dynamic_cast(&comm); CHECK(fed); auto stub = fed->Handle(); + auto size = data.size_bytes() / comm.World(); auto offset = comm.Rank() * size; auto segment = data.subspan(offset, size); diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu index a922e1c11..3f604c50d 100644 --- a/plugin/federated/federated_coll.cu +++ b/plugin/federated/federated_coll.cu @@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() { }; } -[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data) { auto cufed = dynamic_cast(&comm); CHECK(cufed); std::vector h_data(data.size()); @@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size); + return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}); } << [&] { return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), cudaMemcpyHostToDevice, cufed->Stream())); diff --git a/plugin/federated/federated_coll.cuh b/plugin/federated/federated_coll.cuh index a1121d88f..6a690a33d 100644 --- a/plugin/federated/federated_coll.cuh +++ b/plugin/federated/federated_coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include "../../src/collective/comm.h" // for Comm, Coll #include "federated_coll.h" // for FederatedColl @@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data, - std::int64_t size) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h index c261b01e1..12443a3e1 100644 --- a/plugin/federated/federated_coll.h +++ b/plugin/federated/federated_coll.h @@ -1,12 +1,9 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #pragma once #include "../../src/collective/coll.h" // for Coll #include "../../src/collective/comm.h" // for Comm -#include "../../src/common/io.h" // for ReadAll -#include "../../src/common/json_utils.h" // for OptionalArg -#include "xgboost/json.h" // for Json namespace xgboost::collective { class FederatedColl : public Coll { @@ -20,8 +17,7 @@ class FederatedColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data, - std::int64_t) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/plugin/federated/federated_comm.cuh b/plugin/federated/federated_comm.cuh index 58c52f67e..85cecb3eb 100644 --- a/plugin/federated/federated_comm.cuh +++ b/plugin/federated/federated_comm.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once @@ -9,7 +9,6 @@ #include "../../src/common/device_helpers.cuh" // for CUDAStreamView #include "federated_comm.h" // for FederatedComm #include "xgboost/context.h" // for Context -#include "xgboost/logging.h" namespace xgboost::collective { class CUDAFederatedComm : public FederatedComm { diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index 750d94abd..b39e1878a 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #pragma once @@ -11,7 +11,6 @@ #include // for string #include "../../src/collective/comm.h" // for HostComm -#include "../../src/common/json_utils.h" // for OptionalArg #include "xgboost/json.h" namespace xgboost::collective { @@ -51,6 +50,10 @@ class FederatedComm : public HostComm { std::int32_t rank) { this->Init(host, port, world, rank, {}, {}, {}); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } ~FederatedComm() override { stub_.reset(); } [[nodiscard]] std::shared_ptr Chan(std::int32_t) const override { @@ -65,5 +68,13 @@ class FederatedComm : public HostComm { [[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); } [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; + /** + * @brief Get a string ID for the current process. + */ + [[nodiscard]] Result ProcessorName(std::string* out) const final { + auto rank = this->Rank(); + *out = "rank:" + std::to_string(rank); + return Success(); + }; }; } // namespace xgboost::collective diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index de760d9d8..4692ad6c2 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -1,22 +1,18 @@ /** - * Copyright 2022-2023, XGBoost contributors + * Copyright 2022-2024, XGBoost contributors */ #pragma once #include #include // for int32_t -#include // for future #include "../../src/collective/in_memory_handler.h" -#include "../../src/collective/tracker.h" // for Tracker -#include "xgboost/collective/result.h" // for Result namespace xgboost::federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::int32_t world_size) - : handler_{static_cast(world_size)} {} + explicit FederatedService(std::int32_t world_size) : handler_{world_size} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h index 33592fefe..ac46b6eaa 100644 --- a/plugin/federated/federated_tracker.h +++ b/plugin/federated/federated_tracker.h @@ -17,8 +17,7 @@ namespace xgboost::collective { namespace federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::int32_t world_size) - : handler_{static_cast(world_size)} {} + explicit FederatedService(std::int32_t world_size) : handler_{world_size} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 598b7f2f5..79d9793e6 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -694,9 +694,9 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void common::Span{cast_d_ptr, static_cast::index_type>(size)}, {size}, DeviceOrd::CPU()); CHECK(t.CContiguous()); - Json interface{linalg::ArrayInterface(t)}; - CHECK(ArrayInterface<1>{interface}.is_contiguous); - str = Json::Dump(interface); + Json iface{linalg::ArrayInterface(t)}; + CHECK(ArrayInterface<1>{iface}.is_contiguous); + str = Json::Dump(iface); return str; }; diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc index 01713dbad..24e94f3de 100644 --- a/src/c_api/coll_c_api.cc +++ b/src/c_api/coll_c_api.cc @@ -1,8 +1,7 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for seconds -#include // for size_t #include // for future #include // for unique_ptr #include // for string @@ -10,6 +9,7 @@ #include // for pair #include "../collective/tracker.h" // for RabitTracker +#include "../common/timer.h" // for Timer #include "c_api_error.h" // for API_BEGIN #include "xgboost/c_api.h" #include "xgboost/collective/result.h" // for Result @@ -40,17 +40,27 @@ struct CollAPIEntry { }; using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; -void WaitImpl(TrackerHandleT *ptr) { - std::chrono::seconds wait_for{100}; +void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { + constexpr std::int64_t kDft{60}; + std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft}; + + common::Timer timer; + timer.Start(); + auto fut = ptr->second; while (fut.valid()) { auto res = fut.wait_for(wait_for); CHECK(res != std::future_status::deferred); + if (res == std::future_status::ready) { auto const &rc = ptr->second.get(); - CHECK(rc.OK()) << rc.Report(); + collective::SafeColl(rc); break; } + + if (timer.Duration() > timeout && timeout.count() != 0) { + collective::SafeColl(collective::Fail("Timeout waiting for the tracker.")); + } } } } // namespace @@ -106,14 +116,17 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { auto *ptr = GetTrackerHandle(handle); xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); - WaitImpl(ptr); + // Internally, 0 indicates no timeout, which is the default since we don't want to + // interrupt the model training. + auto timeout = OptionalArg(jconfig, "timeout", std::int64_t{0}); + WaitImpl(ptr, std::chrono::seconds{timeout}); API_END(); } XGB_DLL int XGTrackerFree(TrackerHandle handle) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); - WaitImpl(ptr); + WaitImpl(ptr, ptr->first->Timeout()); delete ptr; API_END(); } diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 148cb6cd2..446db73b5 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "allgather.h" @@ -7,6 +7,7 @@ #include // for size_t #include // for int8_t, int32_t, int64_t #include // for shared_ptr +#include // for move #include "broadcast.h" #include "comm.h" // for Comm, Channel @@ -29,16 +30,20 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size auto rc = Success() << [&] { auto send_rank = (rank + world - r + worker_off) % world; auto send_off = send_rank * segment_size; - send_off = std::min(send_off, data.size_bytes()); - auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); + bool is_last_segment = send_rank == (world - 1); + auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size; + auto send_seg = data.subspan(send_off, send_nbytes); return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); } << [&] { auto recv_rank = (rank + world - r - 1 + worker_off) % world; auto recv_off = recv_rank * segment_size; - recv_off = std::min(recv_off, data.size_bytes()); - auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); + bool is_last_segment = recv_rank == (world - 1); + auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size; + auto recv_seg = data.subspan(recv_off, recv_nbytes); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - } << [&] { return prev_ch->Block(); }; + } << [&] { + return prev_ch->Block(); + }; if (!rc.OK()) { return rc; } @@ -91,7 +96,9 @@ namespace detail { auto recv_size = sizes[recv_rank]; auto recv_seg = erased_result.subspan(recv_off, recv_size); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - } << [&] { return prev_ch->Block(); }; + } << [&] { + return prev_ch->Block(); + }; if (!rc.OK()) { return rc; } diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 4f13014be..8de9f1984 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -1,25 +1,27 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for size_t #include // for int32_t #include // for shared_ptr #include // for accumulate +#include // for string #include // for remove_cv_t #include // for vector -#include "../common/type.h" // for EraseType +#include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel +#include "comm_group.h" // for CommGroup #include "xgboost/collective/result.h" // for Result -#include "xgboost/linalg.h" -#include "xgboost/span.h" // for Span +#include "xgboost/linalg.h" // for MakeVec +#include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { /** * @param worker_off Segment offset. For example, if the rank 2 worker specifies - * worker_off = 1, then it owns the third segment. + * worker_off = 1, then it owns the third segment (2 + 1). */ [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, @@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span sizes, } // namespace detail template -[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { - auto n_bytes = sizeof(T) * size; +[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data) { + // This function is also used for ring allreduce, hence we allow the last segment to be + // larger due to round-down. + auto n_bytes_per_segment = data.size_bytes() / comm.World(); auto erased = common::EraseType(data); auto rank = comm.Rank(); @@ -61,7 +65,7 @@ template auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); - auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch); + auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch); if (!rc.OK()) { return rc; } @@ -76,7 +80,7 @@ template std::vector sizes(world, 0); sizes[rank] = data.size_bytes(); - auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}); if (!rc.OK()) { return rc; } diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index 93b76355f..d9cf8b828 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "allreduce.h" @@ -16,7 +16,44 @@ #include "xgboost/span.h" // for Span namespace xgboost::collective::cpu_impl { +namespace { template +Result RingAllreduceSmall(Comm const& comm, common::Span data, Func const& op) { + auto rank = comm.Rank(); + auto world = comm.World(); + + auto next_ch = comm.Chan(BootstrapNext(rank, world)); + auto prev_ch = comm.Chan(BootstrapPrev(rank, world)); + + std::vector buffer(data.size_bytes() * world, 0); + auto s_buffer = common::Span{buffer.data(), buffer.size()}; + + auto offset = data.size_bytes() * rank; + auto self = s_buffer.subspan(offset, data.size_bytes()); + std::copy_n(data.data(), data.size_bytes(), self.data()); + + auto typed = common::RestoreType(s_buffer); + auto rc = RingAllgather(comm, typed); + + if (!rc.OK()) { + return rc; + } + auto first = s_buffer.subspan(0, data.size_bytes()); + CHECK_EQ(first.size(), data.size()); + + for (std::int32_t r = 1; r < world; ++r) { + auto offset = data.size_bytes() * r; + auto buf = s_buffer.subspan(offset, data.size_bytes()); + op(buf, first); + } + std::copy_n(first.data(), first.size(), data.data()); + + return Success(); +} +} // namespace + +template +// note that n_bytes_in_seg is calculated with round-down. Result RingScatterReduceTyped(Comm const& comm, common::Span data, std::size_t n_bytes_in_seg, Func const& op) { auto rank = comm.Rank(); @@ -27,14 +64,17 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto next_ch = comm.Chan(dst_rank); auto prev_ch = comm.Chan(src_rank); - std::vector buffer(n_bytes_in_seg, 0); + std::vector buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0); auto s_buf = common::Span{buffer.data(), buffer.size()}; for (std::int32_t r = 0; r < world - 1; ++r) { // send to ring next - auto send_off = ((rank + world - r) % world) * n_bytes_in_seg; - send_off = std::min(send_off, data.size_bytes()); - auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg); + auto send_rank = (rank + world - r) % world; + auto send_off = send_rank * n_bytes_in_seg; + + bool is_last_segment = send_rank == (world - 1); + + auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; auto send_seg = data.subspan(send_off, seg_nbytes); auto rc = next_ch->SendAll(send_seg); @@ -43,14 +83,21 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, } // receive from ring prev - auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg; - recv_off = std::min(recv_off, data.size_bytes()); - seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg); + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = recv_rank * n_bytes_in_seg; + + is_last_segment = recv_rank == (world - 1); + + seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg; CHECK_EQ(seg_nbytes % sizeof(T), 0); auto recv_seg = data.subspan(recv_off, seg_nbytes); auto seg = s_buf.subspan(0, recv_seg.size()); - rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; + rc = std::move(rc) << [&] { + return prev_ch->RecvAll(seg); + } << [&] { + return comm.Block(); + }; if (!rc.OK()) { return rc; } @@ -68,6 +115,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons if (comm.World() == 1) { return Success(); } + if (data.size_bytes() == 0) { + return Success(); + } return DispatchDType(type, [&](auto t) { using T = decltype(t); // Divide the data into segments according to the number of workers. @@ -75,7 +125,11 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons CHECK_EQ(data.size_bytes() % n_bytes_elem, 0); auto n = data.size_bytes() / n_bytes_elem; auto world = comm.World(); - auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T); + if (n < static_cast(world)) { + return RingAllreduceSmall(comm, data, op); + } + + auto n_bytes_in_seg = (n / world) * sizeof(T); auto rc = RingScatterReduceTyped(comm, data, n_bytes_in_seg, op); if (!rc.OK()) { return rc; @@ -88,7 +142,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons return std::move(rc) << [&] { return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); - } << [&] { return comm.Block(); }; + } << [&] { + return comm.Block(); + }; }); } } // namespace xgboost::collective::cpu_impl diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 1f47d0c55..c6d03c6df 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -104,9 +104,8 @@ bool constexpr IsFloatingPointV() { return cpu_impl::Broadcast(comm, data, root); } -[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data, - std::int64_t size) { - return RingAllgather(comm, data, size); +[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data) { + return RingAllgather(comm, data); } [[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, diff --git a/src/collective/coll.cu b/src/collective/coll.cu index d1b66a8ce..b06435bfe 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -1,10 +1,9 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include // for int8_t, int64_t -#include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "../data/array_interface.h" #include "allgather.h" // for AllgatherVOffset @@ -162,14 +161,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { } << [&] { return nccl->Block(); }; } -[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data) { if (!comm.IsDistributed()) { return Success(); } auto nccl = dynamic_cast(&comm); CHECK(nccl); auto stub = nccl->Stub(); + auto size = data.size_bytes() / comm.World(); auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh index 6ededd101..4d45295d7 100644 --- a/src/collective/coll.cuh +++ b/src/collective/coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once @@ -8,8 +8,7 @@ #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "coll.h" // for Coll #include "comm.h" // for Comm -#include "nccl_stub.h" -#include "xgboost/span.h" // for Span +#include "xgboost/span.h" // for Span namespace xgboost::collective { class NCCLColl : public Coll { @@ -20,8 +19,7 @@ class NCCLColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const& comm, common::Span data, - std::int64_t size) override; + [[nodiscard]] Result Allgather(Comm const& comm, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/src/collective/coll.h b/src/collective/coll.h index 1afc8ed59..96fe35229 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this { * @brief Allgather * * @param [in,out] data Data buffer for input and output. - * @param [in] size Size of data for each worker. */ - [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data, - std::int64_t size); + [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data); /** * @brief Allgather with variable length. * diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 8260b28f6..23a8e89ed 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "comm.h" @@ -9,8 +9,9 @@ #include // for shared_ptr #include // for string #include // for move, forward - -#include "../common/common.h" // for AssertGPUSupport +#if !defined(XGBOOST_USE_NCCL) +#include "../common/common.h" // for AssertNCCLSupport +#endif // !defined(XGBOOST_USE_NCCL) #include "allgather.h" // for RingAllgather #include "protocol.h" // for kMagic #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE @@ -21,11 +22,7 @@ namespace xgboost::collective { Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry, std::string task_id) - : timeout_{timeout}, - retry_{retry}, - tracker_{host, port, -1}, - task_id_{std::move(task_id)}, - loop_{std::shared_ptr{new Loop{timeout}}} {} + : timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {} Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, std::string const& task_id, TCPSocket* out, std::int32_t rank, @@ -191,6 +188,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se std::int32_t retry, std::string task_id, StringView nccl_path) : HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, nccl_path_{std::move(nccl_path)} { + loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT auto rc = this->Bootstrap(timeout_, retry_, task_id_); if (!rc.OK()) { SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); @@ -254,9 +252,6 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // get ring neighbors std::string snext; tracker.Recv(&snext); - if (!rc.OK()) { - return Fail("Failed to receive the rank for the next worker.", std::move(rc)); - } auto jnext = Json::Load(StringView{snext}); proto::PeerInfo ninfo{jnext}; @@ -295,6 +290,10 @@ RabitComm::~RabitComm() noexcept(false) { } [[nodiscard]] Result RabitComm::Shutdown() { + if (!this->IsDistributed()) { + return Success(); + } + TCPSocket tracker; return Success() << [&] { return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); @@ -308,6 +307,11 @@ RabitComm::~RabitComm() noexcept(false) { if (n_bytes != scmd.size()) { return Fail("Faled to send cmd."); } + + this->ResetState(); + return Success(); + } << [&] { + this->channels_.clear(); return Success(); }; } diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 56681253c..8788a2436 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -80,7 +80,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength); GetCudaUUID(s_this_uuid, ctx->Device()); - auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); + auto rc = pimpl->Allgather(root, common::EraseType(s_uuid)); CHECK(rc.OK()) << rc.Report(); diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index a818d95f8..4add9ca61 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -50,6 +50,10 @@ class NCCLComm : public Comm { auto rc = this->Stream().Sync(false); return GetCUDAResult(rc); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } }; class NCCLChannel : public Channel { diff --git a/src/collective/comm.h b/src/collective/comm.h index 82aa2c45e..6ad5bc5c1 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -14,7 +14,7 @@ #include "loop.h" // for Loop #include "protocol.h" // for PeerInfo #include "xgboost/collective/result.h" // for Result -#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName #include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span @@ -54,8 +54,12 @@ class Comm : public std::enable_shared_from_this { std::thread error_worker_; std::string task_id_; std::vector> channels_; - std::shared_ptr loop_{new Loop{std::chrono::seconds{ - DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout + std::shared_ptr loop_{nullptr}; // fixme: require federated comm to have a timeout + + void ResetState() { + this->world_ = -1; + this->rank_ = 0; + } public: Comm() = default; @@ -78,7 +82,10 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] auto Rank() const { return rank_; } [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } [[nodiscard]] bool IsDistributed() const { return world_ != -1; } - void Submit(Loop::Op op) const { loop_->Submit(op); } + void Submit(Loop::Op op) const { + CHECK(loop_); + loop_->Submit(op); + } [[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { @@ -88,6 +95,14 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } + /** + * @brief Get a string ID for the current process. + */ + [[nodiscard]] virtual Result ProcessorName(std::string* out) const { + auto rc = GetHostName(out); + return rc; + } + [[nodiscard]] virtual Result Shutdown() = 0; }; /** @@ -105,7 +120,7 @@ class RabitComm : public HostComm { [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id); - [[nodiscard]] Result Shutdown(); + [[nodiscard]] Result Shutdown() final; public: // bootstrapping construction. diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index f7bbba754..7408882f6 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -1,22 +1,21 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "comm_group.h" #include // for transform +#include // for tolower #include // for seconds #include // for int32_t +#include // for back_inserter #include // for shared_ptr, unique_ptr #include // for string -#include // for vector -#include "../common/json_utils.h" // for OptionalArg -#include "coll.h" // for Coll -#include "comm.h" // for Comm -#include "tracker.h" // for GetHostAddress -#include "xgboost/collective/result.h" // for Result -#include "xgboost/context.h" // for DeviceOrd -#include "xgboost/json.h" // for Json +#include "../common/json_utils.h" // for OptionalArg +#include "coll.h" // for Coll +#include "comm.h" // for Comm +#include "xgboost/context.h" // for DeviceOrd +#include "xgboost/json.h" // for Json #if defined(XGBOOST_USE_FEDERATED) #include "../../plugin/federated/federated_coll.h" @@ -117,6 +116,8 @@ void GlobalCommGroupInit(Json config) { void GlobalCommGroupFinalize() { auto& sptr = GlobalCommGroup(); + auto rc = sptr->Finalize(); sptr.reset(); + SafeColl(rc); } } // namespace xgboost::collective diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index 2f6f91d73..61a58ba56 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -9,7 +9,6 @@ #include "coll.h" // for Comm #include "comm.h" // for Coll #include "xgboost/collective/result.h" // for Result -#include "xgboost/collective/socket.h" // for GetHostName namespace xgboost::collective { /** @@ -35,15 +34,31 @@ class CommGroup { [[nodiscard]] auto Rank() const { return comm_->Rank(); } [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + [[nodiscard]] Result Finalize() const { + return Success() << [this] { + if (gpu_comm_) { + return gpu_comm_->Shutdown(); + } + return Success(); + } << [&] { + return comm_->Shutdown(); + }; + } + [[nodiscard]] static CommGroup* Create(Json config); [[nodiscard]] std::shared_ptr Backend(DeviceOrd device) const; + /** + * @brief Decide the context to use for communication. + * + * @param ctx Global context, provides the CUDA stream and ordinal. + * @param device The device used by the data to be communicated. + */ [[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const; [[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); } [[nodiscard]] Result ProcessorName(std::string* out) const { - auto rc = GetHostName(out); - return rc; + return this->comm_->ProcessorName(out); } }; diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index f9ac52007..e9c69f537 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -32,7 +32,8 @@ class InMemoryHandler { * * This is used when the handler only needs to be initialized once with a known world size. */ - explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {} + explicit InMemoryHandler(std::int32_t worldSize) + : world_size_{static_cast(worldSize)} {} /** * @brief Initialize the handler with the world size and rank. diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index decad8786..b6158693b 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -34,7 +34,7 @@ class Worker : public WorkerForTest { std::vector data(comm_.World(), 0); data[comm_.Rank()] = comm_.Rank(); - auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1); + auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); for (std::int32_t r = 0; r < comm_.World(); ++r) { @@ -51,7 +51,7 @@ class Worker : public WorkerForTest { auto seg = s_data.subspan(comm_.Rank() * n, n); std::iota(seg.begin(), seg.end(), comm_.Rank()); - auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n); + auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); for (std::int32_t r = 0; r < comm_.World(); ++r) { @@ -104,7 +104,7 @@ class Worker : public WorkerForTest { std::vector sizes(comm_.World(), 0); sizes[comm_.Rank()] = s_data.size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); std::shared_ptr pcoll{new Coll{}}; diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu index 236108198..98ece7d17 100644 --- a/tests/cpp/collective/test_allgather.cu +++ b/tests/cpp/collective/test_allgather.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -33,7 +33,7 @@ class Worker : public NCCLWorkerForTest { // get size std::vector sizes(comm_.World(), -1); sizes[comm_.Rank()] = s_data.size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); // create result dh::device_vector result(comm_.World(), -1); @@ -57,7 +57,7 @@ class Worker : public NCCLWorkerForTest { // get size std::vector sizes(nccl_comm_->World(), 0); sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); // create result diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 8359d17a6..457594cd9 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -7,7 +7,6 @@ #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll -#include "../../../src/collective/tracker.h" #include "../../../src/common/type.h" // for EraseType #include "test_worker.h" // for WorkerForTest, TestDistributed diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu index 04ec9f773..f7e11dec2 100644 --- a/tests/cpp/collective/test_allreduce.cu +++ b/tests/cpp/collective/test_allreduce.cu @@ -5,7 +5,7 @@ #include #include // for host_vector -#include "../../../src/common/common.h" +#include "../../../src/common/common.h" // for AllVisibleGPUs #include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector #include "../../../src/common/type.h" // for EraseType #include "test_worker.cuh" // for NCCLWorkerForTest diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc index ad053f286..6b7000ef9 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cc +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -60,8 +60,7 @@ TEST_F(FederatedCollTest, Allgather) { std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), - sizeof(int)); + auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index a6ec7e352..237bdeb9d 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -5,13 +5,13 @@ #include #include // for Result +#include "../../../../src/collective/allreduce.h" #include "../../../../src/common/common.h" // for AllVisibleGPUs #include "../../../../src/common/device_helpers.cuh" // for device_vector #include "../../../../src/common/type.h" // for EraseType #include "../../collective/test_worker.h" // for SocketTest #include "../../helpers.h" // for MakeCUDACtx #include "federated_coll.cuh" -#include "federated_comm.cuh" #include "test_worker.h" // for TestFederated namespace xgboost::collective { @@ -71,7 +71,7 @@ void TestAllgather(std::shared_ptr comm, std::int32_t rank, std:: dh::device_vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int)); + auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index cec180e70..b96524878 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -26,7 +26,6 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { namespace { void VerifyAllReduceSum() { auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); auto const device = GPUIDX; int count = 3; common::SetDevice(device); From bc9ea62ec0da401bd49e31a5e5799e2ce5c78853 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sun, 31 Mar 2024 15:53:00 +0200 Subject: [PATCH 04/26] [R] Make `xgb.cv` work with `xgb.DMatrix` only, adding support for survival and ranking fields (#10031) --------- Co-authored-by: Philip Hyunsu Cho --- R-package/R/utils.R | 68 +++++++++++++---- R-package/R/xgb.DMatrix.R | 31 +++++--- R-package/R/xgb.cv.R | 92 +++++++++++++++-------- R-package/man/print.xgb.cv.Rd | 4 +- R-package/man/xgb.cv.Rd | 53 +++++++++---- R-package/man/xgb.slice.DMatrix.Rd | 10 ++- R-package/src/init.c | 4 +- R-package/src/xgboost_R.cc | 4 +- R-package/src/xgboost_R.h | 3 +- R-package/tests/testthat/test_basic.R | 60 ++++++++++++++- R-package/tests/testthat/test_callbacks.R | 4 +- R-package/tests/testthat/test_dmatrix.R | 40 +++++++++- 12 files changed, 283 insertions(+), 90 deletions(-) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 01c282a96..7b6a20f70 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -26,6 +26,11 @@ NVL <- function(x, val) { 'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map')) } +.RANKING_OBJECTIVES <- function() { + return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax', + 'multi:softprob')) +} + # # Low-level functions for boosting -------------------------------------------- @@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) { } # Generates random (stratified if needed) CV folds -generate.cv.folds <- function(nfold, nrows, stratified, label, params) { +generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { + if (NROW(group)) { + if (stratified) { + warning( + paste0( + "Stratified splitting is not supported when using 'group' attribute.", + " Will use unstratified splitting." + ) + ) + } + return(generate.group.folds(nfold, group)) + } + objective <- params$objective + if (!is.character(objective)) { + warning("Will use unstratified splitting (custom objective used)") + stratified <- FALSE + } + # cannot stratify if label is NULL + if (stratified && is.null(label)) { + warning("Will use unstratified splitting (no 'labels' available)") + stratified <- FALSE + } # cannot do it for rank - objective <- params$objective if (is.character(objective) && strtrim(objective, 5) == 'rank:') { - stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n", + stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n", "\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n") } # shuffle rnd_idx <- sample.int(nrows) - if (stratified && - length(label) == length(rnd_idx)) { + if (stratified && length(label) == length(rnd_idx)) { y <- label[rnd_idx] - # WARNING: some heuristic logic is employed to identify classification setting! # - For classification, need to convert y labels to factor before making the folds, # and then do stratification by factor levels. # - For regression, leave y numeric and do stratification by quantiles. if (is.character(objective)) { - y <- convert.labels(y, params$objective) - } else { - # If no 'objective' given in params, it means that user either wants to - # use the default 'reg:squarederror' objective or has provided a custom - # obj function. Here, assume classification setting when y has 5 or less - # unique values: - if (length(unique(y)) <= 5) { - y <- factor(y) - } + y <- convert.labels(y, objective) } folds <- xgb.createFolds(y = y, k = nfold) } else { @@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) { return(folds) } +generate.group.folds <- function(nfold, group) { + ngroups <- length(group) - 1 + if (ngroups < nfold) { + stop("DMatrix has fewer groups than folds.") + } + seq_groups <- seq_len(ngroups) + indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1])) + assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold)) + assignments <- unname(assignments) + + out <- vector("list", nfold) + randomized_groups <- sample(ngroups) + for (idx in seq_len(nfold)) { + groups_idx_test <- randomized_groups[assignments[[idx]]] + groups_test <- indices[groups_idx_test] + idx_test <- unlist(groups_test) + attributes(idx_test)$group_test <- lengths(groups_test) + attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test]) + out[[idx]] <- idx_test + } + return(out) +} + # Creates CV folds stratified by the values of y. # It was borrowed from caret::createFolds and simplified # by always returning an unnamed list of fold indices. diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index edbc267c1..15f6faed0 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) { #' Get a new DMatrix containing the specified rows of #' original xgb.DMatrix object #' -#' @param object Object of class "xgb.DMatrix" -#' @param idxset a integer vector of indices of rows needed +#' @param object Object of class "xgb.DMatrix". +#' @param idxset An integer vector of indices of rows needed (base-1 indexing). +#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or +#' equivalently `qid`) field. Note that in such case, the result will not have +#' the groups anymore - they need to be set manually through `setinfo`. #' @param colset currently not used (columns subsetting is not available) #' #' @examples @@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) { #' #' @rdname xgb.slice.DMatrix #' @export -xgb.slice.DMatrix <- function(object, idxset) { +xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) { if (!inherits(object, "xgb.DMatrix")) { stop("object must be xgb.DMatrix") } - ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset) + ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups) attr_list <- attributes(object) nr <- nrow(object) @@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) { } } } - return(structure(ret, class = "xgb.DMatrix")) + + out <- structure(ret, class = "xgb.DMatrix") + parent_fields <- as.list(attributes(object)$fields) + if (NROW(parent_fields)) { + child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))] + child_fields <- as.environment(child_fields) + attributes(out)$fields <- child_fields + } + return(out) } #' @rdname xgb.slice.DMatrix @@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) { } cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ') - infos <- character(0) - if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label' - if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight') - if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin') - if (length(infos) == 0) infos <- 'NA' + infos <- names(attributes(x)$fields) + infos <- infos[infos != "feature_name"] + if (!NROW(infos)) infos <- "NA" + infos <- infos[order(infos)] + infos <- paste(infos, collapse = ", ") cat(infos) cnames <- colnames(x) cat(' colnames:') diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 1cafd7be7..880fd5697 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -1,6 +1,6 @@ #' Cross Validation #' -#' The cross validation function of xgboost +#' The cross validation function of xgboost. #' #' @param params the list of parameters. The complete list of parameters is #' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below @@ -19,13 +19,17 @@ #' #' See \code{\link{xgb.train}} for further details. #' See also demo/ for walkthrough example in R. -#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input. +#' +#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if +#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG +#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand. +#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required +#' for model training by the objective. +#' +#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix` +#' or `xgb.ExternalDMatrix` are not supported here. #' @param nrounds the max number of iterations #' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. -#' @param label vector of response values. Should be provided only when data is an R-matrix. -#' @param missing is only used when input is a dense matrix. By default is set to NA, which means -#' that NA values should be considered as 'missing' by the algorithm. -#' Sometimes, 0 or other extreme value might be used to represent missing values. #' @param prediction A logical value indicating whether to return the test fold predictions #' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback. #' @param showsd \code{boolean}, whether to show standard deviation of cross validation @@ -47,13 +51,30 @@ #' @param feval customized evaluation function. Returns #' \code{list(metric='metric-name', value='metric-value')} with given #' prediction and dtrain. -#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified -#' by the values of outcome labels. +#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified +#' by the values of outcome labels. For real-valued labels in regression objectives, +#' stratification will be done by discretizing the labels into up to 5 buckets beforehand. +#' +#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification +#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to +#' `FALSE` otherwise. +#' +#' This parameter is ignored when `data` has a `group` field - in such case, the splitting +#' will be based on whole groups (note that this might make the folds have different sizes). +#' +#' Value `TRUE` here is \bold{not} supported for custom objectives. #' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds #' (each element must be a vector of test fold's indices). When folds are supplied, #' the \code{nfold} and \code{stratified} parameters are ignored. +#' +#' If `data` has a `group` field and the objective requires this field, each fold (list element) +#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test` +#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to +#' the resulting DMatrices. #' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL} #' (the default) all indices not specified in \code{folds} will be used for training. +#' +#' This is not supported when `data` has `group` field. #' @param verbose \code{boolean}, print the statistics during the process #' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}. #' Default is 1 which means all messages are printed. This parameter is passed to the @@ -118,13 +139,14 @@ #' print(cv, verbose=TRUE) #' #' @export -xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA, +xgb.cv <- function(params = list(), data, nrounds, nfold, prediction = FALSE, showsd = TRUE, metrics = list(), - obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL, + obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL, verbose = TRUE, print_every_n = 1L, early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) { check.deprecation(...) + stopifnot(inherits(data, "xgb.DMatrix")) if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) { stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.") } @@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing check.custom.obj() check.custom.eval() - # Check the labels - if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) || - (!inherits(data, 'xgb.DMatrix') && is.null(label))) { - stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") - } else if (inherits(data, 'xgb.DMatrix')) { - if (!is.null(label)) - warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix") - cv_label <- getinfo(data, 'label') - } else { - cv_label <- label + if (stratified == "auto") { + if (is.character(params$objective)) { + stratified <- ( + (params$objective %in% .CLASSIFICATION_OBJECTIVES()) + && !(params$objective %in% .RANKING_OBJECTIVES()) + ) + } else { + stratified <- FALSE + } + } + + # Check the labels and groups + cv_label <- getinfo(data, "label") + cv_group <- getinfo(data, "group") + if (!is.null(train_folds) && NROW(cv_group)) { + stop("'train_folds' is not supported for DMatrix object with 'group' field.") } # CV folds @@ -157,7 +185,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing } else { if (nfold <= 1) stop("'nfold' must be > 1") - folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params) + folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params) } # Callbacks @@ -195,20 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing # create the booster-folds # train_folds - dall <- xgb.get.DMatrix( - data = data, - label = label, - missing = missing, - weight = NULL, - nthread = params$nthread - ) + dall <- data bst_folds <- lapply(seq_along(folds), function(k) { - dtest <- xgb.slice.DMatrix(dall, folds[[k]]) + dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE) # code originally contributed by @RolandASc on stackoverflow if (is.null(train_folds)) - dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k])) + dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE) else - dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]]) + dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE) + if (!is.null(attributes(folds[[k]])$group_test)) { + setinfo(dtest, "group", attributes(folds[[k]])$group_test) + setinfo(dtrain, "group", attributes(folds[[k]])$group_train) + } bst <- xgb.Booster( params = params, cachelist = list(dtrain, dtest), @@ -312,8 +338,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing #' @examples #' data(agaricus.train, package='xgboost') #' train <- agaricus.train -#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, -#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") +#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2, +#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") #' print(cv) #' print(cv, verbose=TRUE) #' diff --git a/R-package/man/print.xgb.cv.Rd b/R-package/man/print.xgb.cv.Rd index 05ad61eed..74fc15d01 100644 --- a/R-package/man/print.xgb.cv.Rd +++ b/R-package/man/print.xgb.cv.Rd @@ -23,8 +23,8 @@ including the best iteration (when available). \examples{ data(agaricus.train, package='xgboost') train <- agaricus.train -cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, - eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") +cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2, + eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") print(cv) print(cv, verbose=TRUE) diff --git a/R-package/man/xgb.cv.Rd b/R-package/man/xgb.cv.Rd index 778b4540a..cede67570 100644 --- a/R-package/man/xgb.cv.Rd +++ b/R-package/man/xgb.cv.Rd @@ -9,14 +9,12 @@ xgb.cv( data, nrounds, nfold, - label = NULL, - missing = NA, prediction = FALSE, showsd = TRUE, metrics = list(), obj = NULL, feval = NULL, - stratified = TRUE, + stratified = "auto", folds = NULL, train_folds = NULL, verbose = TRUE, @@ -44,20 +42,23 @@ is a shorter summary: } See \code{\link{xgb.train}} for further details. -See also demo/ for walkthrough example in R.} +See also demo/ for walkthrough example in R. -\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.} +Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if +supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG +system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.} + +\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required +for model training by the objective. + +\if{html}{\out{
}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix` + or `xgb.ExternalDMatrix` are not supported here. +}\if{html}{\out{
}}} \item{nrounds}{the max number of iterations} \item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.} -\item{label}{vector of response values. Should be provided only when data is an R-matrix.} - -\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means -that NA values should be considered as 'missing' by the algorithm. -Sometimes, 0 or other extreme value might be used to represent missing values.} - \item{prediction}{A logical value indicating whether to return the test fold predictions from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.} @@ -84,15 +85,35 @@ gradient with given prediction and dtrain.} \code{list(metric='metric-name', value='metric-value')} with given prediction and dtrain.} -\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified -by the values of outcome labels.} +\item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified +by the values of outcome labels. For real-valued labels in regression objectives, +stratification will be done by discretizing the labels into up to 5 buckets beforehand. + +\if{html}{\out{
}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification + objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to + `FALSE` otherwise. + + This parameter is ignored when `data` has a `group` field - in such case, the splitting + will be based on whole groups (note that this might make the folds have different sizes). + + Value `TRUE` here is \\bold\{not\} supported for custom objectives. +}\if{html}{\out{
}}} \item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds (each element must be a vector of test fold's indices). When folds are supplied, -the \code{nfold} and \code{stratified} parameters are ignored.} +the \code{nfold} and \code{stratified} parameters are ignored. + +\if{html}{\out{
}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element) + must additionally have two attributes (retrievable through \link{attributes}) named `group_test` + and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to + the resulting DMatrices. +}\if{html}{\out{
}}} \item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL} -(the default) all indices not specified in \code{folds} will be used for training.} +(the default) all indices not specified in \code{folds} will be used for training. + +\if{html}{\out{
}}\preformatted{ This is not supported when `data` has `group` field. +}\if{html}{\out{
}}} \item{verbose}{\code{boolean}, print the statistics during the process} @@ -142,7 +163,7 @@ such as saving also the models created during cross validation); or a list \code will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}). } \description{ -The cross validation function of xgboost +The cross validation function of xgboost. } \details{ The original sample is randomly partitioned into \code{nfold} equal size subsamples. diff --git a/R-package/man/xgb.slice.DMatrix.Rd b/R-package/man/xgb.slice.DMatrix.Rd index c9695996b..c4f776594 100644 --- a/R-package/man/xgb.slice.DMatrix.Rd +++ b/R-package/man/xgb.slice.DMatrix.Rd @@ -6,14 +6,18 @@ \title{Get a new DMatrix containing the specified rows of original xgb.DMatrix object} \usage{ -xgb.slice.DMatrix(object, idxset) +xgb.slice.DMatrix(object, idxset, allow_groups = FALSE) \method{[}{xgb.DMatrix}(object, idxset, colset = NULL) } \arguments{ -\item{object}{Object of class "xgb.DMatrix"} +\item{object}{Object of class "xgb.DMatrix".} -\item{idxset}{a integer vector of indices of rows needed} +\item{idxset}{An integer vector of indices of rows needed (base-1 indexing).} + +\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or +equivalently \code{qid}) field. Note that in such case, the result will not have +the groups anymore - they need to be set manually through \code{setinfo}.} \item{colset}{currently not used (columns subsetting is not available)} } diff --git a/R-package/src/init.c b/R-package/src/init.c index c869871c6..5db3218b4 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -71,7 +71,7 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP); extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP); -extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP); +extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP); extern SEXP XGBSetGlobalConfig_R(SEXP); extern SEXP XGBGetGlobalConfig_R(void); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); @@ -134,7 +134,7 @@ static const R_CallMethodDef CallEntries[] = { {"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3}, {"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3}, {"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3}, - {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2}, + {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3}, {"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1}, {"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0}, {"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 2228932bd..cdb9ba65c 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP return ret; } -XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { +XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) { SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); R_API_BEGIN(); R_xlen_t len = Rf_xlength(idxset); @@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle), BeginPtr(idxvec), len, &res, - 0); + Rf_asLogical(allow_groups)); } CHECK_CALL(res_code); R_SetExternalPtrAddr(ret, res); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index cea50c146..62be5022a 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP * \brief create a new dmatrix from sliced content of existing matrix * \param handle instance of data matrix to be sliced * \param idxset index set + * \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field * \return a sliced new matrix */ -XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset); +XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups); /*! * \brief load a data matrix into binary file diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 18a3b99e6..bbb8fb323 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -334,7 +334,7 @@ test_that("xgb.cv works", { set.seed(11) expect_output( cv <- xgb.cv( - data = train$data, label = train$label, max_depth = 2, nfold = 5, + data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5, eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", eval_metric = "error", verbose = TRUE ), @@ -357,13 +357,13 @@ test_that("xgb.cv works with stratified folds", { cv <- xgb.cv( data = dtrain, max_depth = 2, nfold = 5, eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", - verbose = TRUE, stratified = FALSE + verbose = FALSE, stratified = FALSE ) set.seed(314159) cv2 <- xgb.cv( data = dtrain, max_depth = 2, nfold = 5, eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", - verbose = TRUE, stratified = TRUE + verbose = FALSE, stratified = TRUE ) # Stratified folds should result in a different evaluation logs expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean])) @@ -885,3 +885,57 @@ test_that("Seed in params override PRNG from R", { ) ) }) + +test_that("xgb.cv works for AFT", { + X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix + dtrain <- xgb.DMatrix(X, nthread = n_threads) + + params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L) + + # data must have bounds + expect_error( + xgb.cv( + params = params, + data = dtrain, + nround = 5L, + nfold = 4L, + nthread = n_threads + ) + ) + + setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4)) + setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5)) + + # automatic stratified splitting is turned off + expect_warning( + xgb.cv( + params = params, data = dtrain, nround = 5L, nfold = 4L, + nthread = n_threads, stratified = TRUE, verbose = FALSE + ) + ) + + # this works without any issue + expect_no_warning( + xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE) + ) +}) + +test_that("xgb.cv works for ranking", { + data(iris) + x <- iris[, -(4:5)] + y <- as.integer(iris$Petal.Width) + group <- rep(50, 3) + dm <- xgb.DMatrix(x, label = y, group = group) + res <- xgb.cv( + data = dm, + params = list( + objective = "rank:pairwise", + max_depth = 3 + ), + nrounds = 3, + nfold = 2, + verbose = FALSE, + stratified = FALSE + ) + expect_equal(length(res$folds), 2L) +}) diff --git a/R-package/tests/testthat/test_callbacks.R b/R-package/tests/testthat/test_callbacks.R index 913791de4..bf95a170d 100644 --- a/R-package/tests/testthat/test_callbacks.R +++ b/R-package/tests/testthat/test_callbacks.R @@ -367,7 +367,7 @@ test_that("prediction in early-stopping xgb.cv works", { expect_output( cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20, early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE, - prediction = TRUE, base_score = 0.5) + prediction = TRUE, base_score = 0.5, verbose = TRUE) , "Stopping. Best iteration") expect_false(is.null(cv$early_stop$best_iteration)) @@ -387,7 +387,7 @@ test_that("prediction in xgb.cv for softprob works", { lb <- as.numeric(iris$Species) - 1 set.seed(11) expect_warning( - cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4, + cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4, eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads, subsample = 0.8, gamma = 2, verbose = 0, prediction = TRUE, objective = "multi:softprob", num_class = 3) diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 44d1566c6..548afece3 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", { txt <- capture.output({ print(dtrain) }) - expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes") + expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes") # DMatrix with just features dtrain <- xgb.DMatrix( @@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", { ) }) +test_that("xgb.DMatrix: slicing keeps field indicators", { + data(mtcars) + x <- as.matrix(mtcars[, -1]) + y <- mtcars[, 1] + dm <- xgb.DMatrix( + data = x, + label_lower_bound = -y, + label_upper_bound = y, + nthread = 1 + ) + idx_take <- seq(1, 5) + dm_slice <- xgb.slice.DMatrix(dm, idx_take) + + expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound")) + expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound")) + expect_false(xgb.DMatrix.hasinfo(dm_slice, "label")) + + expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6) + expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6) +}) + +test_that("xgb.DMatrix: can slice with groups", { + data(iris) + x <- as.matrix(iris[, -5]) + set.seed(123) + y <- sample(3, size = nrow(x), replace = TRUE) + group <- c(50, 50, 50) + dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1) + idx_take <- seq(1, 50) + dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE) + + expect_true(xgb.DMatrix.hasinfo(dm_slice, "label")) + expect_false(xgb.DMatrix.hasinfo(dm_slice, "group")) + expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid")) + expect_null(getinfo(dm_slice, "group")) + expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6) +}) + test_that("xgb.DMatrix: can read CSV", { txt <- paste( "1,2,3", From e15d61b916cdb29815bc53497fa4949a7e988b56 Mon Sep 17 00:00:00 2001 From: Fabi <117525608+fabfabi@users.noreply.github.com> Date: Mon, 1 Apr 2024 04:14:40 +0200 Subject: [PATCH 05/26] docs: fix bug in tutorial (#10143) --- doc/tutorials/learning_to_rank.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst index bfc727ed7..15a611bd0 100644 --- a/doc/tutorials/learning_to_rank.rst +++ b/doc/tutorials/learning_to_rank.rst @@ -52,7 +52,7 @@ Notice that the samples are sorted based on their query index in a non-decreasin X, y = make_classification(random_state=seed) rng = np.random.default_rng(seed) n_query_groups = 3 - qid = rng.integers(0, 3, size=X.shape[0]) + qid = rng.integers(0, n_query_groups, size=X.shape[0]) # Sort the inputs based on query index sorted_idx = np.argsort(qid) @@ -65,14 +65,14 @@ The simplest way to train a ranking model is by using the scikit-learn estimator .. code-block:: python ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk") - ranker.fit(X, y, qid=qid) + ranker.fit(X, y, qid=qid[sorted_idx]) Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows: .. code-block:: python df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])]) - df["qid"] = qid + df["qid"] = qid[sorted_idx] ranker.fit(df, y) # No need to pass qid as a separate argument from sklearn.model_selection import StratifiedGroupKFold, cross_val_score From a99bb38bd2762e35e6a1673a0c11e09eddd8e723 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 3 Apr 2024 16:45:54 -0700 Subject: [PATCH 06/26] Bump org.apache.maven.plugins:maven-gpg-plugin from 3.1.0 to 3.2.2 in /jvm-packages/xgboost4j-spark (#10151) --- jvm-packages/pom.xml | 2 +- tests/buildkite/conftest.sh | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 23ab70734..f9266b854 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -152,7 +152,7 @@ org.apache.maven.plugins maven-gpg-plugin - 3.1.0 + 3.2.2 sign-artifacts diff --git a/tests/buildkite/conftest.sh b/tests/buildkite/conftest.sh index c6e8ef65a..0d051001f 100755 --- a/tests/buildkite/conftest.sh +++ b/tests/buildkite/conftest.sh @@ -39,13 +39,14 @@ fi if [[ -n $BUILDKITE_PULL_REQUEST && $BUILDKITE_PULL_REQUEST != "false" ]] then is_pull_request=1 - export BRANCH_NAME=PR-$BUILDKITE_PULL_REQUEST + BRANCH_NAME=PR-$BUILDKITE_PULL_REQUEST else is_pull_request=0 - export BRANCH_NAME=$BUILDKITE_BRANCH + BRANCH_NAME=$BUILDKITE_BRANCH fi +export BRANCH_NAME=${BRANCH_NAME//\//-} -if [[ $BUILDKITE_BRANCH == "master" || $BUILDKITE_BRANCH == "release_"* ]] +if [[ $BRANCH_NAME == "master" || $BRANCH_NAME == "release_"* ]] then is_release_branch=1 enforce_daily_budget=0 From f0a138f33a6ee578c122554419fdc8593d98c487 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 9 Apr 2024 23:18:56 +0800 Subject: [PATCH 07/26] Fix pyspark with verbosity=3. (#10172) --- python-package/xgboost/spark/core.py | 20 +++++++++++--------- src/common/device_helpers.cuh | 14 ++++++++------ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index e44182cb3..741adcb03 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1052,6 +1052,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): dev_ordinal = None use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) + verbosity = booster_params.get("verbosity", 1) msg = "Training on CPUs" if run_on_gpu: dev_ordinal = ( @@ -1089,15 +1090,16 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): evals_result: Dict[str, Any] = {} with CommunicatorContext(context, **_rabit_args): - dtrain, dvalid = create_dmatrix_from_partitions( - pandas_df_iter, - feature_prop.features_cols_names, - dev_ordinal, - use_qdm, - dmatrix_kwargs, - enable_sparse_data_optim=feature_prop.enable_sparse_data_optim, - has_validation_col=feature_prop.has_validation_col, - ) + with xgboost.config_context(verbosity=verbosity): + dtrain, dvalid = create_dmatrix_from_partitions( + pandas_df_iter, + feature_prop.features_cols_names, + dev_ordinal, + use_qdm, + dmatrix_kwargs, + enable_sparse_data_optim=feature_prop.enable_sparse_data_optim, + has_validation_col=feature_prop.has_validation_col, + ) if dvalid is not None: dval = [(dtrain, "training"), (dvalid, "validation")] else: diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 026fbacf2..98b83bae0 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -307,13 +307,15 @@ class MemoryLogger { void RegisterDeallocation(void *ptr, size_t n, int current_device) { auto itr = device_allocations.find(ptr); if (itr == device_allocations.end()) { - LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " - << current_device << " that was never allocated "; + LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device + << " that was never allocated\n" + << dmlc::StackTrace(); + } else { + num_deallocations++; + CHECK_LE(num_deallocations, num_allocations); + currently_allocated_bytes -= itr->second; + device_allocations.erase(itr); } - num_deallocations++; - CHECK_LE(num_deallocations, num_allocations); - currently_allocated_bytes -= itr->second; - device_allocations.erase(itr); } }; DeviceStats stats_; From 1022909bbe4546777613ea04d110479196936225 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 11 Apr 2024 01:29:28 +0800 Subject: [PATCH 08/26] Fix global config for external memory. (#10173) Pass the thread-local configuration between threads. --- src/common/device_helpers.cuh | 3 +-- src/common/timer.cc | 8 ++++--- src/data/sparse_page_source.h | 39 ++++++++++++++++++++--------------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 98b83bae0..9223302aa 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -299,8 +299,7 @@ class MemoryLogger { void RegisterAllocation(void *ptr, size_t n) { device_allocations[ptr] = n; currently_allocated_bytes += n; - peak_allocated_bytes = - std::max(peak_allocated_bytes, currently_allocated_bytes); + peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes); num_allocations++; CHECK_GT(num_allocations, num_deallocations); } diff --git a/src/common/timer.cc b/src/common/timer.cc index 99150aa26..2eccc67cd 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -1,9 +1,8 @@ -/*! - * Copyright by Contributors 2019 +/** + * Copyright 2019-2024, XGBoost Contributors */ #include "timer.h" -#include #include #include "../collective/communicator-inl.h" @@ -61,6 +60,9 @@ void Monitor::Print() const { kv.second.timer.elapsed) .count()); } + if (stat_map.empty()) { + return; + } LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========"; this->PrintStatistics(stat_map); } diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 9cb0e364f..60129741b 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file sparse_page_source.h */ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ @@ -7,23 +7,26 @@ #include // for min #include // for atomic +#include // for remove #include // for async -#include -#include -#include // for mutex -#include -#include -#include // for pair, move -#include +#include // for unique_ptr +#include // for mutex +#include // for string +#include // for pair, move +#include // for vector -#include "../common/common.h" -#include "../common/io.h" // for PrivateMmapConstStream -#include "../common/timer.h" // for Monitor, Timer -#include "adapter.h" -#include "proxy_dmatrix.h" // for DMatrixProxy -#include "sparse_page_writer.h" // for SparsePageFormat -#include "xgboost/base.h" -#include "xgboost/data.h" +#if !defined(XGBOOST_USE_CUDA) +#include "../common/common.h" // for AssertGPUSupport +#endif // !defined(XGBOOST_USE_CUDA) + +#include "../common/io.h" // for PrivateMmapConstStream +#include "../common/timer.h" // for Monitor, Timer +#include "proxy_dmatrix.h" // for DMatrixProxy +#include "sparse_page_writer.h" // for SparsePageFormat +#include "xgboost/base.h" // for bst_feature_t +#include "xgboost/data.h" // for SparsePage, CSCPage +#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore +#include "xgboost/logging.h" // for CHECK_EQ namespace xgboost::data { inline void TryDeleteCacheFile(const std::string& file) { @@ -185,6 +188,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { exce_.Rethrow(); + auto const config = *GlobalConfigThreadLocalStore::Get(); for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { @@ -192,7 +196,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { } auto const* self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); - ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() { + ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() { + *GlobalConfigThreadLocalStore::Get() = config; auto page = std::make_shared(); this->exce_.Run([&] { std::unique_ptr> fmt{CreatePageFormat("raw")}; From 732e27cebc0c4085d0d9cc57f278c531cf68a8b0 Mon Sep 17 00:00:00 2001 From: Trinh Quoc Anh Date: Fri, 12 Apr 2024 19:10:50 +0200 Subject: [PATCH 09/26] [doc] Update python3statement URL (#10179) --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 43019d877..b067c8e3c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2101,7 +2101,7 @@ This release marks a major milestone for the XGBoost project. ## v0.90 (2019.05.18) ### XGBoost Python package drops Python 2.x (#4379, #4381) -Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.org/). +Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.github.io/). ### XGBoost4J-Spark now requires Spark 2.4.x (#4377) * Spark 2.3 is reaching its end-of-life soon. See discussion at #4389. From 882f4136e09b539edf8a7cb0a7e39ac57c0389af Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 13 Apr 2024 13:52:12 -0700 Subject: [PATCH 10/26] [CI] Update create-pull-request action --- .github/workflows/update_rapids.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/update_rapids.yml b/.github/workflows/update_rapids.yml index 395a42148..22a395799 100644 --- a/.github/workflows/update_rapids.yml +++ b/.github/workflows/update_rapids.yml @@ -3,7 +3,7 @@ name: update-rapids on: workflow_dispatch: schedule: - - cron: "0 20 * * *" # Run once daily + - cron: "0 20 * * 1" # Run once weekly permissions: pull-requests: write @@ -32,7 +32,7 @@ jobs: run: | bash tests/buildkite/update-rapids.sh - name: Create Pull Request - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 if: github.ref == 'refs/heads/master' with: add-paths: | From 6e5c335ceaadf2f4c5b61b115de64ea93dbbe6e6 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Mon, 15 Apr 2024 15:24:46 +0200 Subject: [PATCH 11/26] [SYCL] Add basic features for QuantileHistMaker (#10174) --------- Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/tree/updater_quantile_hist.cc | 55 +++++++++++ plugin/sycl/tree/updater_quantile_hist.h | 91 +++++++++++++++++++ .../plugin/test_sycl_quantile_hist_builder.cc | 55 +++++++++++ 3 files changed, 201 insertions(+) create mode 100644 plugin/sycl/tree/updater_quantile_hist.cc create mode 100644 plugin/sycl/tree/updater_quantile_hist.h create mode 100644 tests/cpp/plugin/test_sycl_quantile_hist_builder.cc diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc new file mode 100644 index 000000000..98a42c3c8 --- /dev/null +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -0,0 +1,55 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file updater_quantile_hist.cc + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "xgboost/tree_updater.h" +#pragma GCC diagnostic pop + +#include "xgboost/logging.h" + +#include "updater_quantile_hist.h" +#include "../data.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl); + +DMLC_REGISTER_PARAMETER(HistMakerTrainParam); + +void QuantileHistMaker::Configure(const Args& args) { + const DeviceOrd device_spec = ctx_->Device(); + qu_ = device_manager.GetQueue(device_spec); + + param_.UpdateAllowUnknown(args); + hist_maker_param_.UpdateAllowUnknown(args); +} + +void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, + linalg::Matrix* gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees) { + LOG(FATAL) << "Not Implemented yet"; +} + +bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) { + LOG(FATAL) << "Not Implemented yet"; +} + +XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") +.describe("Grow tree using quantized histogram with SYCL.") +.set_body( + [](Context const* ctx, ObjInfo const * task) { + return new QuantileHistMaker(ctx, task); + }); +} // namespace tree +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h new file mode 100644 index 000000000..93a50de3e --- /dev/null +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -0,0 +1,91 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file updater_quantile_hist.h + */ +#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ +#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ + +#include +#include + +#include + +#include "../data/gradient_index.h" +#include "../common/hist_util.h" +#include "../common/row_set.h" +#include "../common/partition_builder.h" +#include "split_evaluator.h" +#include "../device_manager.h" + +#include "xgboost/data.h" +#include "xgboost/json.h" +#include "../../src/tree/constraints.h" +#include "../../src/common/random.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +// training parameters specific to this algorithm +struct HistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram = false; + // declare parameters + DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; + +/*! \brief construct a tree using quantized feature values with SYCL backend*/ +class QuantileHistMaker: public TreeUpdater { + public: + QuantileHistMaker(Context const* ctx, ObjInfo const * task) : + TreeUpdater(ctx), task_{task} { + updater_monitor_.Init("SYCLQuantileHistMaker"); + } + void Configure(const Args& args) override; + + void Update(xgboost::tree::TrainParam const *param, + linalg::Matrix* gpair, + DMatrix* dmat, + xgboost::common::Span> out_position, + const std::vector& trees) override; + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) override; + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_); + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + out["sycl_hist_train_param"] = ToJson(hist_maker_param_); + } + + char const* Name() const override { + return "grow_quantile_histmaker_sycl"; + } + + protected: + HistMakerTrainParam hist_maker_param_; + // training parameter + xgboost::tree::TrainParam param_; + + xgboost::common::Monitor updater_monitor_; + + ::sycl::queue qu_; + DeviceManager device_manager; + ObjInfo const *task_{nullptr}; +}; + + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ diff --git a/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc b/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc new file mode 100644 index 000000000..4bf7bd962 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#include +#include "../../../plugin/sycl/tree/updater_quantile_hist.h" // for QuantileHistMaker +#pragma GCC diagnostic pop + +namespace xgboost::sycl::tree { +TEST(SyclQuantileHistMaker, Basic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + + ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl"); +} + +TEST(SyclQuantileHistMaker, JsonIO) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + ObjInfo task{ObjInfo::kRegression}; + Json config {Object()}; + { + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + updater->Configure({{"max_depth", std::to_string(42)}}); + updater->Configure({{"single_precision_histogram", std::to_string(true)}}); + updater->SaveConfig(&config); + } + + { + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + updater->LoadConfig(config); + + Json new_config {Object()}; + updater->SaveConfig(&new_config); + + ASSERT_EQ(config, new_config); + + auto max_depth = atoi(get(new_config["train_param"]["max_depth"]).c_str()); + ASSERT_EQ(max_depth, 42); + + auto single_precision_histogram = atoi(get(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str()); + ASSERT_EQ(single_precision_histogram, 1); + } + +} +} // namespace xgboost::sycl::tree From 2925cebdca37e43a3f18422c78e05fa88ede7244 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:38:53 -0700 Subject: [PATCH 12/26] [CI] Use latest RAPIDS; Pandas 2.0 compatibility fix (#10175) * [CI] Update RAPIDS to latest stable * [CI] Use rapidsai stable channel; fix syntax errors in Dockerfile.gpu * Don't combine astype() with loc() * Work around https://github.com/dmlc/xgboost/issues/10181 * Fix formatting * Fix test --------- Co-authored-by: hcho3 Co-authored-by: Hyunsu Cho --- python-package/xgboost/data.py | 12 +++++++++++- python-package/xgboost/testing/__init__.py | 4 ++-- tests/buildkite/conftest.sh | 2 +- tests/ci_build/Dockerfile.gpu | 8 ++++---- tests/python-gpu/test_from_cudf.py | 9 --------- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 07a08dc5f..12b576566 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -909,9 +909,19 @@ def _transform_cudf_df( enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]: try: - from cudf.api.types import is_categorical_dtype + from cudf.api.types import is_bool_dtype, is_categorical_dtype except ImportError: from cudf.utils.dtypes import is_categorical_dtype + from pandas.api.types import is_bool_dtype + + # Work around https://github.com/dmlc/xgboost/issues/10181 + if _is_cudf_ser(data): + if is_bool_dtype(data.dtype): + data = data.astype(np.uint8) + else: + data = data.astype( + {col: np.uint8 for col in data.select_dtypes(include="bool")} + ) if _is_cudf_ser(data): dtypes = [data.dtype] diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 409fd0274..f7d9510fa 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -429,8 +429,8 @@ def make_categorical( categories = np.arange(0, n_categories) for col in df.columns: if rng.binomial(1, cat_ratio, size=1)[0] == 1: - df.loc[:, col] = df[col].astype("category") - df.loc[:, col] = df[col].cat.set_categories(categories) + df[col] = df[col].astype("category") + df[col] = df[col].cat.set_categories(categories) if sparsity > 0.0: for i in range(n_features): diff --git a/tests/buildkite/conftest.sh b/tests/buildkite/conftest.sh index 0d051001f..44043910b 100755 --- a/tests/buildkite/conftest.sh +++ b/tests/buildkite/conftest.sh @@ -24,7 +24,7 @@ set -x CUDA_VERSION=11.8.0 NCCL_VERSION=2.16.5-1 -RAPIDS_VERSION=24.02 +RAPIDS_VERSION=24.04 SPARK_VERSION=3.4.0 JDK_VERSION=8 R_VERSION=4.3.2 diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 698a61e93..255dd9d71 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -21,14 +21,14 @@ ENV PATH=/opt/mambaforge/bin:$PATH # Create new Conda environment with cuDF, Dask, and cuPy RUN \ - conda install -c conda-forge mamba && \ - mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ + export NCCL_SHORT_VER=$(echo "$NCCL_VERSION_ARG" | cut -d "-" -f 1) && \ + mamba create -y -n gpu_test -c rapidsai -c nvidia -c conda-forge \ python=3.10 cudf=$RAPIDS_VERSION_ARG* rmm=$RAPIDS_VERSION_ARG* cudatoolkit=$CUDA_VERSION_ARG \ - nccl>=$(cut -d "-" -f 1 << $NCCL_VERSION_ARG) \ + "nccl>=${NCCL_SHORT_VER}" \ dask=2024.1.1 \ dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \ numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \ - pyspark>=3.4.0 cloudpickle cuda-python && \ + "pyspark>=3.4.0" cloudpickle cuda-python && \ mamba clean --all && \ conda run --no-capture-output -n gpu_test pip install buildkite-test-collector diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 8707af0c8..c3a0b7d5f 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -71,15 +71,6 @@ def _test_from_cudf(DMatrixT): assert dtrain.num_col() == 1 assert dtrain.num_row() == 5 - # Boolean is not supported. - X_boolean = cudf.DataFrame({"x": cudf.Series([True, False])}) - with pytest.raises(Exception): - dtrain = DMatrixT(X_boolean) - - y_boolean = cudf.DataFrame({"x": cudf.Series([True, False, True, True, True])}) - with pytest.raises(Exception): - dtrain = DMatrixT(X_boolean, label=y_boolean) - def _test_cudf_training(DMatrixT): import pandas as pd From 9e354fb120863a558ca4cd374e6e82efde777932 Mon Sep 17 00:00:00 2001 From: Eric Leung <2754821+erictleung@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:09:59 -0400 Subject: [PATCH 13/26] docs: update Ruby package link (#10182) --- doc/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/index.rst b/doc/index.rst index a2ae9bbd3..7b241c0a1 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -28,7 +28,7 @@ Contents Python Package R Package JVM Package - Ruby Package + Ruby Package Swift Package Julia Package C Package From 3d1d97c8ccdfa9ca65577379000f18ab289ae4a4 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 16 Apr 2024 19:42:08 -0700 Subject: [PATCH 14/26] [CI] Reduce clutter from dependabot (#10187) --- .github/dependabot.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index c03a52c60..0cc0c16fd 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -8,7 +8,7 @@ updates: - package-ecosystem: "maven" directory: "/jvm-packages" schedule: - interval: "daily" + interval: "monthly" - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j" schedule: @@ -16,11 +16,11 @@ updates: - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-gpu" schedule: - interval: "daily" + interval: "monthly" - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-example" schedule: - interval: "daily" + interval: "monthly" - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-spark" schedule: @@ -28,4 +28,4 @@ updates: - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-spark-gpu" schedule: - interval: "daily" + interval: "monthly" From 32be4669fbaa886e081cf9feb546e7d38f966baf Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 16 Apr 2024 19:43:04 -0700 Subject: [PATCH 15/26] [jvm-packages] Ombinus patch to update all minor dependencies (#10188) * Fold in #10184 * Fold in #10176 * Fold in #10168 * Fold in #10165 * Fold in #10164 * Fold in #10155 * Fold in #10062 * Fold in #9984 * Fold in #9843 * Upgrade to Maven 3.6.3 --- jvm-packages/pom.xml | 14 +++++++------- jvm-packages/xgboost4j-gpu/pom.xml | 4 ++-- jvm-packages/xgboost4j-tester/generate_pom.py | 2 +- jvm-packages/xgboost4j/pom.xml | 4 ++-- tests/ci_build/Dockerfile.jvm | 6 +++--- tests/ci_build/Dockerfile.jvm_cross | 6 +++--- tests/ci_build/Dockerfile.jvm_gpu_build | 6 +++--- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index f9266b854..f34deefd2 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -123,7 +123,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.3.0 + 3.4.0 empty-javadoc-jar @@ -152,7 +152,7 @@ org.apache.maven.plugins maven-gpg-plugin - 3.2.2 + 3.2.3 sign-artifacts @@ -166,7 +166,7 @@ org.apache.maven.plugins maven-source-plugin - 3.3.0 + 3.3.1 attach-sources @@ -204,7 +204,7 @@ org.apache.maven.plugins maven-assembly-plugin - 3.6.0 + 3.7.1 jar-with-dependencies @@ -445,7 +445,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.2.2 + 3.2.5 false false @@ -487,12 +487,12 @@ com.esotericsoftware kryo - 5.5.0 + 5.6.0 commons-logging commons-logging - 1.3.0 + 1.3.1 org.scalatest diff --git a/jvm-packages/xgboost4j-gpu/pom.xml b/jvm-packages/xgboost4j-gpu/pom.xml index fc55dd156..bd0e01213 100644 --- a/jvm-packages/xgboost4j-gpu/pom.xml +++ b/jvm-packages/xgboost4j-gpu/pom.xml @@ -72,7 +72,7 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.6.2 + 3.6.3 protected true @@ -88,7 +88,7 @@ exec-maven-plugin org.codehaus.mojo - 3.1.0 + 3.2.0 native diff --git a/jvm-packages/xgboost4j-tester/generate_pom.py b/jvm-packages/xgboost4j-tester/generate_pom.py index b9c274c28..eb7cf94b3 100644 --- a/jvm-packages/xgboost4j-tester/generate_pom.py +++ b/jvm-packages/xgboost4j-tester/generate_pom.py @@ -22,7 +22,7 @@ pom_template = """ {scala_version} 3.2.15 {scala_binary_version} - 5.5.0 + 5.6.0 diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 7eb186919..4ef55ae2c 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -60,7 +60,7 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.6.2 + 3.6.3 protected true @@ -76,7 +76,7 @@ exec-maven-plugin org.codehaus.mojo - 3.1.0 + 3.2.0 native diff --git a/tests/ci_build/Dockerfile.jvm b/tests/ci_build/Dockerfile.jvm index 43fbd8ff5..a115fd52c 100644 --- a/tests/ci_build/Dockerfile.jvm +++ b/tests/ci_build/Dockerfile.jvm @@ -15,9 +15,9 @@ RUN \ wget -nv -nc https://cmake.org/files/v3.18/cmake-3.18.0-Linux-x86_64.sh --no-check-certificate && \ bash cmake-3.18.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ # Maven - wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ - tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ - ln -s /opt/apache-maven-3.6.1/ /opt/maven + wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz && \ + tar xvf apache-maven-3.6.3-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.3/ /opt/maven ENV PATH=/opt/mambaforge/bin:/opt/maven/bin:$PATH ENV CC=/opt/rh/devtoolset-9/root/usr/bin/gcc diff --git a/tests/ci_build/Dockerfile.jvm_cross b/tests/ci_build/Dockerfile.jvm_cross index fdfae310a..5c4bb569b 100644 --- a/tests/ci_build/Dockerfile.jvm_cross +++ b/tests/ci_build/Dockerfile.jvm_cross @@ -17,9 +17,9 @@ RUN \ bash conda.sh -b -p /opt/mambaforge && \ /opt/mambaforge/bin/pip install awscli && \ # Maven - wget -nv https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ - tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ - ln -s /opt/apache-maven-3.6.1/ /opt/maven && \ + wget -nv https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz && \ + tar xvf apache-maven-3.6.3-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.3/ /opt/maven && \ # Spark with scala 2.12 mkdir -p /opt/spark-scala-2.12 && \ wget -nv https://archive.apache.org/dist/spark/spark-$SPARK_VERSION/spark-$SPARK_VERSION-bin-hadoop3.tgz && \ diff --git a/tests/ci_build/Dockerfile.jvm_gpu_build b/tests/ci_build/Dockerfile.jvm_gpu_build index 86ce7e72a..cee418942 100644 --- a/tests/ci_build/Dockerfile.jvm_gpu_build +++ b/tests/ci_build/Dockerfile.jvm_gpu_build @@ -18,9 +18,9 @@ RUN \ wget -nv -nc https://cmake.org/files/v3.18/cmake-3.18.0-Linux-x86_64.sh --no-check-certificate && \ bash cmake-3.18.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ # Maven - wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ - tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ - ln -s /opt/apache-maven-3.6.1/ /opt/maven + wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz && \ + tar xvf apache-maven-3.6.3-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.3/ /opt/maven # NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) RUN \ From 7c0c9677a98e021087a19d2978377215266b4cb0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 19:52:33 -0700 Subject: [PATCH 16/26] Bump org.apache.maven.plugins:maven-jar-plugin (#10191) Bumps [org.apache.maven.plugins:maven-jar-plugin](https://github.com/apache/maven-jar-plugin) from 3.3.0 to 3.4.0. - [Release notes](https://github.com/apache/maven-jar-plugin/releases) - [Commits](https://github.com/apache/maven-jar-plugin/compare/maven-jar-plugin-3.3.0...maven-jar-plugin-3.4.0) --- updated-dependencies: - dependency-name: org.apache.maven.plugins:maven-jar-plugin dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- jvm-packages/xgboost4j-gpu/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-gpu/pom.xml b/jvm-packages/xgboost4j-gpu/pom.xml index bd0e01213..25b44d6b2 100644 --- a/jvm-packages/xgboost4j-gpu/pom.xml +++ b/jvm-packages/xgboost4j-gpu/pom.xml @@ -113,7 +113,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.3.0 + 3.4.0 From 4b102004565b16c4f08f532f71291c1dd19bda09 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 18 Apr 2024 03:29:52 +0800 Subject: [PATCH 17/26] [coll] Improve event loop. (#10199) - Add a test for blocking calls. - Do not require the queue to be empty after waking up; this frees up the thread to answer blocking calls. - Handle EOF in read. - Improve the error message in the result. Allow concatenation of multiple results. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + demo/dask/cpu_training.py | 2 +- doc/contrib/unit_tests.rst | 8 +++ include/xgboost/collective/result.h | 107 ++++++++++++---------------- include/xgboost/collective/socket.h | 32 ++++++--- src/collective/loop.cc | 94 +++++++++++++++++------- src/collective/loop.h | 20 +++--- src/collective/result.cc | 86 ++++++++++++++++++++++ tests/cpp/collective/test_loop.cc | 41 ++++++++--- tests/cpp/collective/test_result.cc | 31 ++++++++ 11 files changed, 312 insertions(+), 111 deletions(-) create mode 100644 src/collective/result.cc create mode 100644 tests/cpp/collective/test_result.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 0f4b3ac6f..99241249f 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/result.o \ $(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 0c2084de9..fc2cd3b9f 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/result.o \ $(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ diff --git a/demo/dask/cpu_training.py b/demo/dask/cpu_training.py index 2bee444f7..7117eddd9 100644 --- a/demo/dask/cpu_training.py +++ b/demo/dask/cpu_training.py @@ -40,7 +40,7 @@ def main(client): # you can pass output directly into `predict` too. prediction = dxgb.predict(client, bst, dtrain) print("Evaluation history:", history) - return prediction + print("Error:", da.sqrt((prediction - y) ** 2).mean().compute()) if __name__ == "__main__": diff --git a/doc/contrib/unit_tests.rst b/doc/contrib/unit_tests.rst index 662a632e2..908e5ed99 100644 --- a/doc/contrib/unit_tests.rst +++ b/doc/contrib/unit_tests.rst @@ -144,6 +144,14 @@ which provides higher flexibility. For example: ctest --verbose +If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`: + +.. code-block:: + + ::testing::GTEST_FLAG(filter) = "Suite.Test"; + ::testing::GTEST_FLAG(repeat) = 10; + + *********************************************** Sanitizers: Detect memory errors and data races *********************************************** diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h index 919d3a902..23e70a8e6 100644 --- a/include/xgboost/collective/result.h +++ b/include/xgboost/collective/result.h @@ -3,13 +3,11 @@ */ #pragma once -#include - -#include // for unique_ptr -#include // for stringstream -#include // for stack -#include // for string -#include // for move +#include // for int32_t +#include // for unique_ptr +#include // for string +#include // for error_code +#include // for move namespace xgboost::collective { namespace detail { @@ -48,48 +46,19 @@ struct ResultImpl { return cur_eq; } - [[nodiscard]] std::string Report() { - std::stringstream ss; - ss << "\n- " << this->message; - if (this->errc != std::error_code{}) { - ss << " system error:" << this->errc.message(); - } + [[nodiscard]] std::string Report() const; + [[nodiscard]] std::error_code Code() const; - auto ptr = prev.get(); - while (ptr) { - ss << "\n- "; - ss << ptr->message; - - if (ptr->errc != std::error_code{}) { - ss << " " << ptr->errc.message(); - } - ptr = ptr->prev.get(); - } - - return ss.str(); - } - [[nodiscard]] auto Code() const { - // Find the root error. - std::stack stack; - auto ptr = this; - while (ptr) { - stack.push(ptr); - if (ptr->prev) { - ptr = ptr->prev.get(); - } else { - break; - } - } - while (!stack.empty()) { - auto frame = stack.top(); - stack.pop(); - if (frame->errc != std::error_code{}) { - return frame->errc; - } - } - return std::error_code{}; - } + void Concat(std::unique_ptr rhs); }; + +#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) +#define __builtin_FILE() nullptr +#define __builtin_LINE() (-1) +std::string MakeMsg(std::string&& msg, char const*, std::int32_t); +#else +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line); +#endif } // namespace detail /** @@ -131,8 +100,21 @@ struct Result { } return *impl_ == *that.impl_; } + + friend Result operator+(Result&& lhs, Result&& rhs); }; +[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) { + if (lhs.OK()) { + return std::forward(rhs); + } + if (rhs.OK()) { + return std::forward(lhs); + } + lhs.impl_->Concat(std::move(rhs.impl_)); + return std::forward(lhs); +} + /** * @brief Return success. */ @@ -140,38 +122,43 @@ struct Result { /** * @brief Return failure. */ -[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; } +[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line)}; +} /** * @brief Return failure with `errno`. */ -[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) { - return Result{std::move(msg), std::move(errc)}; +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, + char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)}; } /** * @brief Return failure with a previous error. */ -[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) { - return Result{std::move(msg), std::forward(prev)}; +[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::forward(prev)}; } /** * @brief Return failure with a previous error and a new `errno`. */ -[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { - return Result{std::move(msg), std::move(errc), std::forward(prev)}; +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev, + char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc), + std::forward(prev)}; } // We don't have monad, a simple helper would do. template -[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) { +[[nodiscard]] std::enable_if_t, Result> operator<<(Result&& r, Fn&& fn) { if (!r.OK()) { return std::forward(r); } return fn(); } -inline void SafeColl(Result const& rc) { - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } -} +void SafeColl(Result const& rc); } // namespace xgboost::collective diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 3bc3b389c..11520eede 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, XGBoost Contributors + * Copyright (c) 2022-2024, XGBoost Contributors */ #pragma once @@ -12,7 +12,6 @@ #include // std::size_t #include // std::int32_t, std::uint16_t #include // memset -#include // std::numeric_limits #include // std::string #include // std::error_code, std::system_category #include // std::swap @@ -468,19 +467,30 @@ class TCPSocket { *addr = SockAddress{SockAddrV6{caddr}}; *out = TCPSocket{newfd}; } + // On MacOS, this is automatically set to async socket if the parent socket is async + // We make sure all socket are blocking by default. + // + // On Windows, a closed socket is returned during shutdown. We guard against it when + // setting non-blocking. + if (!out->IsClosed()) { + return out->NonBlocking(false); + } return Success(); } ~TCPSocket() { if (!IsClosed()) { - Close(); + auto rc = this->Close(); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + } } } TCPSocket(TCPSocket const &that) = delete; TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); } TCPSocket &operator=(TCPSocket const &that) = delete; - TCPSocket &operator=(TCPSocket &&that) { + TCPSocket &operator=(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); return *this; } @@ -635,22 +645,26 @@ class TCPSocket { */ std::size_t Recv(std::string *p_str); /** - * \brief Close the socket, called automatically in destructor if the socket is not closed. + * @brief Close the socket, called automatically in destructor if the socket is not closed. */ - void Close() { + Result Close() { if (InvalidSocket() != handle_) { -#if defined(_WIN32) auto rc = system::CloseSocket(handle_); +#if defined(_WIN32) // it's possible that we close TCP sockets after finalizing WSA due to detached thread. if (rc != 0 && system::LastError() != WSANOTINITIALISED) { - system::ThrowAtError("close", rc); + return system::FailWithCode("Failed to close the socket."); } #else - xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); + if (rc != 0) { + return system::FailWithCode("Failed to close the socket."); + } #endif handle_ = InvalidSocket(); } + return Success(); } + /** * \brief Create a TCP socket on specified domain. */ diff --git a/src/collective/loop.cc b/src/collective/loop.cc index b51749fcd..0cd41426d 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -18,9 +18,11 @@ #include "xgboost/logging.h" // for CHECK namespace xgboost::collective { -Result Loop::EmptyQueue(std::queue* p_queue) const { +Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { timer_.Start(__func__); - auto error = [this] { timer_.Stop(__func__); }; + auto error = [this] { + timer_.Stop(__func__); + }; if (stop_) { timer_.Stop(__func__); @@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { poll.WatchWrite(*op.sock); break; } + case Op::kSleep: { + break; + } default: { error(); return Fail("Invalid socket operation."); @@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { // poll, work on fds that are ready. timer_.Start("poll"); - auto rc = poll.Poll(timeout_); - timer_.Stop("poll"); - if (!rc.OK()) { - error(); - return rc; + if (!poll.fds.empty()) { + auto rc = poll.Poll(timeout_); + if (!rc.OK()) { + error(); + return rc; + } } + timer_.Stop("poll"); // we wonldn't be here if the queue is empty. CHECK(!qcopy.empty()); @@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { qcopy.pop(); std::int32_t n_bytes_done{0}; - CHECK(op.sock->NonBlocking()); + if (!op.sock) { + CHECK(op.code == Op::kSleep); + } else { + CHECK(op.sock->NonBlocking()); + } switch (op.code) { case Op::kRead: { if (poll.CheckRead(*op.sock)) { n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); + if (n_bytes_done == 0) { + error(); + return Fail("Encountered EOF. The other end is likely closed."); + } } break; } @@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { } break; } + case Op::kSleep: { + // For testing only. + std::this_thread::sleep_for(std::chrono::seconds{op.n}); + n_bytes_done = op.n; + break; + } default: { error(); return Fail("Invalid socket operation."); @@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { qcopy.push(op); } } + + if (!blocking) { + break; + } } timer_.Stop(__func__); @@ -128,6 +153,15 @@ void Loop::Process() { while (true) { try { std::unique_lock lock{mu_}; + // This can handle missed notification: wait(lock, predicate) is equivalent to: + // + // while (!predicate()) { + // cv.wait(lock); + // } + // + // As a result, if there's a missed notification, the queue wouldn't be empty, hence + // the predicate would be false and the actual wait wouldn't be invoked. Therefore, + // the blocking call can never go unanswered. cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); if (stop_) { break; // only point where this loop can exit. @@ -142,26 +176,27 @@ void Loop::Process() { queue_.pop(); if (op.code == Op::kBlock) { is_blocking = true; - // Block must be the last op in the current batch since no further submit can be - // issued until the blocking call is finished. - CHECK(queue_.empty()); } else { qcopy.push(op); } } - if (!is_blocking) { - // Unblock, we can write to the global queue again. - lock.unlock(); + lock.unlock(); + // Clear the local queue, if `is_blocking` is true, this is blocking the current + // worker thread (but not the client thread), wait until all operations are + // finished. + auto rc = this->ProcessQueue(&qcopy, is_blocking); + + if (is_blocking && rc.OK()) { + CHECK(qcopy.empty()); } - - // Clear the local queue, this is blocking the current worker thread (but not the - // client thread), wait until all operations are finished. - auto rc = this->EmptyQueue(&qcopy); - - if (is_blocking) { - // The unlock is delayed if this is a blocking call - lock.unlock(); + // Push back the remaining operations. + if (rc.OK()) { + std::unique_lock lock{mu_}; + while (!qcopy.empty()) { + queue_.push(qcopy.front()); + qcopy.pop(); + } } // Notify the client thread who called block after all error conditions are set. @@ -228,7 +263,6 @@ Result Loop::Stop() { } this->Submit(Op{Op::kBlock}); - { // Wait for the block call to finish. std::unique_lock lock{mu_}; @@ -243,8 +277,20 @@ Result Loop::Stop() { } } +void Loop::Submit(Op op) { + std::unique_lock lock{mu_}; + if (op.code != Op::kBlock) { + CHECK_NE(op.n, 0); + } + queue_.push(op); + lock.unlock(); + cv_.notify_one(); +} + Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { timer_.Init(__func__); - worker_ = std::thread{[this] { this->Process(); }}; + worker_ = std::thread{[this] { + this->Process(); + }}; } } // namespace xgboost::collective diff --git a/src/collective/loop.h b/src/collective/loop.h index 4839abfd3..a4de2a81b 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -19,20 +19,27 @@ namespace xgboost::collective { class Loop { public: struct Op { - enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code; + // kSleep is only for testing + enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code; std::int32_t rank{-1}; std::int8_t* ptr{nullptr}; std::size_t n{0}; TCPSocket* sock{nullptr}; std::size_t off{0}; - explicit Op(Code c) : code{c} { CHECK(c == kBlock); } + explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); } Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} Op(Op const&) = default; Op& operator=(Op const&) = default; Op(Op&&) = default; Op& operator=(Op&&) = default; + // For testing purpose only + [[nodiscard]] static Op Sleep(std::size_t seconds) { + Op op{kSleep}; + op.n = seconds; + return op; + } }; private: @@ -54,7 +61,7 @@ class Loop { std::exception_ptr curr_exce_{nullptr}; common::Monitor mutable timer_; - Result EmptyQueue(std::queue* p_queue) const; + Result ProcessQueue(std::queue* p_queue, bool blocking) const; // The cunsumer function that runs inside a worker thread. void Process(); @@ -64,12 +71,7 @@ class Loop { */ Result Stop(); - void Submit(Op op) { - std::unique_lock lock{mu_}; - queue_.push(op); - lock.unlock(); - cv_.notify_one(); - } + void Submit(Op op); /** * @brief Block the event loop until all ops are finished. In the case of failure, this diff --git a/src/collective/result.cc b/src/collective/result.cc new file mode 100644 index 000000000..b11710572 --- /dev/null +++ b/src/collective/result.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include "xgboost/collective/result.h" + +#include // for path +#include // for stringstream +#include // for stack + +#include "xgboost/logging.h" + +namespace xgboost::collective { +namespace detail { +[[nodiscard]] std::string ResultImpl::Report() const { + std::stringstream ss; + ss << "\n- " << this->message; + if (this->errc != std::error_code{}) { + ss << " system error:" << this->errc.message(); + } + + auto ptr = prev.get(); + while (ptr) { + ss << "\n- "; + ss << ptr->message; + + if (ptr->errc != std::error_code{}) { + ss << " " << ptr->errc.message(); + } + ptr = ptr->prev.get(); + } + + return ss.str(); +} + +[[nodiscard]] std::error_code ResultImpl::Code() const { + // Find the root error. + std::stack stack; + auto ptr = this; + while (ptr) { + stack.push(ptr); + if (ptr->prev) { + ptr = ptr->prev.get(); + } else { + break; + } + } + while (!stack.empty()) { + auto frame = stack.top(); + stack.pop(); + if (frame->errc != std::error_code{}) { + return frame->errc; + } + } + return std::error_code{}; +} + +void ResultImpl::Concat(std::unique_ptr rhs) { + auto ptr = this; + while (ptr->prev) { + ptr = ptr->prev.get(); + } + ptr->prev = std::move(rhs); +} + +#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) +std::string MakeMsg(std::string&& msg, char const*, std::int32_t) { + return std::forward(msg); +} +#else +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) { + auto name = std::filesystem::path{file}.filename(); + if (file && line != -1) { + return "[" + name.string() + ":" + std::to_string(line) + // NOLINT + "]: " + std::forward(msg); + } + return std::forward(msg); +} +#endif +} // namespace detail + +void SafeColl(Result const& rc) { + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } +} +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index e5ef987f3..0908d9623 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for ASSERT_TRUE, ASSERT_EQ #include // for TCPSocket, Connect, SocketFinalize, SocketStartup @@ -28,18 +28,25 @@ class LoopTest : public ::testing::Test { auto domain = SockDomain::kV4; pair_.first = TCPSocket::Create(domain); - auto port = pair_.first.BindHost(); - pair_.first.Listen(); + in_port_t port{0}; + auto rc = Success() << [&] { + port = pair_.first.BindHost(); + return Success(); + } << [&] { + pair_.first.Listen(); + return Success(); + }; + SafeColl(rc); auto const& addr = SockAddrV4::Loopback().Addr(); - auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); - ASSERT_TRUE(rc.OK()); + rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); + SafeColl(rc); rc = pair_.second.NonBlocking(true); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); pair_.first = pair_.first.Accept(); rc = pair_.first.NonBlocking(true); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); loop_ = std::shared_ptr{new Loop{timeout}}; } @@ -74,8 +81,26 @@ TEST_F(LoopTest, Op) { loop_->Submit(rop); auto rc = loop_->Block(); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(rbuf[0], wbuf[0]); } + +TEST_F(LoopTest, Block) { + // We need to ensure that a blocking call doesn't go unanswered. + auto op = Loop::Op::Sleep(2); + + common::Timer t; + t.Start(); + loop_->Submit(op); + t.Stop(); + // submit is non-blocking + ASSERT_LT(t.ElapsedSeconds(), 1); + + t.Start(); + auto rc = loop_->Block(); + t.Stop(); + SafeColl(rc); + ASSERT_GE(t.ElapsedSeconds(), 1); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_result.cc b/tests/cpp/collective/test_result.cc new file mode 100644 index 000000000..1c7194f92 --- /dev/null +++ b/tests/cpp/collective/test_result.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include +#include + +namespace xgboost::collective { +TEST(Result, Concat) { + auto rc0 = Fail("foo"); + auto rc1 = Fail("bar"); + auto rc = std::move(rc0) + std::move(rc1); + ASSERT_NE(rc.Report().find("foo"), std::string::npos); + ASSERT_NE(rc.Report().find("bar"), std::string::npos); + + auto rc2 = Fail("Another", std::move(rc)); + auto assert_that = [](Result const& rc) { + ASSERT_NE(rc.Report().find("Another"), std::string::npos); + ASSERT_NE(rc.Report().find("foo"), std::string::npos); + ASSERT_NE(rc.Report().find("bar"), std::string::npos); + }; + assert_that(rc2); + + auto empty = Success(); + auto rc3 = std::move(empty) + std::move(rc2); + assert_that(rc3); + + empty = Success(); + auto rc4 = std::move(rc3) + std::move(empty); + assert_that(rc4); +} +} // namespace xgboost::collective From f53f5ca35971a162ee370afb516d0a3bf1e60f89 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Wed, 17 Apr 2024 19:15:06 -0700 Subject: [PATCH 18/26] [CI] Update machine images (#10201) --- tests/buildkite/infrastructure/README.md | 106 ++++++++++++++++++ .../aws-stack-creator/metadata.py | 14 +-- .../windows-gpu-bootstrap.yml | 12 +- 3 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 tests/buildkite/infrastructure/README.md diff --git a/tests/buildkite/infrastructure/README.md b/tests/buildkite/infrastructure/README.md new file mode 100644 index 000000000..cc3e552e7 --- /dev/null +++ b/tests/buildkite/infrastructure/README.md @@ -0,0 +1,106 @@ +BuildKite CI Infrastructure +=========================== + +# Worker image builder (`worker-image-pipeline/`) + +Use EC2 Image Builder to build machine images in a deterministic fashion. +The machine images are used to initialize workers in the CI/CD pipelines. + +## Editing bootstrap scripts + +Currently, we create two pipelines for machine images: one for Linux workers and another +for Windows workers. +You can edit the bootstrap scripts to change how the worker machines are initialized. + +* `linux-amd64-gpu-bootstrap.yml`: Bootstrap script for Linux worker machines +* `windows-gpu-bootstrap.yml`: Bootstrap script for Windows worker machines + +## Creating and running Image Builder pipelines + +Run the following commands to create and run pipelines in EC2 Image Builder service: +```bash +python worker-image-pipeline/create_worker_image_pipelines.py --aws-region us-west-2 +python worker-image-pipeline/run_pipelines.py --aws-region us-west-2 +``` +Go to the AWS CloudFormation console and verify the existence of two CloudFormation stacks: +* `buildkite-windows-gpu-worker` +* `buildkite-linux-amd64-gpu-worker` + +Then go to the EC2 Image Builder console to check the status of the image builds. You may +want to inspect the log output should a build fails. +Once the new machine images are done building, see the next section to deploy the new +images to the worker machines. + +# Elastic CI Stack for AWS (`aws-stack-creator/`) + +Use EC2 Autoscaling groups to launch worker machines in EC2. BuildKite periodically sends +messages to the Autoscaling groups to increase or decrease the number of workers according +to the number of outstanding testing jobs. + +## Deploy an updated CI stack with new machine images + +First, edit `aws-stack-creator/metadata.py` to update the `AMI_ID` fields: +```python +AMI_ID = { + # Managed by XGBoost team + "linux-amd64-gpu": { + "us-west-2": "...", + }, + "linux-amd64-mgpu": { + "us-west-2": "...", + }, + "windows-gpu": { + "us-west-2": "...", + }, + "windows-cpu": { + "us-west-2": "...", + }, + # Managed by BuildKite + # from https://s3.amazonaws.com/buildkite-aws-stack/latest/aws-stack.yml + "linux-amd64-cpu": { + "us-west-2": "...", + }, + "pipeline-loader": { + "us-west-2": "...", + }, + "linux-arm64-cpu": { + "us-west-2": "...", + }, +} +``` +AMI IDs uniquely identify the machine images in the EC2 service. +Go to the EC2 Image Builder console to find the AMI IDs for the new machine images +(see the previous section), and update the following fields: + +* `AMI_ID["linux-amd64-gpu"]["us-west-2"]`: + Use the latest output from the `buildkite-linux-amd64-gpu-worker` pipeline +* `AMI_ID["linux-amd64-mgpu"]["us-west-2"]`: + Should be identical to `AMI_ID["linux-amd64-gpu"]["us-west-2"]` +* `AMI_ID["windows-gpu"]["us-west-2"]`: + Use the latest output from the `buildkite-windows-gpu-worker` pipeline +* `AMI_ID["windows-cpu"]["us-west-2"]`: + Should be identical to `AMI_ID["windows-gpu"]["us-west-2"]` + +Next, visit https://s3.amazonaws.com/buildkite-aws-stack/latest/aws-stack.yml +to look up the AMI IDs for the following fields: + +* `AMI_ID["linux-amd64-cpu"]["us-west-2"]`: Copy and paste the AMI ID from the field + `Mappings/AWSRegion2AMI/us-west-2/linuxamd64` +* `AMI_ID["pipeline-loader"]["us-west-2"]`: + Should be identical to `AMI_ID["linux-amd64-cpu"]["us-west-2"]` +* `AMI_ID["linux-arm64-cpu"]["us-west-2"]`: Copy and paste the AMI ID from the field + `Mappings/AWSRegion2AMI/us-west-2/linuxarm64` + +Finally, run the following commands to deploy the new machine images: +``` +python aws-stack-creator/create_stack.py --aws-region us-west-2 --agent-token AGENT_TOKEN +``` +Go to the AWS CloudFormation console and verify the existence of the following +CloudFormation stacks: +* `buildkite-pipeline-loader-autoscaling-group` +* `buildkite-linux-amd64-cpu-autoscaling-group` +* `buildkite-linux-amd64-gpu-autoscaling-group` +* `buildkite-linux-amd64-mgpu-autoscaling-group` +* `buildkite-linux-arm64-cpu-autoscaling-group` +* `buildkite-windows-cpu-autoscaling-group` +* `buildkite-windows-gpu-autoscaling-group` diff --git a/tests/buildkite/infrastructure/aws-stack-creator/metadata.py b/tests/buildkite/infrastructure/aws-stack-creator/metadata.py index 3b56a2d8c..e086021da 100644 --- a/tests/buildkite/infrastructure/aws-stack-creator/metadata.py +++ b/tests/buildkite/infrastructure/aws-stack-creator/metadata.py @@ -1,27 +1,27 @@ AMI_ID = { # Managed by XGBoost team "linux-amd64-gpu": { - "us-west-2": "ami-08c3bc1dd5ec8bc5c", + "us-west-2": "ami-070080d04e81c5e39", }, "linux-amd64-mgpu": { - "us-west-2": "ami-08c3bc1dd5ec8bc5c", + "us-west-2": "ami-070080d04e81c5e39", }, "windows-gpu": { - "us-west-2": "ami-03c7f2156f93b22a7", + "us-west-2": "ami-07c14abcf529d816a", }, "windows-cpu": { - "us-west-2": "ami-03c7f2156f93b22a7", + "us-west-2": "ami-07c14abcf529d816a", }, # Managed by BuildKite # from https://s3.amazonaws.com/buildkite-aws-stack/latest/aws-stack.yml "linux-amd64-cpu": { - "us-west-2": "ami-015e64acb52b3e595", + "us-west-2": "ami-0180f7fb0f07eb0bc", }, "pipeline-loader": { - "us-west-2": "ami-015e64acb52b3e595", + "us-west-2": "ami-0180f7fb0f07eb0bc", }, "linux-arm64-cpu": { - "us-west-2": "ami-0884e9c23a2fa98d0", + "us-west-2": "ami-00686bdc2043a5505", }, } diff --git a/tests/buildkite/infrastructure/worker-image-pipeline/windows-gpu-bootstrap.yml b/tests/buildkite/infrastructure/worker-image-pipeline/windows-gpu-bootstrap.yml index e4d212fda..128351e0d 100644 --- a/tests/buildkite/infrastructure/worker-image-pipeline/windows-gpu-bootstrap.yml +++ b/tests/buildkite/infrastructure/worker-image-pipeline/windows-gpu-bootstrap.yml @@ -15,9 +15,9 @@ phases: choco --version choco feature enable -n=allowGlobalConfirmation - # CMake 3.27 - Write-Host '>>> Installing CMake 3.27...' - choco install cmake --version 3.27.9 --installargs "ADD_CMAKE_TO_PATH=System" + # CMake 3.29.2 + Write-Host '>>> Installing CMake 3.29.2...' + choco install cmake --version 3.29.2 --installargs "ADD_CMAKE_TO_PATH=System" if ($LASTEXITCODE -ne 0) { throw "Last command failed" } # Notepad++ @@ -53,9 +53,9 @@ phases: "--wait --passive --norestart --includeOptional" if ($LASTEXITCODE -ne 0) { throw "Last command failed" } - # Install CUDA 11.8 - Write-Host '>>> Installing CUDA 11.8...' - choco install cuda --version=11.8.0.52206 + # Install CUDA 12.4 + Write-Host '>>> Installing CUDA 12.4...' + choco install cuda --version=12.4.1.551 if ($LASTEXITCODE -ne 0) { throw "Last command failed" } # Install R From 0aa260039901c961361509f0136b9cc324eddb0a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 Apr 2024 23:04:41 -0700 Subject: [PATCH 19/26] Bump org.apache.maven.plugins:maven-jar-plugin (#10202) Bumps [org.apache.maven.plugins:maven-jar-plugin](https://github.com/apache/maven-jar-plugin) from 3.3.0 to 3.4.0. - [Release notes](https://github.com/apache/maven-jar-plugin/releases) - [Commits](https://github.com/apache/maven-jar-plugin/compare/maven-jar-plugin-3.3.0...maven-jar-plugin-3.4.0) --- updated-dependencies: - dependency-name: org.apache.maven.plugins:maven-jar-plugin dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- jvm-packages/xgboost4j/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 4ef55ae2c..5012eaf14 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -99,7 +99,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.3.0 + 3.4.0 From 303c603c7d7b77006303e71be5d5ef56c53f7152 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 18 Apr 2024 19:09:30 +0800 Subject: [PATCH 20/26] [pyspark] Reuse the collective communicator. (#10198) --- python-package/xgboost/spark/utils.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index c0a876419..7dbe290ae 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -14,7 +14,8 @@ import pyspark from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext from pyspark.sql.session import SparkSession -from xgboost import Booster, XGBModel, collective +from xgboost import Booster, XGBModel +from xgboost.collective import CommunicatorContext as CCtx from xgboost.tracker import RabitTracker @@ -42,22 +43,12 @@ def _get_default_params_from_func( return filtered_params_dict -class CommunicatorContext: - """A context controlling collective communicator initialization and finalization. - This isn't specificially necessary (note Part 3), but it is more understandable - coding-wise. - - """ +class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods + """Context with PySpark specific task ID.""" def __init__(self, context: BarrierTaskContext, **args: Any) -> None: - self.args = args - self.args["DMLC_TASK_ID"] = str(context.partitionId()) - - def __enter__(self) -> None: - collective.init(**self.args) - - def __exit__(self, *args: Any) -> None: - collective.finalize() + args["DMLC_TASK_ID"] = str(context.partitionId()) + super().__init__(**args) def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: From 531ff21b201a0dbf631fee2626cda49371f410fd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 07:18:18 -0700 Subject: [PATCH 21/26] Bump org.scala-lang.modules:scala-collection-compat_2.12 (#10193) Bumps [org.scala-lang.modules:scala-collection-compat_2.12](https://github.com/scala/scala-collection-compat) from 2.11.0 to 2.12.0. - [Release notes](https://github.com/scala/scala-collection-compat/releases) - [Commits](https://github.com/scala/scala-collection-compat/compare/v2.11.0...v2.12.0) --- updated-dependencies: - dependency-name: org.scala-lang.modules:scala-collection-compat_2.12 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- jvm-packages/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index f34deefd2..ced9f7b3f 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -47,7 +47,7 @@ 23.12.1 cuda12 3.2.17 - 2.11.0 + 2.12.0 From 551fa6e25e36efd057b7af934a243690fcafb058 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:46:28 -0700 Subject: [PATCH 22/26] Bump scalatest.version from 3.2.17 to 3.2.18 in /jvm-packages/xgboost4j (#10196) Bumps `scalatest.version` from 3.2.17 to 3.2.18. Updates `org.scalatest:scalatest_2.12` from 3.2.17 to 3.2.18 - [Release notes](https://github.com/scalatest/scalatest/releases) - [Commits](https://github.com/scalatest/scalatest/compare/release-3.2.17...release-3.2.18) Updates `org.scalactic:scalactic_2.12` from 3.2.17 to 3.2.18 - [Release notes](https://github.com/scalatest/scalatest/releases) - [Commits](https://github.com/scalatest/scalatest/compare/release-3.2.17...release-3.2.18) --- updated-dependencies: - dependency-name: org.scalatest:scalatest_2.12 dependency-type: direct:development update-type: version-update:semver-patch - dependency-name: org.scalactic:scalactic_2.12 dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- jvm-packages/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index ced9f7b3f..71c2b5fa1 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -46,7 +46,7 @@ 23.12.1 23.12.1 cuda12 - 3.2.17 + 3.2.18 2.12.0 From 3f64b4fde3b62458eb32f56d36837298d25ccb26 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 19 Apr 2024 03:17:23 +0800 Subject: [PATCH 23/26] [coll] Add global functions. (#10203) --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + src/collective/aggregator.h | 2 +- src/collective/allgather.cc | 47 ++++++++- src/collective/allgather.h | 111 ++++++++++++++++++++ src/collective/allreduce.cc | 42 ++++---- src/collective/allreduce.h | 46 +++++++- src/collective/broadcast.h | 27 ++++- src/collective/comm.cu | 9 +- src/common/device_helpers.cuh | 1 - src/common/hist_util.cuh | 1 + src/tree/gpu_hist/gradient_based_sampler.cu | 4 +- tests/cpp/CMakeLists.txt | 2 +- tests/cpp/collective/test_allgather.cc | 12 +-- tests/cpp/collective/test_allgather.cu | 8 +- tests/cpp/collective/test_allreduce.cc | 2 +- tests/cpp/collective/test_allreduce.cu | 6 +- tests/cpp/collective/test_broadcast.cc | 11 +- tests/cpp/common/test_device_helpers.cu | 12 ++- tests/cpp/common/test_io.cc | 3 +- tests/cpp/common/test_json.cc | 4 +- 21 files changed, 283 insertions(+), 69 deletions(-) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 99241249f..69cdd09a3 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -104,6 +104,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index fc2cd3b9f..b34d8c649 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -104,6 +104,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index 8a5b31c36..bc652f2e8 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -165,7 +165,7 @@ template T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) { std::array results{dividend, divisor}; auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size())); - collective::SafeColl(rc); + SafeColl(rc); std::tie(dividend, divisor) = std::tuple_cat(results); if (divisor <= 0) { return std::numeric_limits::quiet_NaN(); diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 446db73b5..5d1ec664e 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -33,6 +33,7 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size bool is_last_segment = send_rank == (world - 1); auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size; auto send_seg = data.subspan(send_off, send_nbytes); + CHECK_NE(send_seg.size(), 0); return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); } << [&] { auto recv_rank = (rank + world - r - 1 + worker_off) % world; @@ -40,9 +41,10 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size bool is_last_segment = recv_rank == (world - 1); auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size; auto recv_seg = data.subspan(recv_off, recv_nbytes); + CHECK_NE(recv_seg.size(), 0); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); } << [&] { - return prev_ch->Block(); + return comm.Block(); }; if (!rc.OK()) { return rc; @@ -106,4 +108,47 @@ namespace detail { return comm.Block(); } } // namespace detail + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, CommGroup const& comm, std::vector> const& input) { + auto n_inputs = input.size(); + std::vector sizes(n_inputs); + std::transform(input.cbegin(), input.cend(), sizes.begin(), + [](auto const& vec) { return vec.size(); }); + + std::vector recv_segments(comm.World() + 1, 0); + + HostDeviceVector recv; + auto rc = + AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv); + SafeColl(rc); + + auto global_sizes = common::RestoreType(recv.ConstHostSpan()); + std::vector offset(global_sizes.size() + 1); + offset[0] = 0; + for (std::size_t i = 1; i < offset.size(); i++) { + offset[i] = offset[i - 1] + global_sizes[i - 1]; + } + + std::vector collected; + for (auto const& vec : input) { + collected.insert(collected.end(), vec.cbegin(), vec.cend()); + } + rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments, + &recv); + SafeColl(rc); + auto out = common::RestoreType(recv.ConstHostSpan()); + + std::vector> result; + for (std::size_t i = 1; i < offset.size(); ++i) { + std::vector local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]); + result.emplace_back(std::move(local)); + } + return result; +} + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, std::vector> const& input) { + return VectorAllgatherV(ctx, *GlobalCommGroup(), input); +} } // namespace xgboost::collective diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 8de9f1984..ca44c3916 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -102,4 +102,115 @@ template return detail::RingAllgatherV(comm, sizes, s_segments, erased_result); } + +template +[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm, + linalg::VectorView data) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + + auto const& cctx = comm.Ctx(ctx, data.Device()); + auto backend = comm.Backend(data.Device()); + return backend->Allgather(cctx, erased); +} + +/** + * @brief Gather all data from all workers. + * + * @param data The input and output buffer, needs to be pre-allocated by the caller. + */ +template +[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView data) { + auto const& cg = *GlobalCommGroup(); + if (data.Size() % cg.World() != 0) { + return Fail("The total number of elements should be multiple of the number of workers."); + } + return Allgather(ctx, cg, data); +} + +template +[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm, + linalg::VectorView data, + std::vector* recv_segments, + HostDeviceVector* recv) { + if (!comm.IsDistributed()) { + return Success(); + } + std::vector sizes(comm.World(), 0); + sizes[comm.Rank()] = data.Values().size_bytes(); + auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()}); + auto rc = comm.Backend(DeviceOrd::CPU()) + ->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes); + if (!rc.OK()) { + return rc; + } + + recv_segments->resize(sizes.size() + 1); + detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()}); + auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL); + recv->SetDevice(data.Device()); + recv->Resize(total_bytes); + + auto s_segments = common::Span{recv_segments->data(), recv_segments->size()}; + + auto backend = comm.Backend(data.Device()); + auto erased = common::EraseType(data.Values()); + + return backend->AllgatherV( + comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments, + data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast); +} + +/** + * @brief Allgather with variable length data. + * + * @param data The input data. + * @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are + * from the first worker, [2, 5) elements are from the second one. + * @param recv The buffer storing the result. + */ +template +[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView data, + std::vector* recv_segments, + HostDeviceVector* recv) { + return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv); +} + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, CommGroup const& comm, std::vector> const& input); + +/** + * @brief Gathers variable-length data from all processes and distributes it to all processes. + * + * @param inputs All the inputs from the local worker. The number of inputs can vary + * across different workers. Along with which, the size of each vector in + * the input can also vary. + * + * @return The AllgatherV result, containing vectors from all workers. + */ +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, std::vector> const& input); + +/** + * @brief Gathers variable-length strings from all processes and distributes them to all processes. + * @param input Variable-length list of variable-length strings. + */ +[[nodiscard]] inline Result AllgatherStrings(std::vector const& input, + std::vector* p_result) { + std::vector> inputs(input.size()); + for (std::size_t i = 0; i < input.size(); ++i) { + inputs[i] = {input[i].cbegin(), input[i].cend()}; + } + Context ctx; + auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs); + auto& result = *p_result; + result.resize(out.size()); + for (std::size_t i = 0; i < out.size(); ++i) { + result[i] = {out[i].cbegin(), out[i].cend()}; + } + return Success(); +} } // namespace xgboost::collective diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index d9cf8b828..55c5c8854 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -68,39 +68,35 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto s_buf = common::Span{buffer.data(), buffer.size()}; for (std::int32_t r = 0; r < world - 1; ++r) { - // send to ring next - auto send_rank = (rank + world - r) % world; - auto send_off = send_rank * n_bytes_in_seg; + common::Span seg, recv_seg; + auto rc = Success() << [&] { + // send to ring next + auto send_rank = (rank + world - r) % world; + auto send_off = send_rank * n_bytes_in_seg; - bool is_last_segment = send_rank == (world - 1); + bool is_last_segment = send_rank == (world - 1); - auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; - auto send_seg = data.subspan(send_off, seg_nbytes); + auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; + CHECK_EQ(seg_nbytes % sizeof(T), 0); - auto rc = next_ch->SendAll(send_seg); - if (!rc.OK()) { - return rc; - } + auto send_seg = data.subspan(send_off, seg_nbytes); + return next_ch->SendAll(send_seg); + } << [&] { + // receive from ring prev + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = recv_rank * n_bytes_in_seg; - // receive from ring prev - auto recv_rank = (rank + world - r - 1) % world; - auto recv_off = recv_rank * n_bytes_in_seg; + bool is_last_segment = recv_rank == (world - 1); - is_last_segment = recv_rank == (world - 1); + auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg; + CHECK_EQ(seg_nbytes % sizeof(T), 0); - seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg; - CHECK_EQ(seg_nbytes % sizeof(T), 0); - auto recv_seg = data.subspan(recv_off, seg_nbytes); - auto seg = s_buf.subspan(0, recv_seg.size()); - - rc = std::move(rc) << [&] { + recv_seg = data.subspan(recv_off, seg_nbytes); + seg = s_buf.subspan(0, recv_seg.size()); return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; - if (!rc.OK()) { - return rc; - } // accumulate to recv_seg CHECK_EQ(seg.size(), recv_seg.size()); diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h index 0c94d11cc..3e88cca11 100644 --- a/src/collective/allreduce.h +++ b/src/collective/allreduce.h @@ -1,15 +1,18 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int8_t #include // for function #include // for is_invocable_v, enable_if_t +#include // for vector #include "../common/type.h" // for EraseType, RestoreType -#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler #include "comm.h" // for Comm, RestoreType +#include "comm_group.h" // for GlobalCommGroup #include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -27,8 +30,7 @@ std::enable_if_t, common::Span> auto erased = common::EraseType(data); auto type = ToDType::kType; - auto erased_fn = [type, redop](common::Span lhs, - common::Span out) { + auto erased_fn = [redop](common::Span lhs, common::Span out) { CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; auto lhs_t = common::RestoreType(lhs); auto rhs_t = common::RestoreType(out); @@ -37,4 +39,40 @@ std::enable_if_t, common::Span> return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); } + +template +[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm, + linalg::TensorView data, Op op) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + auto type = ToDType::kType; + + auto backend = comm.Backend(data.Device()); + return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op); +} + +template +[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView data, Op op) { + return Allreduce(ctx, *GlobalCommGroup(), data, op); +} + +/** + * @brief Specialization for std::vector. + */ +template +[[nodiscard]] Result Allreduce(Context const* ctx, std::vector* data, Op op) { + return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op); +} + +/** + * @brief Specialization for scalar value. + */ +template +[[nodiscard]] std::enable_if_t && std::is_trivial_v, Result> +Allreduce(Context const* ctx, T* data, Op op) { + return Allreduce(ctx, linalg::MakeVec(data, 1), op); +} } // namespace xgboost::collective diff --git a/src/collective/broadcast.h b/src/collective/broadcast.h index 28db83815..61cab8cdd 100644 --- a/src/collective/broadcast.h +++ b/src/collective/broadcast.h @@ -1,11 +1,15 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int32_t, int8_t -#include "comm.h" // for Comm -#include "xgboost/collective/result.h" // for +#include "../common/type.h" +#include "comm.h" // for Comm, EraseType +#include "comm_group.h" // for CommGroup +#include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -23,4 +27,21 @@ template common::Span{reinterpret_cast(data.data()), n_total_bytes}; return cpu_impl::Broadcast(comm, erased, root); } + +template +[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm, + linalg::VectorView data, std::int32_t root) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + auto backend = comm.Backend(data.Device()); + return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root); +} + +template +[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView data, std::int32_t root) { + return Broadcast(ctx, *GlobalCommGroup(), data, root); +} } // namespace xgboost::collective diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 8788a2436..6566f28fa 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr stub, std::shared ncclUniqueId id; if (comm.Rank() == kRootRank) { auto rc = stub->GetUniqueId(&id); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); } auto rc = coll->Broadcast( comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, kRootRank); @@ -81,8 +81,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p GetCudaUUID(s_this_uuid, ctx->Device()); auto rc = pimpl->Allgather(root, common::EraseType(s_uuid)); - - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); std::vector> converted(root.World()); std::size_t j = 0; @@ -103,7 +102,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p [&] { return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()); }; - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < root.World(); ++r) { this->channels_.emplace_back( @@ -114,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p NCCLComm::~NCCLComm() { if (nccl_comm_) { auto rc = stub_->CommDestroy(nccl_comm_); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); } } } // namespace xgboost::collective diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 9223302aa..f4fce42f8 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -12,7 +12,6 @@ #include // make_transform_output_iterator #include #include -#include #include #include #include diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index fe3771924..cf1043ddb 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -8,6 +8,7 @@ #define COMMON_HIST_UTIL_CUH_ #include +#include // for sort #include // for size_t diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 718474a3e..f9a3819ad 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -1,13 +1,13 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include +#include // for sort #include #include #include -#include #include // for size_t #include #include diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index b1a2e0ded..0bbe5e223 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -71,7 +71,7 @@ target_include_directories(testxgboost ${xgboost_SOURCE_DIR}/rabit/include) target_link_libraries(testxgboost PRIVATE - ${GTEST_LIBRARIES}) + GTest::gtest GTest::gmock) set_output_directory(testxgboost ${xgboost_BINARY_DIR}) diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index b6158693b..b25db54cb 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for ASSERT_EQ #include // for Span, oper... @@ -35,7 +35,7 @@ class Worker : public WorkerForTest { data[comm_.Rank()] = comm_.Rank(); auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < comm_.World(); ++r) { ASSERT_EQ(data[r], r); @@ -52,7 +52,7 @@ class Worker : public WorkerForTest { std::iota(seg.begin(), seg.end(), comm_.Rank()); auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < comm_.World(); ++r) { auto seg = s_data.subspan(r * n, n); @@ -81,7 +81,7 @@ class Worker : public WorkerForTest { std::vector data(comm_.Rank() + 1, comm_.Rank()); std::vector result; auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); CheckV(result); } @@ -91,7 +91,7 @@ class Worker : public WorkerForTest { std::int32_t n{comm_.Rank()}; std::vector result; auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t i = 0; i < comm_.World(); ++i) { ASSERT_EQ(result[i], i); } @@ -105,7 +105,7 @@ class Worker : public WorkerForTest { std::vector sizes(comm_.World(), 0); sizes[comm_.Rank()] = s_data.size_bytes(); auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); std::shared_ptr pcoll{new Coll{}}; std::vector recv_segments(comm_.World() + 1, 0); diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu index 98ece7d17..f145681da 100644 --- a/tests/cpp/collective/test_allgather.cu +++ b/tests/cpp/collective/test_allgather.cu @@ -34,7 +34,7 @@ class Worker : public NCCLWorkerForTest { std::vector sizes(comm_.World(), -1); sizes[comm_.Rank()] = s_data.size_bytes(); auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); // create result dh::device_vector result(comm_.World(), -1); auto s_result = common::EraseType(dh::ToSpan(result)); @@ -42,7 +42,7 @@ class Worker : public NCCLWorkerForTest { std::vector recv_seg(nccl_comm_->World() + 1, 0); rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t i = 0; i < comm_.World(); ++i) { ASSERT_EQ(result[i], i); @@ -58,7 +58,7 @@ class Worker : public NCCLWorkerForTest { std::vector sizes(nccl_comm_->World(), 0); sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); // create result dh::device_vector result(n_bytes / sizeof(std::int32_t), -1); @@ -67,7 +67,7 @@ class Worker : public NCCLWorkerForTest { std::vector recv_seg(nccl_comm_->World() + 1, 0); rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); // check segment size if (algo != AllgatherVAlgo::kBcast) { auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()]; diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 457594cd9..13a6ca656 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -59,7 +59,7 @@ class AllreduceWorker : public WorkerForTest { auto pcoll = std::shared_ptr{new Coll{}}; auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (auto v : data) { ASSERT_EQ(v, ~std::uint32_t{0}); } diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu index f7e11dec2..8bda1e0de 100644 --- a/tests/cpp/collective/test_allreduce.cu +++ b/tests/cpp/collective/test_allreduce.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -24,7 +24,7 @@ class Worker : public NCCLWorkerForTest { data[comm_.Rank()] = ~std::uint32_t{0}; auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); thrust::host_vector h_data(data.size()); thrust::copy(data.cbegin(), data.cend(), h_data.begin()); for (auto v : h_data) { @@ -36,7 +36,7 @@ class Worker : public NCCLWorkerForTest { dh::device_vector data(314, 1.5); auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kF8, Op::kSum); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::size_t i = 0; i < data.size(); ++i) { auto v = data[i]; ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 4d0d87e93..1b1d73428 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include #include @@ -10,7 +10,6 @@ #include // for vector #include "../../../src/collective/broadcast.h" // for Broadcast -#include "../../../src/collective/tracker.h" // for GetHostAddress #include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { @@ -24,14 +23,14 @@ class Worker : public WorkerForTest { // basic test std::vector data(1, comm_.Rank()); auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(data[0], r); } for (std::int32_t r = 0; r < comm_.World(); ++r) { std::vector data(1 << 16, comm_.Rank()); auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(data[0], r); } } @@ -41,11 +40,11 @@ class BroadcastTest : public SocketTest {}; } // namespace TEST_F(BroadcastTest, Basic) { - std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Worker worker{host, port, timeout, n_workers, r}; worker.Run(); }); -} // namespace +} } // namespace xgboost::collective diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 1d10a48ad..4178e55d8 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -1,14 +1,16 @@ -/*! - * Copyright 2017-2021 XGBoost contributors +/** + * Copyright 2017-2024, XGBoost contributors */ +#include +#include // for is_sorted +#include + #include #include -#include #include -#include + #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/quantile.h" -#include "../helpers.h" #include "gtest/gtest.h" TEST(SumReduce, Test) { diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index 4c4d4efe0..face21851 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -1,10 +1,11 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include // for size_t #include // for ofstream +#include // for iota #include "../../../src/common/io.h" #include "../filesystem.h" // dmlc::TemporaryDirectory diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 3ee041a33..e144bdc45 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -4,10 +4,10 @@ #include #include -#include // for back_inserter +#include // for numeric_limits #include +#include // for iota -#include "../../../src/common/charconv.h" #include "../../../src/common/io.h" #include "../../../src/common/json_utils.h" #include "../../../src/common/threading_utils.h" // for ParallelFor From bb212bf33c1ead1d2ccebc06ba1b8c0fff1b5d0f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:20:31 -0700 Subject: [PATCH 24/26] Bump org.apache.flink:flink-clients in /jvm-packages (#10197) Bumps [org.apache.flink:flink-clients](https://github.com/apache/flink) from 1.18.0 to 1.19.0. - [Commits](https://github.com/apache/flink/compare/release-1.18.0...release-1.19.0) --- updated-dependencies: - dependency-name: org.apache.flink:flink-clients dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- jvm-packages/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 71c2b5fa1..dc47b74ae 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -33,7 +33,7 @@ UTF-8 1.8 1.8 - 1.18.0 + 1.19.0 4.13.2 3.4.1 3.4.1 From 8fb05c8c957817e730d311bc5f2a01dd3c5ea47c Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Sat, 20 Apr 2024 00:24:40 +0800 Subject: [PATCH 25/26] [pyspark] support stage-level for yarn/k8s (#10209) --- python-package/xgboost/spark/core.py | 73 ++++--- .../test_with_spark/test_spark_local.py | 206 ++++++++++++++---- 2 files changed, 213 insertions(+), 66 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 741adcb03..2f24effe5 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -347,15 +347,14 @@ class _SparkXGBParams( predict_params[param.name] = self.getOrDefault(param) return predict_params - def _validate_gpu_params(self) -> None: + def _validate_gpu_params( + self, spark_version: str, conf: SparkConf, is_local: bool = False + ) -> None: """Validate the gpu parameters and gpu configurations""" if self._run_on_gpu(): - ss = _get_spark_session() - sc = ss.sparkContext - - if _is_local(sc): - # Support GPU training in Spark local mode is just for debugging + if is_local: + # Supporting GPU training in Spark local mode is just for debugging # purposes, so it's okay for printing the below warning instead of # checking the real gpu numbers and raising the exception. get_logger(self.__class__.__name__).warning( @@ -364,33 +363,41 @@ class _SparkXGBParams( self.getOrDefault(self.num_workers), ) else: - executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") + executor_gpus = conf.get("spark.executor.resource.gpu.amount") if executor_gpus is None: raise ValueError( "The `spark.executor.resource.gpu.amount` is required for training" " on GPU." ) - - if not ( - ss.version >= "3.4.0" - and _is_standalone_or_localcluster(sc.getConf()) + gpu_per_task = conf.get("spark.task.resource.gpu.amount") + if gpu_per_task is not None and float(gpu_per_task) > 1.0: + get_logger(self.__class__.__name__).warning( + "The configuration assigns %s GPUs to each Spark task, but each " + "XGBoost training task only utilizes 1 GPU, which will lead to " + "unnecessary GPU waste", + gpu_per_task, + ) + # For 3.5.1+, Spark supports task stage-level scheduling for + # Yarn/K8s/Standalone/Local cluster + # From 3.4.0 ~ 3.5.0, Spark only supports task stage-level scheduing for + # Standalone/Local cluster + # For spark below 3.4.0, Task stage-level scheduling is not supported. + # + # With stage-level scheduling, spark.task.resource.gpu.amount is not required + # to be set explicitly. Or else, spark.task.resource.gpu.amount is a must-have and + # must be set to 1.0 + if spark_version < "3.4.0" or ( + "3.4.0" <= spark_version < "3.5.1" + and not _is_standalone_or_localcluster(conf) ): - # We will enable stage-level scheduling in spark 3.4.0+ which doesn't - # require spark.task.resource.gpu.amount to be set explicitly - gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount") if gpu_per_task is not None: if float(gpu_per_task) < 1.0: raise ValueError( - "XGBoost doesn't support GPU fractional configurations. " - "Please set `spark.task.resource.gpu.amount=spark.executor" - ".resource.gpu.amount`" - ) - - if float(gpu_per_task) > 1.0: - get_logger(self.__class__.__name__).warning( - "%s GPUs for each Spark task is configured, but each " - "XGBoost training task uses only 1 GPU.", - gpu_per_task, + "XGBoost doesn't support GPU fractional configurations. Please set " + "`spark.task.resource.gpu.amount=spark.executor.resource.gpu." + "amount`. To enable GPU fractional configurations, you can try " + "standalone/localcluster with spark 3.4.0+ and" + "YARN/K8S with spark 3.5.1+" ) else: raise ValueError( @@ -475,7 +482,9 @@ class _SparkXGBParams( "`pyspark.ml.linalg.Vector` type." ) - self._validate_gpu_params() + ss = _get_spark_session() + sc = ss.sparkContext + self._validate_gpu_params(ss.version, sc.getConf(), _is_local(sc)) def _run_on_gpu(self) -> bool: """If train or transform on the gpu according to the parameters""" @@ -925,10 +934,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): ) return True - if not _is_standalone_or_localcluster(conf): + if ( + "3.4.0" <= spark_version < "3.5.1" + and not _is_standalone_or_localcluster(conf) + ): self.logger.info( - "Stage-level scheduling in xgboost requires spark standalone or " - "local-cluster mode" + "For %s, Stage-level scheduling in xgboost requires spark standalone " + "or local-cluster mode", + spark_version, ) return True @@ -980,7 +993,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): """Try to enable stage-level scheduling""" ss = _get_spark_session() conf = ss.sparkContext.getConf() - if self._skip_stage_level_scheduling(ss.version, conf): + if _is_local(ss.sparkContext) or self._skip_stage_level_scheduling( + ss.version, conf + ): return rdd # executor_cores will not be None diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index b8c16ef1c..ab983c920 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -929,8 +929,127 @@ class TestPySparkLocal: model_loaded.set_device("cuda") assert model_loaded._run_on_gpu() + def test_validate_gpu_params(self) -> None: + # Standalone + standalone_conf = ( + SparkConf() + .setMaster("spark://foo") + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + .set("spark.task.resource.gpu.amount", "0.08") + ) + classifer_on_cpu = SparkXGBClassifier(use_gpu=False) + classifer_on_gpu = SparkXGBClassifier(use_gpu=True) + + # No exception for classifier on CPU + classifer_on_cpu._validate_gpu_params("3.4.0", standalone_conf) + + with pytest.raises( + ValueError, match="XGBoost doesn't support GPU fractional configurations" + ): + classifer_on_gpu._validate_gpu_params("3.3.0", standalone_conf) + + # No issues + classifer_on_gpu._validate_gpu_params("3.4.0", standalone_conf) + classifer_on_gpu._validate_gpu_params("3.4.1", standalone_conf) + classifer_on_gpu._validate_gpu_params("3.5.0", standalone_conf) + classifer_on_gpu._validate_gpu_params("3.5.1", standalone_conf) + + # no spark.executor.resource.gpu.amount + standalone_bad_conf = ( + SparkConf() + .setMaster("spark://foo") + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.task.resource.gpu.amount", "0.08") + ) + msg_match = ( + "The `spark.executor.resource.gpu.amount` is required for training on GPU" + ) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.3.0", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.4.0", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.4.1", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.5.0", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.5.1", standalone_bad_conf) + + standalone_bad_conf = ( + SparkConf() + .setMaster("spark://foo") + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + ) + msg_match = ( + "The `spark.task.resource.gpu.amount` is required for training on GPU" + ) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.3.0", standalone_bad_conf) + + classifer_on_gpu._validate_gpu_params("3.4.0", standalone_bad_conf) + classifer_on_gpu._validate_gpu_params("3.5.0", standalone_bad_conf) + classifer_on_gpu._validate_gpu_params("3.5.1", standalone_bad_conf) + + # Yarn and K8s mode + for mode in ["yarn", "k8s://"]: + conf = ( + SparkConf() + .setMaster(mode) + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + .set("spark.task.resource.gpu.amount", "0.08") + ) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.3.0", conf) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.4.0", conf) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.4.1", conf) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.5.0", conf) + + classifer_on_gpu._validate_gpu_params("3.5.1", conf) + + for mode in ["yarn", "k8s://"]: + bad_conf = ( + SparkConf() + .setMaster(mode) + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + ) + msg_match = ( + "The `spark.task.resource.gpu.amount` is required for training on GPU" + ) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.3.0", bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.4.0", bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.5.0", bad_conf) + + classifer_on_gpu._validate_gpu_params("3.5.1", bad_conf) + def test_skip_stage_level_scheduling(self) -> None: - conf = ( + standalone_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") @@ -943,26 +1062,36 @@ class TestPySparkLocal: classifer_on_gpu = SparkXGBClassifier(use_gpu=True) # the correct configurations should not skip stage-level scheduling - assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.4.0", standalone_conf + ) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.4.1", standalone_conf + ) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.5.0", standalone_conf + ) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.5.1", standalone_conf + ) # spark version < 3.4.0 - assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf) - + assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", standalone_conf) # not run on GPU - assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", conf) + assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", standalone_conf) # spark.executor.cores is not set - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.task.cpus", "1") .set("spark.executor.resource.gpu.amount", "1") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.executor.cores=1 - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "1") @@ -970,20 +1099,20 @@ class TestPySparkLocal: .set("spark.executor.resource.gpu.amount", "1") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.executor.resource.gpu.amount is not set - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") .set("spark.task.cpus", "1") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.executor.resource.gpu.amount>1 - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") @@ -991,20 +1120,20 @@ class TestPySparkLocal: .set("spark.executor.resource.gpu.amount", "2") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.task.resource.gpu.amount is not set - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") .set("spark.task.cpus", "1") .set("spark.executor.resource.gpu.amount", "1") ) - assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.task.resource.gpu.amount=1 - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") @@ -1012,29 +1141,32 @@ class TestPySparkLocal: .set("spark.executor.resource.gpu.amount", "1") .set("spark.task.resource.gpu.amount", "1") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) - # yarn - badConf = ( - SparkConf() - .setMaster("yarn") - .set("spark.executor.cores", "12") - .set("spark.task.cpus", "1") - .set("spark.executor.resource.gpu.amount", "1") - .set("spark.task.resource.gpu.amount", "1") - ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + # For Yarn and K8S + for mode in ["yarn", "k8s://"]: + for gpu_amount in ["0.08", "0.2", "1.0"]: + conf = ( + SparkConf() + .setMaster(mode) + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + .set("spark.task.resource.gpu.amount", gpu_amount) + ) + assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.1", conf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.5.0", conf) - # k8s - badConf = ( - SparkConf() - .setMaster("k8s://") - .set("spark.executor.cores", "12") - .set("spark.task.cpus", "1") - .set("spark.executor.resource.gpu.amount", "1") - .set("spark.task.resource.gpu.amount", "1") - ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + # This will be fixed when spark 4.0.0 is released. + if gpu_amount == "1.0": + assert classifer_on_gpu._skip_stage_level_scheduling("3.5.1", conf) + else: + # Starting from 3.5.1+, stage-level scheduling is working for Yarn and K8s + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.5.1", conf + ) class XgboostLocalTest(SparkTestCase): From 3fbb221fecf0d7bc98a56f79d668497624ab8f62 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 20 Apr 2024 04:08:17 +0800 Subject: [PATCH 26/26] [coll] Implement shutdown for tracker and comm. (#10208) - Force shutdown the tracker. - Implement shutdown notice for error handling thread in comm. --- include/xgboost/c_api.h | 18 ++- include/xgboost/collective/socket.h | 81 +++++++--- plugin/federated/federated_tracker.cc | 6 +- rabit/include/rabit/internal/socket.h | 24 ++- rabit/src/allreduce_base.cc | 24 +-- rabit/src/allreduce_base.h | 4 +- src/c_api/coll_c_api.cc | 35 ++++- src/collective/coll.cc | 4 + src/collective/comm.cc | 140 ++++++++++++++---- src/collective/comm.h | 30 ++-- src/collective/comm_group.cc | 17 ++- src/collective/comm_group.h | 6 +- src/collective/protocol.h | 55 ++++++- src/collective/socket.cc | 27 ++-- src/collective/tracker.cc | 92 +++++++++--- src/collective/tracker.h | 41 +++-- tests/cpp/collective/test_coll_c_api.cc | 8 +- tests/cpp/collective/test_comm.cc | 8 +- tests/cpp/collective/test_comm_group.cc | 15 +- tests/cpp/collective/test_loop.cc | 8 +- tests/cpp/collective/test_socket.cc | 16 +- tests/cpp/collective/test_tracker.cc | 58 +++++++- tests/cpp/collective/test_worker.h | 30 +++- .../federated/test_federated_tracker.cc | 5 +- 24 files changed, 553 insertions(+), 199 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index e065d8ba1..19b93c644 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); /** * @brief Get the arguments needed for running workers. This should be called after - * XGTrackerRun() and XGTrackerWait() + * XGTrackerRun(). * * @param handle The handle to the tracker. * @param args The arguments returned as a JSON document. @@ -1565,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args); /** - * @brief Run the tracker. + * @brief Start the tracker. The tracker runs in the background and this function returns + * once the tracker is started. * * @param handle The handle to the tracker. + * @param config Unused at the moment, preserved for the future. * * @return 0 for success, -1 for failure. */ -XGB_DLL int XGTrackerRun(TrackerHandle handle); +XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config); /** - * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). + * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This + * function will block until the tracker task is finished or timeout is reached. * * @param handle The handle to the tracker. * @param config JSON encoded configuration. No argument is required yet, preserved for @@ -1582,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle); * * @return 0 for success, -1 for failure. */ -XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config); +XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config); /** - * @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker - * cannot close properly, manual interruption is required. + * @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the + * tracker is not properly waited, this function will shutdown all connections with + * the tracker, potentially leading to undefined behavior. * * @param handle The handle to the tracker. * diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 11520eede..0e098052c 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) { #endif } +inline std::int32_t ShutdownSocket(SocketT fd) { +#if defined(_WIN32) + auto rc = shutdown(fd, SD_BOTH); + if (rc != 0 && LastError() == WSANOTINITIALISED) { + return 0; + } +#else + auto rc = shutdown(fd, SHUT_RDWR); + if (rc != 0 && LastError() == ENOTCONN) { + return 0; + } +#endif + return rc; +} + inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) { #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; @@ -499,36 +514,49 @@ class TCPSocket { */ [[nodiscard]] HandleT const &Handle() const { return handle_; } /** - * \brief Listen to incoming requests. Should be called after bind. + * @brief Listen to incoming requests. Should be called after bind. */ - void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); } + [[nodiscard]] Result Listen(std::int32_t backlog = 16) { + if (listen(handle_, backlog) != 0) { + return system::FailWithCode("Failed to listen."); + } + return Success(); + } /** - * \brief Bind socket to INADDR_ANY, return the port selected by the OS. + * @brief Bind socket to INADDR_ANY, return the port selected by the OS. */ - [[nodiscard]] in_port_t BindHost() { + [[nodiscard]] Result BindHost(std::int32_t* p_out) { + // Use int32 instead of in_port_t for consistency. We take port as parameter from + // users using other languages, the port is usually stored and passed around as int. if (Domain() == SockDomain::kV6) { auto addr = SockAddrV6::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); - xgboost_CHECK_SYS_CALL( - bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) { + return system::FailWithCode("bind failed."); + } sockaddr_in6 res_addr; socklen_t addrlen = sizeof(res_addr); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); - return ntohs(res_addr.sin6_port); + if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) { + return system::FailWithCode("getsockname failed."); + } + *p_out = ntohs(res_addr.sin6_port); } else { auto addr = SockAddrV4::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); - xgboost_CHECK_SYS_CALL( - bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) { + return system::FailWithCode("bind failed."); + } sockaddr_in res_addr; socklen_t addrlen = sizeof(res_addr); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); - return ntohs(res_addr.sin_port); + if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) { + return system::FailWithCode("getsockname failed."); + } + *p_out = ntohs(res_addr.sin_port); } + + return Success(); } [[nodiscard]] auto Port() const { @@ -641,13 +669,13 @@ class TCPSocket { */ std::size_t Send(StringView str); /** - * \brief Receive string, format is matched with the Python socket wrapper in RABIT. + * @brief Receive string, format is matched with the Python socket wrapper in RABIT. */ - std::size_t Recv(std::string *p_str); + [[nodiscard]] Result Recv(std::string *p_str); /** * @brief Close the socket, called automatically in destructor if the socket is not closed. */ - Result Close() { + [[nodiscard]] Result Close() { if (InvalidSocket() != handle_) { auto rc = system::CloseSocket(handle_); #if defined(_WIN32) @@ -664,6 +692,25 @@ class TCPSocket { } return Success(); } + /** + * @brief Call shutdown on the socket. + */ + [[nodiscard]] Result Shutdown() { + if (this->IsClosed()) { + return Success(); + } + auto rc = system::ShutdownSocket(this->Handle()); +#if defined(_WIN32) + // Windows cannot shutdown a socket if it's not connected. + if (rc == -1 && system::LastError() == WSAENOTCONN) { + return Success(); + } +#endif + if (rc != 0) { + return system::FailWithCode("Failed to shutdown socket."); + } + return Success(); + } /** * \brief Create a TCP socket on specified domain. diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc index 37b6c3639..5051d43cb 100644 --- a/plugin/federated/federated_tracker.cc +++ b/plugin/federated/federated_tracker.cc @@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() { [[nodiscard]] Json FederatedTracker::WorkerArgs() const { auto rc = this->WaitUntilReady(); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); std::string host; rc = GetHostAddress(&host); CHECK(rc.OK()); Json args{Object{}}; - args["DMLC_TRACKER_URI"] = String{host}; - args["DMLC_TRACKER_PORT"] = this->Port(); + args["dmlc_tracker_uri"] = String{host}; + args["dmlc_tracker_port"] = this->Port(); return args; } } // namespace xgboost::collective diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index 89e324482..cec246efd 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -100,6 +100,24 @@ std::enable_if_t, xgboost::collective::Result> PollError(E if ((revents & POLLNVAL) != 0) { return xgboost::system::FailWithCode("Invalid polling request."); } + if ((revents & POLLHUP) != 0) { + // Excerpt from the Linux manual: + // + // Note that when reading from a channel such as a pipe or a stream socket, this event + // merely indicates that the peer closed its end of the channel.Subsequent reads from + // the channel will return 0 (end of file) only after all outstanding data in the + // channel has been consumed. + // + // We don't usually have a barrier for exiting workers, it's normal to have one end + // exit while the other still reading data. + return xgboost::collective::Success(); + } +#if defined(POLLRDHUP) + // Linux only flag + if ((revents & POLLRDHUP) != 0) { + return xgboost::system::FailWithCode("Poll hung up on the other end."); + } +#endif // defined(POLLRDHUP) return xgboost::collective::Success(); } @@ -179,9 +197,11 @@ struct PollHelper { } std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout); if (ret == 0) { - return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out)); + return xgboost::collective::Fail( + "Poll timeout:" + std::to_string(timeout.count()) + " seconds.", + std::make_error_code(std::errc::timed_out)); } else if (ret < 0) { - return xgboost::system::FailWithCode("Poll failed."); + return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size())); } for (auto& pfd : fdset) { diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index b99eb3763..fcf80b414 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() { try { for (auto &all_link : all_links) { if (!all_link.sock.IsClosed()) { - all_link.sock.Close(); + SafeColl(all_link.sock.Close()); } } all_links.clear(); @@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() { LOG(FATAL) << rc.Report(); } tracker.Send(xgboost::StringView{"shutdown"}); - tracker.Close(); + SafeColl(tracker.Close()); xgboost::system::SocketFinalize(); return true; } catch (std::exception const &e) { @@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { tracker.Send(xgboost::StringView{"print"}); tracker.Send(xgboost::StringView{msg}); - tracker.Close(); + SafeColl(tracker.Close()); } // util to parse data with unit suffix @@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) { auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())}; // create listening socket - int port = sock_listen.BindHost(); - utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); - sock_listen.Listen(); + std::int32_t port{0}; + SafeColl(sock_listen.BindHost(&port)); + SafeColl(sock_listen.Listen()); // get number of to connect and number of to accept nodes from tracker int num_conn, num_accept, num_error = 1; do { for (auto & all_link : all_links) { - all_link.sock.Close(); + SafeColl(all_link.sock.Close()); } // tracker construct goodset Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), @@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { LinkRecord r; int hport, hrank; std::string hname; - tracker.Recv(&hname); + SafeColl(tracker.Recv(&hname)); Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); // connect to peer @@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { timeout_sec, &r.sock) .OK()) { num_error += 1; - r.sock.Close(); + SafeColl(r.sock.Close()); continue; } Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), @@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { // send back socket listening port to tracker Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); // close connection to tracker - tracker.Close(); + SafeColl(tracker.Close()); // listen to incoming links for (int i = 0; i < num_accept; ++i) { @@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { } if (!match) all_links.emplace_back(std::move(r)); } - sock_listen.Close(); + SafeColl(sock_listen.Close()); this->parent_index = -1; // setup tree links and ring structure @@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, Recv(sendrecvbuf + size_down_in, total_size - size_down_in); if (len == 0) { - links[parent_index].sock.Close(); + SafeColl(links[parent_index].sock.Close()); return ReportError(&links[parent_index], kRecvZeroLen); } if (len != -1) { diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index 7724bf3d5..9991c2138 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -270,7 +270,7 @@ class AllreduceBase : public IEngine { ssize_t len = sock.Recv(buffer_head + offset, nmax); // length equals 0, remote disconnected if (len == 0) { - sock.Close(); return kRecvZeroLen; + SafeColl(sock.Close()); return kRecvZeroLen; } if (len == -1) return Errno2Return(); size_read += static_cast(len); @@ -289,7 +289,7 @@ class AllreduceBase : public IEngine { ssize_t len = sock.Recv(p + size_read, max_size - size_read); // length equals 0, remote disconnected if (len == 0) { - sock.Close(); return kRecvZeroLen; + SafeColl(sock.Close()); return kRecvZeroLen; } if (len == -1) return Errno2Return(); size_read += static_cast(len); diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc index 24e94f3de..fba2647cc 100644 --- a/src/c_api/coll_c_api.cc +++ b/src/c_api/coll_c_api.cc @@ -5,9 +5,11 @@ #include // for future #include // for unique_ptr #include // for string +#include // for sleep_for #include // for is_same_v, remove_pointer_t #include // for pair +#include "../collective/comm.h" // for DefaultTimeoutSec #include "../collective/tracker.h" // for RabitTracker #include "../common/timer.h" // for Timer #include "c_api_error.h" // for API_BEGIN @@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT namespace { using TrackerHandleT = - std::pair, std::shared_future>; + std::pair, std::shared_future>; TrackerHandleT *GetTrackerHandle(TrackerHandle handle) { xgboost_CHECK_C_ARG_PTR(handle); @@ -41,12 +43,14 @@ struct CollAPIEntry { using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { - constexpr std::int64_t kDft{60}; + constexpr std::int64_t kDft{collective::DefaultTimeoutSec()}; std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft}; common::Timer timer; timer.Start(); + auto ref = ptr->first; // hold a reference to that free don't delete it while waiting. + auto fut = ptr->second; while (fut.valid()) { auto res = fut.wait_for(wait_for); @@ -72,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) { Json jconfig = Json::Load(config); auto type = RequiredArg(jconfig, "dmlc_communicator", __func__); - std::unique_ptr tptr; + std::shared_ptr tptr; if (type == "federated") { #if defined(XGBOOST_USE_FEDERATED) - tptr = std::make_unique(jconfig); + tptr = std::make_shared(jconfig); #else LOG(FATAL) << error::NoFederated(); #endif // defined(XGBOOST_USE_FEDERATED) } else if (type == "rabit") { - tptr = std::make_unique(jconfig); + tptr = std::make_shared(jconfig); } else { LOG(FATAL) << "Unknown communicator:" << type; } @@ -103,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) { API_END(); } -XGB_DLL int XGTrackerRun(TrackerHandle handle) { +XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); CHECK(!ptr->second.valid()) << "Tracker is already running."; @@ -111,13 +115,14 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) { API_END(); } -XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { +XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); // Internally, 0 indicates no timeout, which is the default since we don't want to // interrupt the model training. + xgboost_CHECK_C_ARG_PTR(config); auto timeout = OptionalArg(jconfig, "timeout", std::int64_t{0}); WaitImpl(ptr, std::chrono::seconds{timeout}); API_END(); @@ -125,8 +130,24 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { XGB_DLL int XGTrackerFree(TrackerHandle handle) { API_BEGIN(); + using namespace std::chrono_literals; // NOLINT auto *ptr = GetTrackerHandle(handle); + ptr->first->Stop(); + // The wait is not necessary since we just called stop, just reusing the function to do + // any potential cleanups. WaitImpl(ptr, ptr->first->Timeout()); + common::Timer timer; + timer.Start(); + // Make sure no one else is waiting on the tracker. + while (!ptr->first.unique()) { + auto ela = timer.Duration().count(); + if (ela > ptr->first->Timeout().count()) { + LOG(WARNING) << "Time out " << ptr->first->Timeout().count() + << " seconds reached for TrackerFree, killing the tracker."; + break; + } + std::this_thread::sleep_for(64ms); + } delete ptr; API_END(); } diff --git a/src/collective/coll.cc b/src/collective/coll.cc index c6d03c6df..b720d09b7 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -38,6 +38,10 @@ bool constexpr IsFloatingPointV() { auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto p_lhs = lhs.data(); auto p_out = out.data(); +#if defined(__GNUC__) || defined(__clang__) + // For the sum op, one can verify the simd by: addps %xmm15, %xmm14 +#pragma omp simd +#endif for (std::size_t i = 0; i < lhs.size(); ++i) { p_out[i] = elem_op(p_lhs[i], p_out[i]); } diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 23a8e89ed..50a14aaaf 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -5,9 +5,11 @@ #include // for copy #include // for seconds +#include // for int32_t #include // for exit #include // for shared_ptr #include // for string +#include // for thread #include // for move, forward #if !defined(XGBOOST_USE_NCCL) #include "../common/common.h" // for AssertNCCLSupport @@ -184,13 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st return Success(); } -RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t retry, std::string task_id, StringView nccl_path) - : HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, +namespace { +std::string InitLog(std::string task_id, std::int32_t rank) { + if (task_id.empty()) { + return "Rank " + std::to_string(rank); + } + return "Task " + task_id + " got rank " + std::to_string(rank); +} +} // namespace + +RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port, + std::chrono::seconds timeout, std::int32_t retry, std::string task_id, + StringView nccl_path) + : HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)}, nccl_path_{std::move(nccl_path)} { + if (this->TrackerInfo().host.empty()) { + // Not in a distributed environment. + LOG(CONSOLE) << InitLog(task_id_, rank_); + return; + } + loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT auto rc = this->Bootstrap(timeout_, retry_, task_id_); if (!rc.OK()) { + this->ResetState(); SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); } } @@ -217,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // Start command TCPSocket listener = TCPSocket::Create(tracker.Domain()); - std::int32_t lport = listener.BindHost(); - listener.Listen(); + std::int32_t lport{0}; + rc = std::move(rc) << [&] { + return listener.BindHost(&lport); + } << [&] { + return listener.Listen(); + }; + if (!rc.OK()) { + return rc; + } // create worker for listening to error notice. auto domain = tracker.Domain(); std::shared_ptr error_sock{TCPSocket::CreatePtr(domain)}; - auto eport = error_sock->BindHost(); - error_sock->Listen(); + std::int32_t eport{0}; + rc = std::move(rc) << [&] { + return error_sock->BindHost(&eport); + } << [&] { + return error_sock->Listen(); + }; + if (!rc.OK()) { + return rc; + } + error_port_ = eport; + error_worker_ = std::thread{[error_sock = std::move(error_sock)] { - auto conn = error_sock->Accept(); + TCPSocket conn; + SockAddress addr; + auto rc = error_sock->Accept(&conn, &addr); + // On Linux, a shutdown causes an invalid argument error; + if (rc.Code() == std::errc::invalid_argument) { + return; + } // On Windows, accept returns a closed socket after finalize. if (conn.IsClosed()) { return; } + // The error signal is from the tracker, while shutdown signal is from the shutdown method + // of the RabitComm class (this). + bool is_error{false}; + rc = proto::Error{}.RecvSignal(&conn, &is_error); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + return; + } + if (!is_error) { + return; // shutdown + } + LOG(WARNING) << "Another worker is running into error."; #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 // exit is nicer than abort as the former performs cleanups. @@ -239,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { LOG(FATAL) << "abort"; #endif }}; + // The worker thread is detached here to avoid the need to handle it later during + // destruction. For C++, if a thread is not joined or detached, it will segfault during + // destruction. error_worker_.detach(); proto::Start start; @@ -251,7 +307,10 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // get ring neighbors std::string snext; - tracker.Recv(&snext); + rc = tracker.Recv(&snext); + if (!rc.OK()) { + return Fail("Failed to receive the rank for the next worker.", std::move(rc)); + } auto jnext = Json::Load(StringView{snext}); proto::PeerInfo ninfo{jnext}; @@ -268,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { CHECK(this->channels_.empty()); for (auto& w : workers) { if (w) { - rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); } - << [&] { return w->SetKeepAlive(); }; + rc = std::move(rc) << [&] { + return w->SetNoDelay(); + } << [&] { + return w->NonBlocking(true); + } << [&] { + return w->SetKeepAlive(); + }; } if (!rc.OK()) { return rc; } this->channels_.emplace_back(std::make_shared(*this, w)); } + + LOG(CONSOLE) << InitLog(task_id_, rank_); return rc; } @@ -283,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) { if (!this->IsDistributed()) { return; } + LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can " + "lead to undefined behaviour."; auto rc = this->Shutdown(); if (!rc.OK()) { LOG(WARNING) << rc.Report(); @@ -293,30 +361,49 @@ RabitComm::~RabitComm() noexcept(false) { if (!this->IsDistributed()) { return Success(); } - + // Tell the tracker that this worker is shutting down. TCPSocket tracker; + // Tell the error hanlding thread that we are shutting down. + TCPSocket err_client; + return Success() << [&] { return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); } << [&] { return this->Block(); } << [&] { - Json jcmd{Object{}}; - jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; - auto scmd = Json::Dump(jcmd); - auto n_bytes = tracker.Send(scmd); - if (n_bytes != scmd.size()) { - return Fail("Faled to send cmd."); - } - - this->ResetState(); - return Success(); + return proto::ShutdownCMD{}.Send(&tracker); } << [&] { this->channels_.clear(); return Success(); + } << [&] { + // Use tracker address to determine whether we want to use IPv6. + auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port); + // Shutdown the error handling thread. We signal the thread through socket, + // alternatively, we can get the native handle and use pthread_cancel. But using a + // socket seems to be clearer as we know what's happening. + auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr(); + // We use hardcoded 10 seconds and 1 retry here since we are just connecting to a + // local socket. For a normal OS, this should be enough time to schedule the + // connection. + auto rc = Connect(StringView{addr}, this->error_port_, 1, + std::min(std::chrono::seconds{10}, timeout_), &err_client); + this->ResetState(); + if (!rc.OK()) { + return Fail("Failed to connect to the error socket.", std::move(rc)); + } + return rc; + } << [&] { + // We put error thread shutdown at the end so that we have a better chance to finish + // the previous more important steps. + return proto::Error{}.SignalShutdown(&err_client); }; } [[nodiscard]] Result RabitComm::LogTracker(std::string msg) const { + if (!this->IsDistributed()) { + LOG(CONSOLE) << msg; + return Success(); + } TCPSocket out; proto::Print print; return Success() << [&] { return this->ConnectTracker(&out); } @@ -324,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) { } [[nodiscard]] Result RabitComm::SignalError(Result const& res) { - TCPSocket out; - return Success() << [&] { return this->ConnectTracker(&out); } - << [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); }; + TCPSocket tracker; + return Success() << [&] { + return this->ConnectTracker(&tracker); + } << [&] { + return proto::ErrorCMD{}.WorkerSend(&tracker, res); + }; } } // namespace xgboost::collective diff --git a/src/collective/comm.h b/src/collective/comm.h index 6ad5bc5c1..a41f47be9 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -1,10 +1,10 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for seconds #include // for size_t -#include // for int32_t +#include // for int32_t, int64_t #include // for shared_ptr #include // for string #include // for thread @@ -20,7 +20,7 @@ namespace xgboost::collective { -inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min +inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min inline constexpr std::int32_t DefaultRetry() { return 3; } // indexing into the ring @@ -51,7 +51,10 @@ class Comm : public std::enable_shared_from_this { proto::PeerInfo tracker_; SockDomain domain_{SockDomain::kV4}; + std::thread error_worker_; + std::int32_t error_port_; + std::string task_id_; std::vector> channels_; std::shared_ptr loop_{nullptr}; // fixme: require federated comm to have a timeout @@ -59,6 +62,13 @@ class Comm : public std::enable_shared_from_this { void ResetState() { this->world_ = -1; this->rank_ = 0; + this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()}; + + tracker_ = proto::PeerInfo{}; + this->task_id_.clear(); + channels_.clear(); + + loop_.reset(); } public: @@ -79,9 +89,9 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] auto Retry() const { return retry_; } [[nodiscard]] auto TaskID() const { return task_id_; } - [[nodiscard]] auto Rank() const { return rank_; } - [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } - [[nodiscard]] bool IsDistributed() const { return world_ != -1; } + [[nodiscard]] auto Rank() const noexcept { return rank_; } + [[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; } + [[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; } void Submit(Loop::Op op) const { CHECK(loop_); loop_->Submit(op); @@ -120,20 +130,20 @@ class RabitComm : public HostComm { [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id); - [[nodiscard]] Result Shutdown() final; public: // bootstrapping construction. RabitComm() = default; - // ctor for testing where environment is known. - RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t retry, std::string task_id, StringView nccl_path); + RabitComm(std::string const& tracker_host, std::int32_t tracker_port, + std::chrono::seconds timeout, std::int32_t retry, std::string task_id, + StringView nccl_path); ~RabitComm() noexcept(false) override; [[nodiscard]] bool IsFederated() const override { return false; } [[nodiscard]] Result LogTracker(std::string msg) const override; [[nodiscard]] Result SignalError(Result const&) override; + [[nodiscard]] Result Shutdown() final; [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; }; diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index 7408882f6..18a5ba8a7 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -64,6 +64,9 @@ CommGroup::CommGroup() auto const& obj = get(config); auto it = obj.find(upper); + if (it != obj.cend() && obj.find(name) != obj.cend()) { + LOG(FATAL) << "Duplicated parameter:" << name; + } if (it != obj.cend()) { return OptionalArg(config, upper, dft); } else { @@ -77,14 +80,14 @@ CommGroup::CommGroup() auto task_id = get_param("dmlc_task_id", std::string{}, String{}); if (type == "rabit") { - auto host = get_param("dmlc_tracker_uri", std::string{}, String{}); - auto port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); + auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{}); + auto tracker_port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{}); - auto ptr = - new CommGroup{std::shared_ptr{new RabitComm{ // NOLINT - host, static_cast(port), std::chrono::seconds{timeout}, - static_cast(retry), task_id, nccl}}, - std::shared_ptr(new Coll{})}; // NOLINT + auto ptr = new CommGroup{ + std::shared_ptr{new RabitComm{ // NOLINT + tracker_host, static_cast(tracker_port), std::chrono::seconds{timeout}, + static_cast(retry), task_id, nccl}}, + std::shared_ptr(new Coll{})}; // NOLINT return ptr; } else if (type == "federated") { #if defined(XGBOOST_USE_FEDERATED) diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index 61a58ba56..a98de0c16 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -30,9 +30,9 @@ class CommGroup { public: CommGroup(); - [[nodiscard]] auto World() const { return comm_->World(); } - [[nodiscard]] auto Rank() const { return comm_->Rank(); } - [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + [[nodiscard]] auto World() const noexcept { return comm_->World(); } + [[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); } + [[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); } [[nodiscard]] Result Finalize() const { return Success() << [this] { diff --git a/src/collective/protocol.h b/src/collective/protocol.h index 96edf4e29..29e6c9619 100644 --- a/src/collective/protocol.h +++ b/src/collective/protocol.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int32_t @@ -58,6 +58,7 @@ struct Magic { } }; +// Basic commands for communication between workers and the tracker. enum class CMD : std::int32_t { kInvalid = 0, kStart = 1, @@ -84,7 +85,10 @@ struct Connect { [[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank, std::string* task_id) const { std::string init; - sock->Recv(&init); + auto rc = sock->Recv(&init); + if (!rc.OK()) { + return Fail("Connect protocol failed.", std::move(rc)); + } auto jinit = Json::Load(StringView{init}); *world = get(jinit["world_size"]); *rank = get(jinit["rank"]); @@ -122,9 +126,9 @@ class Start { } [[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const { std::string scmd; - auto n_bytes = tracker->Recv(&scmd); - if (n_bytes <= 0) { - return Fail("Failed to recv init command from tracker."); + auto rc = tracker->Recv(&scmd); + if (!rc.OK()) { + return Fail("Failed to recv init command from tracker.", std::move(rc)); } auto jcmd = Json::Load(scmd); auto world = get(jcmd["world_size"]); @@ -132,7 +136,7 @@ class Start { return Fail("Invalid world size."); } *p_world = world; - return Success(); + return rc; } [[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world, std::int32_t* p_port, TCPSocket* p_sock, @@ -150,6 +154,7 @@ class Start { } }; +// Protocol for communicating with the tracker for printing message. struct Print { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const { Json jcmd{Object{}}; @@ -172,6 +177,7 @@ struct Print { } }; +// Protocol for communicating with the tracker during error. struct ErrorCMD { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const { auto msg = res.Report(); @@ -199,6 +205,7 @@ struct ErrorCMD { } }; +// Protocol for communicating with the tracker during shutdown. struct ShutdownCMD { [[nodiscard]] Result Send(TCPSocket* peer) const { Json jcmd{Object{}}; @@ -211,4 +218,40 @@ struct ShutdownCMD { return Success(); } }; + +// Protocol for communicating with the local error handler during error or shutdown. Only +// one protocol that doesn't have the tracker involved. +struct Error { + constexpr static std::int32_t ShutdownSignal() { return 0; } + constexpr static std::int32_t ErrorSignal() { return -1; } + + [[nodiscard]] Result SignalError(TCPSocket* worker) const { + std::int32_t err{ErrorSignal()}; + auto n_sent = worker->SendAll(&err, sizeof(err)); + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send error signal"); + } + // self is localhost, we are sending the signal to the error handling thread for it to + // close. + [[nodiscard]] Result SignalShutdown(TCPSocket* self) const { + std::int32_t err{ShutdownSignal()}; + auto n_sent = self->SendAll(&err, sizeof(err)); + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send shutdown signal"); + } + // get signal, either for error or for shutdown. + [[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const { + std::int32_t err{ShutdownSignal()}; + auto n_recv = peer->RecvAll(&err, sizeof(err)); + if (n_recv == sizeof(err)) { + *p_is_error = err == 1; + return Success(); + } + return Fail("Failed to receive error signal."); + } +}; } // namespace xgboost::collective::proto diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 43da366bd..737ce584e 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include "xgboost/collective/socket.h" @@ -8,7 +8,8 @@ #include // std::int32_t #include // std::memcpy, std::memset #include // for path -#include // std::error_code, std::system_category +#include // for error_code, system_category +#include // for sleep_for #include "rabit/internal/socket.h" // for PollHelper #include "xgboost/collective/result.h" // for Result @@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) { return bytes; } -std::size_t TCPSocket::Recv(std::string *p_str) { +[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) { CHECK(!this->IsClosed()); std::int32_t len; - CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length."; + if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) { + return Fail("Failed to recv string length."); + } p_str->resize(len); auto bytes = this->RecvAll(&(*p_str)[0], len); - CHECK_EQ(bytes, len) << "Failed to recv string."; - return bytes; + if (static_cast(bytes) != len) { + return Fail("Failed to recv string."); + } + return Success(); } [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, @@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) { for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) { if (attempt > 0) { LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time."; -#if defined(_MSC_VER) || defined(__MINGW32__) - Sleep(attempt << 1); -#else - sleep(attempt << 1); -#endif + std::this_thread::sleep_for(std::chrono::seconds{attempt << 1}); } auto rc = connect(conn.Handle(), addr_handle, addr_len); @@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) { std::stringstream ss; ss << "Failed to connect to " << host << ":" << port; - conn.Close(); - return Fail(ss.str(), std::move(last_error)); + auto close_rc = conn.Close(); + return Fail(ss.str(), std::move(close_rc) + std::move(last_error)); } [[nodiscard]] Result GetHostName(std::string *p_out) { diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 3fdf75ead..142483ccf 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -1,6 +1,7 @@ /** * Copyright 2023-2024, XGBoost Contributors */ +#include "rabit/internal/socket.h" #if defined(__unix__) || defined(__APPLE__) #include // gethostbyname #include // socket, AF_INET6, AF_INET, connect, getsockname @@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); } << [&] { std::string cmd; - sock_.Recv(&cmd); + auto rc = sock_.Recv(&cmd); + if (!rc.OK()) { + return rc; + } jcmd = Json::Load(StringView{cmd}); cmd_ = static_cast(get(jcmd["cmd"])); - return Success(); + return rc; } << [&] { if (cmd_ == proto::CMD::kStart) { proto::Start start; @@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA RabitTracker::RabitTracker(Json const& config) : Tracker{config} { std::string self; - auto rc = collective::GetHostAddress(&self); - host_ = OptionalArg(config, "host", self); + auto rc = Success() << [&] { + return collective::GetHostAddress(&self); + } << [&] { + host_ = OptionalArg(config, "host", self); - auto addr = MakeSockAddress(xgboost::StringView{host_}, 0); - listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); - rc = listener_.Bind(host_, &this->port_); + auto addr = MakeSockAddress(xgboost::StringView{host_}, 0); + listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); + return listener_.Bind(host_, &this->port_); + } << [&] { + return listener_.Listen(); + }; SafeColl(rc); - listener_.Listen(); } Result RabitTracker::Bootstrap(std::vector* p_workers) { @@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { // // retry is set to 1, just let the worker timeout or error. Otherwise the // tracker and the worker might be waiting for each other. - auto rc = Connect(w.first, w.second, 1, timeout_, &out); + auto rc = Success() << [&] { + return Connect(w.first, w.second, 1, timeout_, &out); + } << [&] { + return proto::Error{}.SignalError(&out); + }; if (!rc.OK()) { - return Fail("Failed to inform workers to stop."); + return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc)); } } return Success(); @@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { return std::async(std::launch::async, [this, handle_error] { State state{this->n_workers_}; + auto select_accept = [&](TCPSocket* sock, auto* addr) { + // accept with poll so that we can enable timeout and interruption. + rabit::utils::PollHelper poll; + auto rc = Success() << [&] { + std::lock_guard lock{listener_mu_}; + return listener_.NonBlocking(true); + } << [&] { + std::lock_guard lock{listener_mu_}; + poll.WatchRead(listener_); + if (state.running) { + // Don't timeout if the communicator group is up and running. + return poll.Poll(std::chrono::seconds{-1}); + } else { + // Have timeout for workers to bootstrap. + return poll.Poll(timeout_); + } + } << [&] { + // this->Stop() closes the socket with a lock. Therefore, when the accept returns + // due to shutdown, the state is still valid (closed). + return listener_.Accept(sock, addr); + }; + return rc; + }; + while (state.ShouldContinue()) { TCPSocket sock; SockAddress addr; this->ready_ = true; - auto rc = listener_.Accept(&sock, &addr); + auto rc = select_accept(&sock, &addr); if (!rc.OK()) { - return Fail("Failed to accept connection.", std::move(rc)); + return Fail("Failed to accept connection.", this->Stop() + std::move(rc)); } auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)}; @@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { state.Error(); rc = handle_error(worker); if (!rc.OK()) { - return Fail("Failed to handle abort.", std::move(rc)); + return Fail("Failed to handle abort.", this->Stop() + std::move(rc)); } } @@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { state.Bootstrap(); } if (!rc.OK()) { - return rc; + return this->Stop() + std::move(rc); } continue; } @@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { } case proto::CMD::kInvalid: default: { - return Fail("Invalid command received."); + return Fail("Invalid command received.", this->Stop()); } } } - ready_ = false; - return Success(); + return this->Stop(); }); } @@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { SafeColl(rc); Json args{Object{}}; - args["DMLC_TRACKER_URI"] = String{host_}; - args["DMLC_TRACKER_PORT"] = this->Port(); + args["dmlc_tracker_uri"] = String{host_}; + args["dmlc_tracker_port"] = this->Port(); return args; } +[[nodiscard]] Result RabitTracker::Stop() { + if (!this->Ready()) { + return Success(); + } + + ready_ = false; + std::lock_guard lock{listener_mu_}; + if (this->listener_.IsClosed()) { + return Success(); + } + + return Success() << [&] { + // This should have the effect of stopping the `accept` call. + return this->listener_.Shutdown(); + } << [&] { + return listener_.Close(); + }; +} + [[nodiscard]] Result GetHostAddress(std::string* out) { auto rc = GetHostName(out); if (!rc.OK()) { diff --git a/src/collective/tracker.h b/src/collective/tracker.h index e15aaee59..af30e0be7 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -36,15 +36,18 @@ namespace xgboost::collective { * signal an error to the tracker and the tracker will notify other workers. */ class Tracker { + public: + enum class SortBy : std::int8_t { + kHost = 0, + kTask = 1, + }; + protected: // How to sort the workers, either by host name or by task ID. When using a multi-GPU // setting, multiple workers can occupy the same host, in which case one should sort // workers by task. Due to compatibility reason, the task ID is not always available, so // we use host as the default. - enum class SortBy : std::int8_t { - kHost = 0, - kTask = 1, - } sortby_; + SortBy sortby_; protected: std::int32_t n_workers_{0}; @@ -54,10 +57,7 @@ class Tracker { public: explicit Tracker(Json const& config); - Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout) - : n_workers_{n_worders}, port_{port}, timeout_{timeout} {} - - virtual ~Tracker() noexcept(false){}; // NOLINT + virtual ~Tracker() = default; [[nodiscard]] Result WaitUntilReady() const; @@ -69,6 +69,11 @@ class Tracker { * @brief Flag to indicate whether the server is running. */ [[nodiscard]] bool Ready() const { return ready_; } + /** + * @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while + * calling accept. + */ + virtual Result Stop() { return Success(); } }; class RabitTracker : public Tracker { @@ -127,28 +132,22 @@ class RabitTracker : public Tracker { // record for how to reach out to workers if error happens. std::vector> worker_error_handles_; // listening socket for incoming workers. - // - // At the moment, the listener calls accept without first polling. We can add an - // additional unix domain socket to allow cancelling the accept. TCPSocket listener_; + // mutex for protecting the listener, used to prevent race when it's listening while + // another thread tries to shut it down. + std::mutex listener_mu_; Result Bootstrap(std::vector* p_workers); public: - explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port, - std::chrono::seconds timeout) - : Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} { - listener_ = TCPSocket::Create(SockDomain::kV4); - auto rc = listener_.Bind(host, &this->port_); - CHECK(rc.OK()) << rc.Report(); - listener_.Listen(); - } - explicit RabitTracker(Json const& config); - ~RabitTracker() noexcept(false) override = default; + ~RabitTracker() override = default; std::future Run() override; [[nodiscard]] Json WorkerArgs() const override; + // Stop the tracker without waiting. This is to prevent the tracker from hanging when + // one of the workers failes to start. + [[nodiscard]] Result Stop() override; }; // Prob the public IP address of the host, need a better method. diff --git a/tests/cpp/collective/test_coll_c_api.cc b/tests/cpp/collective/test_coll_c_api.cc index d80fbc140..c7229ff77 100644 --- a/tests/cpp/collective/test_coll_c_api.cc +++ b/tests/cpp/collective/test_coll_c_api.cc @@ -25,13 +25,13 @@ TEST_F(TrackerAPITest, CAPI) { auto config_str = Json::Dump(config); auto rc = XGTrackerCreate(config_str.c_str(), &handle); ASSERT_EQ(rc, 0); - rc = XGTrackerRun(handle); + rc = XGTrackerRun(handle, nullptr); ASSERT_EQ(rc, 0); std::thread bg_wait{[&] { Json config{Object{}}; auto config_str = Json::Dump(config); - auto rc = XGTrackerWait(handle, config_str.c_str()); + auto rc = XGTrackerWaitFor(handle, config_str.c_str()); ASSERT_EQ(rc, 0); }}; @@ -42,8 +42,8 @@ TEST_F(TrackerAPITest, CAPI) { std::string host; ASSERT_TRUE(GetHostAddress(&host).OK()); - ASSERT_EQ(host, get(args["DMLC_TRACKER_URI"])); - auto port = get(args["DMLC_TRACKER_PORT"]); + ASSERT_EQ(host, get(args["dmlc_tracker_uri"])); + auto port = get(args["dmlc_tracker_port"]); ASSERT_NE(port, 0); std::vector workers; diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc index 8e69b2f8e..c1eb06465 100644 --- a/tests/cpp/collective/test_comm.cc +++ b/tests/cpp/collective/test_comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -14,7 +14,7 @@ class CommTest : public TrackerTest {}; TEST_F(CommTest, Channel) { auto n_workers = 4; - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -29,7 +29,7 @@ TEST_F(CommTest, Channel) { return p_chan->SendAll( EraseType(common::Span{&i, static_cast(1)})); } << [&] { return p_chan->Block(); }; - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } else { auto p_chan = worker.Comm().Chan(i - 1); std::int32_t r{-1}; @@ -37,7 +37,7 @@ TEST_F(CommTest, Channel) { return p_chan->RecvAll( EraseType(common::Span{&r, static_cast(1)})); } << [&] { return p_chan->Block(); }; - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(r, i - 1); } }); diff --git a/tests/cpp/collective/test_comm_group.cc b/tests/cpp/collective/test_comm_group.cc index 0f6bc23a2..3b1b5c5df 100644 --- a/tests/cpp/collective/test_comm_group.cc +++ b/tests/cpp/collective/test_comm_group.cc @@ -17,17 +17,6 @@ namespace xgboost::collective { namespace { -auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { - Json config{Object{}}; - config["dmlc_communicator"] = std::string{"rabit"}; - config["DMLC_TRACKER_URI"] = host; - config["DMLC_TRACKER_PORT"] = port; - config["dmlc_timeout_sec"] = static_cast(timeout.count()); - config["DMLC_TASK_ID"] = std::to_string(r); - config["dmlc_retry"] = 2; - return config; -} - class CommGroupTest : public SocketTest {}; } // namespace @@ -36,7 +25,7 @@ TEST_F(CommGroupTest, Basic) { TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Context ctx; - auto config = MakeConfig(host, port, timeout, r); + auto config = MakeDistributedTestConfig(host, port, timeout, r); std::unique_ptr ptr{CommGroup::Create(config)}; ASSERT_TRUE(ptr->IsDistributed()); ASSERT_EQ(ptr->World(), n_workers); @@ -52,7 +41,7 @@ TEST_F(CommGroupTest, BasicGPU) { TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { auto ctx = MakeCUDACtx(r); - auto config = MakeConfig(host, port, timeout, r); + auto config = MakeDistributedTestConfig(host, port, timeout, r); std::unique_ptr ptr{CommGroup::Create(config)}; auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0)); ASSERT_EQ(comm.TaskID(), std::to_string(r)); diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index 0908d9623..34e0c1de8 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -28,13 +28,11 @@ class LoopTest : public ::testing::Test { auto domain = SockDomain::kV4; pair_.first = TCPSocket::Create(domain); - in_port_t port{0}; + std::int32_t port{0}; auto rc = Success() << [&] { - port = pair_.first.BindHost(); - return Success(); + return pair_.first.BindHost(&port); } << [&] { - pair_.first.Listen(); - return Success(); + return pair_.first.Listen(); }; SafeColl(rc); diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index ced795fef..ea57da9b4 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include @@ -21,14 +21,19 @@ TEST_F(SocketTest, Basic) { auto run_test = [msg](SockDomain domain) { auto server = TCPSocket::Create(domain); ASSERT_EQ(server.Domain(), domain); - auto port = server.BindHost(); - server.Listen(); + std::int32_t port{0}; + auto rc = Success() << [&] { + return server.BindHost(&port); + } << [&] { + return server.Listen(); + }; + SafeColl(rc); TCPSocket client; if (domain == SockDomain::kV4) { auto const& addr = SockAddrV4::Loopback().Addr(); auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } else { auto const& addr = SockAddrV6::Loopback().Addr(); auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); @@ -45,7 +50,8 @@ TEST_F(SocketTest, Basic) { accepted.Send(msg); std::string str; - client.Recv(&str); + rc = client.Recv(&str); + SafeColl(rc); ASSERT_EQ(StringView{str}, msg); }; diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc index 0dce33c0c..8d6cbeff2 100644 --- a/tests/cpp/collective/test_tracker.cc +++ b/tests/cpp/collective/test_tracker.cc @@ -1,6 +1,7 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ +#include #include #include // for seconds @@ -10,6 +11,7 @@ #include // for vector #include "../../../src/collective/comm.h" +#include "../helpers.h" // for GMockThrow #include "test_worker.h" namespace xgboost::collective { @@ -20,13 +22,13 @@ class PrintWorker : public WorkerForTest { void Print() { auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank())); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } }; } // namespace TEST_F(TrackerTest, Bootstrap) { - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; ASSERT_FALSE(tracker.Ready()); auto fut = tracker.Run(); @@ -34,7 +36,7 @@ TEST_F(TrackerTest, Bootstrap) { auto args = tracker.WorkerArgs(); ASSERT_TRUE(tracker.Ready()); - ASSERT_EQ(get(args["DMLC_TRACKER_URI"]), host); + ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); std::int32_t port = tracker.Port(); @@ -44,12 +46,11 @@ TEST_F(TrackerTest, Bootstrap) { for (auto &w : workers) { w.join(); } - - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); } TEST_F(TrackerTest, Print) { - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -73,4 +74,47 @@ TEST_F(TrackerTest, Print) { } TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); } + +/** + * Test connecting the tracker after it has finished. This should not hang the workers. + */ +TEST_F(TrackerTest, AfterShutdown) { + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + auto fut = tracker.Run(); + + std::vector workers; + auto rc = tracker.WaitUntilReady(); + ASSERT_TRUE(rc.OK()); + + std::int32_t port = tracker.Port(); + + // Launch no-op workers to cause the tracker to shutdown. + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; }); + } + + for (auto &w : workers) { + w.join(); + } + + ASSERT_TRUE(fut.get().OK()); + + // Launch workers again, they should fail. + workers.clear(); + for (std::int32_t i = 0; i < n_workers; ++i) { + auto assert_that = [=] { + WorkerForTest worker{host, port, timeout, n_workers, i}; + }; + // On a Linux platform, the connection will be refused, on Apple platform, this gets + // an operation now in progress poll failure, on Windows, it's a timeout error. +#if defined(__linux__) + workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); }); +#else + workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); }); +#endif + } + for (auto &w : workers) { + w.join(); + } +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 7b76052c8..c84df528f 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -37,7 +37,7 @@ class WorkerForTest { comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} { CHECK_EQ(world_size_, comm_.World()); } - virtual ~WorkerForTest() = default; + virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); } auto& Comm() { return comm_; } void LimitSockBuf(std::int32_t n_bytes) { @@ -87,19 +87,30 @@ class TrackerTest : public SocketTest { void SetUp() override { SocketTest::SetUp(); auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } }; +inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers, + std::chrono::seconds timeout) { + Json config{Object{}}; + config["host"] = host; + config["port"] = Integer{0}; + config["n_workers"] = Integer{n_workers}; + config["sortby"] = Integer{static_cast(Tracker::SortBy::kHost)}; + config["timeout"] = timeout.count(); + return config; +} + template void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { std::chrono::seconds timeout{2}; std::string host; auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); LOG(INFO) << "Using " << n_workers << " workers for test."; - RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -115,4 +126,15 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { ASSERT_TRUE(fut.get().OK()); } +inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, + std::chrono::seconds timeout, std::int32_t r) { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"rabit"}; + config["dmlc_tracker_uri"] = host; + config["dmlc_tracker_port"] = port; + config["dmlc_timeout_sec"] = static_cast(timeout.count()); + config["dmlc_task_id"] = std::to_string(r); + config["dmlc_retry"] = 2; + return config; +} } // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_federated_tracker.cc b/tests/cpp/plugin/federated/test_federated_tracker.cc index 81ff95540..aa979ff15 100644 --- a/tests/cpp/plugin/federated/test_federated_tracker.cc +++ b/tests/cpp/plugin/federated/test_federated_tracker.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -8,7 +8,6 @@ #include "../../../../src/collective/tracker.h" // for GetHostAddress #include "federated_tracker.h" -#include "test_worker.h" #include "xgboost/json.h" // for Json namespace xgboost::collective { @@ -26,7 +25,7 @@ TEST(FederatedTrackerTest, Basic) { ASSERT_GE(tracker->Port(), 1); std::string host; auto rc = GetHostAddress(&host); - ASSERT_EQ(get(args["DMLC_TRACKER_URI"]), host); + ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); rc = tracker->Shutdown(); ASSERT_TRUE(rc.OK());