From 9fbde21e9d7f421286623e0c5f21cf9dff0366a9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 2 Jun 2023 20:49:43 +0800 Subject: [PATCH 1/9] Rework the precision metric. (#9222) - Rework the precision metric for both CPU and GPU. - Mention it in the document. - Cleanup old support code for GPU ranking metric. - Deterministic GPU implementation. * Drop support for classification. * type. * use batch shape. * lint. * cpu build. * cpu build. * lint. * Tests. * Fix. * Cleanup error message. --- R-package/R/callbacks.R | 2 +- doc/parameter.rst | 3 +- python-package/xgboost/callback.py | 2 + python-package/xgboost/testing/metrics.py | 54 ++++++- src/common/device_helpers.cuh | 170 ---------------------- src/common/optional_weight.h | 9 +- src/common/quantile.cc | 3 + src/common/quantile.h | 20 ++- src/common/ranking_utils.cc | 13 +- src/common/ranking_utils.cu | 7 +- src/common/ranking_utils.h | 31 +++- src/data/iterative_dmatrix.cc | 4 +- src/metric/metric.cc | 24 +-- src/metric/metric_common.h | 56 +------ src/metric/rank_metric.cc | 161 +++++++------------- src/metric/rank_metric.cu | 151 +++++++------------ src/metric/rank_metric.h | 21 ++- tests/ci_build/lint_python.py | 2 +- tests/cpp/metric/test_rank_metric.h | 39 +++-- tests/python-gpu/test_gpu_eval_metrics.py | 5 +- tests/python/test_eval_metrics.py | 5 +- tests/python/test_quantile_dmatrix.py | 32 ++++ 22 files changed, 312 insertions(+), 502 deletions(-) diff --git a/R-package/R/callbacks.R b/R-package/R/callbacks.R index 20fbd0617..d2ee59476 100644 --- a/R-package/R/callbacks.R +++ b/R-package/R/callbacks.R @@ -319,7 +319,7 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE, # maximize is usually NULL when not set in xgb.train and built-in metrics if (is.null(maximize)) - maximize <<- grepl('(_auc|_map|_ndcg)', metric_name) + maximize <<- grepl('(_auc|_map|_ndcg|_pre)', metric_name) if (verbose && NVL(env$rank, 0) == 0) cat("Will train until ", metric_name, " hasn't improved in ", diff --git a/doc/parameter.rst b/doc/parameter.rst index 8c7cadcdc..f6d3a06b6 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -424,6 +424,7 @@ Specify the learning task and the corresponding learning objective. The objectiv After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation. + - ``pre``: Precision at :math:`k`. Supports only learning to rank task. - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ @@ -435,7 +436,7 @@ Specify the learning task and the corresponding learning objective. The objectiv where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries. - - ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation. + - ``ndcg@n``, ``map@n``, ``pre@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation. - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions. - ``poisson-nloglik``: negative log-likelihood for Poisson regression - ``gamma-nloglik``: negative log-likelihood for gamma regression diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index cc62b354d..88e340737 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -372,6 +372,8 @@ class EarlyStopping(TrainingCallback): maximize_metrics = ( "auc", "aucpr", + "pre", + "pre@", "map", "ndcg", "auc@", diff --git a/python-package/xgboost/testing/metrics.py b/python-package/xgboost/testing/metrics.py index 6edbe0e3d..c9f449f22 100644 --- a/python-package/xgboost/testing/metrics.py +++ b/python-package/xgboost/testing/metrics.py @@ -1,9 +1,61 @@ """Tests for evaluation metrics.""" -from typing import Dict +from typing import Dict, List import numpy as np +import pytest import xgboost as xgb +from xgboost.compat import concat +from xgboost.core import _parse_eval_str + + +def check_precision_score(tree_method: str) -> None: + """Test for precision with ranking and classification.""" + datasets = pytest.importorskip("sklearn.datasets") + + X, y = datasets.make_classification( + n_samples=1024, n_features=4, n_classes=2, random_state=2023 + ) + qid = np.zeros(shape=y.shape) # same group + + ltr = xgb.XGBRanker(n_estimators=2, tree_method=tree_method) + ltr.fit(X, y, qid=qid) + + # re-generate so that XGBoost doesn't evaluate the result to 1.0 + X, y = datasets.make_classification( + n_samples=512, n_features=4, n_classes=2, random_state=1994 + ) + + ltr.set_params(eval_metric="pre@32") + result = _parse_eval_str( + ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y), "Xy")]) + ) + score_0 = result[1][1] + + X_list = [] + y_list = [] + n_query_groups = 3 + q_list: List[np.ndarray] = [] + for i in range(n_query_groups): + # same for all groups + X, y = datasets.make_classification( + n_samples=512, n_features=4, n_classes=2, random_state=1994 + ) + X_list.append(X) + y_list.append(y) + q = np.full(shape=y.shape, fill_value=i, dtype=np.uint64) + q_list.append(q) + + qid = concat(q_list) + X = concat(X_list) + y = concat(y_list) + + result = _parse_eval_str( + ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y, qid=qid), "Xy")]) + ) + assert result[1][0].endswith("pre@32") + score_1 = result[1][1] + assert score_1 == score_0 def check_quantile_error(tree_method: str) -> None: diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 4aadfb0c0..db38b2222 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -825,176 +825,6 @@ XGBOOST_DEVICE auto tcrend(xgboost::common::Span const &span) { // NOLINT return tcrbegin(span) + span.size(); } -// This type sorts an array which is divided into multiple groups. The sorting is influenced -// by the function object 'Comparator' -template -class SegmentSorter { - private: - // Items sorted within the group - caching_device_vector ditems_; - - // Original position of the items before they are sorted descending within their groups - caching_device_vector doriginal_pos_; - - // Segments within the original list that delineates the different groups - caching_device_vector group_segments_; - - // Need this on the device as it is used in the kernels - caching_device_vector dgroups_; // Group information on device - - // Where did the item that was originally present at position 'x' move to after they are sorted - caching_device_vector dindexable_sorted_pos_; - - // Initialize everything but the segments - void Init(uint32_t num_elems) { - ditems_.resize(num_elems); - - doriginal_pos_.resize(num_elems); - thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end()); - } - - // Initialize all with group info - void Init(const std::vector &groups) { - uint32_t num_elems = groups.back(); - this->Init(num_elems); - this->CreateGroupSegments(groups); - } - - public: - // This needs to be public due to device lambda - void CreateGroupSegments(const std::vector &groups) { - uint32_t num_elems = groups.back(); - group_segments_.resize(num_elems, 0); - - dgroups_ = groups; - - if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them - - // Define the segments by assigning a group ID to each element - const uint32_t *dgroups = dgroups_.data().get(); - uint32_t ngroups = dgroups_.size(); - auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { - return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) - - dgroups - 1; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(num_elems), - group_segments_.begin(), - ComputeGroupIDLambda); - } - - // Accessors that returns device pointer - inline uint32_t GetNumItems() const { return ditems_.size(); } - inline const xgboost::common::Span GetItemsSpan() const { - return { ditems_.data().get(), ditems_.size() }; - } - - inline const xgboost::common::Span GetOriginalPositionsSpan() const { - return { doriginal_pos_.data().get(), doriginal_pos_.size() }; - } - - inline const xgboost::common::Span GetGroupSegmentsSpan() const { - return { group_segments_.data().get(), group_segments_.size() }; - } - - inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; } - inline const xgboost::common::Span GetGroupsSpan() const { - return { dgroups_.data().get(), dgroups_.size() }; - } - - inline const xgboost::common::Span GetIndexableSortedPositionsSpan() const { - return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() }; - } - - // Sort an array that is divided into multiple groups. The array is sorted within each group. - // This version provides the group information that is on the host. - // The array is sorted based on an adaptable binary predicate. By default a stateless predicate - // is used. - template > - void SortItems(const T *ditems, uint32_t item_size, const std::vector &groups, - const Comparator &comp = Comparator()) { - this->Init(groups); - this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp); - } - - // Sort an array that is divided into multiple groups. The array is sorted within each group. - // This version provides the group information that is on the device. - // The array is sorted based on an adaptable binary predicate. By default a stateless predicate - // is used. - template > - void SortItems(const T *ditems, uint32_t item_size, - const xgboost::common::Span &group_segments, - const Comparator &comp = Comparator()) { - this->Init(item_size); - - // Sort the items that are grouped. We would like to avoid using predicates to perform the sort, - // as thrust resorts to using a merge sort as opposed to a much much faster radix sort - // when comparators are used. Hence, the following algorithm is used. This is done so that - // we can grab the appropriate related values from the original list later, after the - // items are sorted. - // - // Here is the internal representation: - // dgroups_: [ 0, 3, 5, 8, 10 ] - // group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3 - // doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9 - // ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items) - // - // Sort the items first and make a note of the original positions in doriginal_pos_ - // based on the sort - // ditems_: 4 4 3 3 2 1 1 1 1 0 - // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 - // NOTE: This consumes space, but is much faster than some of the other approaches - sorting - // in kernel, sorting using predicates etc. - - ditems_.assign(thrust::device_ptr(ditems), - thrust::device_ptr(ditems) + item_size); - - // Allocator to be used by sort for managing space overhead while sorting - dh::XGBCachingDeviceAllocator alloc; - - thrust::stable_sort_by_key(thrust::cuda::par(alloc), - ditems_.begin(), ditems_.end(), - doriginal_pos_.begin(), comp); - - if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented - - // Next, gather the segments based on the doriginal_pos_. This is to reflect the - // holisitic item sort order on the segments - // group_segments_c_: 3 3 2 2 1 0 0 1 2 0 - // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same) - caching_device_vector group_segments_c(item_size); - thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), - dh::tcbegin(group_segments), group_segments_c.begin()); - - // Now, sort the group segments so that you may bring the items within the group together, - // in the process also noting the relative changes to the doriginal_pos_ while that happens - // group_segments_c_: 0 0 0 1 1 2 2 2 3 3 - // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 - thrust::stable_sort_by_key(thrust::cuda::par(alloc), - group_segments_c.begin(), group_segments_c.end(), - doriginal_pos_.begin(), thrust::less()); - - // Finally, gather the original items based on doriginal_pos_ to sort the input and - // to store them in ditems_ - // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same) - // ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems) - thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), - thrust::device_ptr(ditems), ditems_.begin()); - } - - // Determine where an item that was originally present at position 'x' has been relocated to - // after a sort. Creation of such an index has to be explicitly requested after a sort - void CreateIndexableSortedPositions() { - dindexable_sorted_pos_.resize(GetNumItems()); - thrust::scatter(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(GetNumItems()), // Rearrange indices... - // ...based on this map - dh::tcbegin(GetOriginalPositionsSpan()), - dindexable_sorted_pos_.begin()); // Write results into this - } -}; - // Atomic add function for gradients template XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, diff --git a/src/common/optional_weight.h b/src/common/optional_weight.h index e929aecb5..c2844d73f 100644 --- a/src/common/optional_weight.h +++ b/src/common/optional_weight.h @@ -8,8 +8,7 @@ #include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/span.h" // Span -namespace xgboost { -namespace common { +namespace xgboost::common { struct OptionalWeights { Span weights; float dft{1.0f}; // fixme: make this compile time constant @@ -18,7 +17,8 @@ struct OptionalWeights { explicit OptionalWeights(float w) : dft{w} {} XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } - auto Empty() const { return weights.empty(); } + [[nodiscard]] auto Empty() const { return weights.empty(); } + [[nodiscard]] auto Size() const { return weights.size(); } }; inline OptionalWeights MakeOptionalWeights(Context const* ctx, @@ -28,6 +28,5 @@ inline OptionalWeights MakeOptionalWeights(Context const* ctx, } return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()}; } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_ diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 390ce34d2..5250abd0f 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -90,6 +90,9 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid MetaInfo const &info, float missing) { auto const &h_weights = (use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector()); + if (!use_group_ind_ && !h_weights.empty()) { + CHECK_EQ(h_weights.size(), batch.Size()) << "Invalid size of sample weight."; + } auto is_valid = data::IsValidFunctor{missing}; auto weights = OptionalWeights{Span{h_weights}}; diff --git a/src/common/quantile.h b/src/common/quantile.h index dc0a4872a..48758b8dc 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -19,12 +19,12 @@ #include "categorical.h" #include "common.h" +#include "error_msg.h" // GroupWeight #include "optional_weight.h" // OptionalWeights #include "threading_utils.h" #include "timer.h" -namespace xgboost { -namespace common { +namespace xgboost::common { /*! * \brief experimental wsummary * \tparam DType type of data content @@ -695,13 +695,18 @@ inline std::vector UnrollGroupWeights(MetaInfo const &info) { return group_weights; } - size_t n_samples = info.num_row_; auto const &group_ptr = info.group_ptr_; - std::vector results(n_samples); CHECK_GE(group_ptr.size(), 2); - CHECK_EQ(group_ptr.back(), n_samples); + + auto n_groups = group_ptr.size() - 1; + CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight(); + + bst_row_t n_samples = info.num_row_; + std::vector results(n_samples); + CHECK_EQ(group_ptr.back(), n_samples) + << error::GroupSize() << " the number of rows from the data."; size_t cur_group = 0; - for (size_t i = 0; i < n_samples; ++i) { + for (bst_row_t i = 0; i < n_samples; ++i) { results[i] = group_weights[cur_group]; if (i == group_ptr[cur_group + 1]) { cur_group++; @@ -1010,6 +1015,5 @@ class SortedSketchContainer : public SketchContainerImpl hessian); }; -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_QUANTILE_H_ diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc index d831b551c..65793a13a 100644 --- a/src/common/ranking_utils.cc +++ b/src/common/ranking_utils.cc @@ -114,9 +114,20 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS DMLC_REGISTER_PARAMETER(LambdaRankParam); +void PreCache::InitOnCPU(Context const*, MetaInfo const& info) { + auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0); + CheckPreLabels("pre", h_label, + [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); }); +} + +#if !defined(XGBOOST_USE_CUDA) +void PreCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +#endif // !defined(XGBOOST_USE_CUDA) + void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) { auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0); - CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); }); + CheckPreLabels("map", h_label, + [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); }); } #if !defined(XGBOOST_USE_CUDA) diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu index 8fbf89818..283ccc21d 100644 --- a/src/common/ranking_utils.cu +++ b/src/common/ranking_utils.cu @@ -205,8 +205,13 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); }); } +void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()}); +} + void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); - CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()}); + CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()}); } } // namespace xgboost::ltr diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index dd823a0d6..7d11de048 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -366,18 +366,43 @@ bool IsBinaryRel(linalg::VectorView label, AllOf all_of) { }); } /** - * \brief Validate label for MAP + * \brief Validate label for precision-based metric. * * \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for * both CPU and GPU. */ template -void CheckMapLabels(linalg::VectorView label, AllOf all_of) { +void CheckPreLabels(StringView name, linalg::VectorView label, AllOf all_of) { auto s_label = label.Values(); auto is_binary = IsBinaryRel(label, all_of); - CHECK(is_binary) << "MAP can only be used with binary labels."; + CHECK(is_binary) << name << " can only be used with binary labels."; } +class PreCache : public RankingCache { + HostDeviceVector pre_; + + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + + public: + PreCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) + : RankingCache{ctx, info, p} { + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + } + + common::Span Pre(Context const* ctx) { + if (pre_.Empty()) { + pre_.SetDevice(ctx->gpu_id); + pre_.Resize(this->Groups()); + } + return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan(); + } +}; + class MAPCache : public RankingCache { // Total number of relevant documents for each group HostDeviceVector n_rel_; diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 8eb1c2034..627606aa3 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -366,8 +366,8 @@ inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, Da common::AssertGPUSupport(); } -inline BatchSet IterativeDMatrix::GetEllpackBatches(Context const* ctx, - BatchParam const& param) { +inline BatchSet IterativeDMatrix::GetEllpackBatches(Context const*, + BatchParam const&) { common::AssertGPUSupport(); auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_)); return BatchSet(BatchIterator(begin_iter)); diff --git a/src/metric/metric.cc b/src/metric/metric.cc index ebb579827..d7e2683ec 100644 --- a/src/metric/metric.cc +++ b/src/metric/metric.cc @@ -52,32 +52,13 @@ Metric::Create(const std::string& name, Context const* ctx) { metric->ctx_ = ctx; return metric; } - -GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) { - auto metric = CreateMetricImpl(name); - if (metric == nullptr) { - LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name - << ". Resorting to the CPU builder"; - return nullptr; - } - - // Narrowing reference only for the compiler to allow assignment to a base class member. - // As such, using this narrowed reference to refer to derived members will be an illegal op. - // This is moot, as this type is stateless. - auto casted = static_cast(metric); - CHECK(casted); - casted->ctx_ = ctx; - return casted; -} } // namespace xgboost namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); -DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg); } -namespace xgboost { -namespace metric { +namespace xgboost::metric { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(auc); DMLC_REGISTRY_LINK_TAG(elementwise_metric); @@ -88,5 +69,4 @@ DMLC_REGISTRY_LINK_TAG(rank_metric); DMLC_REGISTRY_LINK_TAG(auc_gpu); DMLC_REGISTRY_LINK_TAG(rank_metric_gpu); #endif -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index a6fad7158..1b148ab0f 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -23,53 +23,14 @@ class MetricNoCache : public Metric { double Evaluate(HostDeviceVector const &predts, std::shared_ptr p_fmat) final { double result{0.0}; - auto const& info = p_fmat->Info(); - collective::ApplyWithLabels(info, &result, sizeof(double), [&] { - result = this->Eval(predts, info); - }); + auto const &info = p_fmat->Info(); + collective::ApplyWithLabels(info, &result, sizeof(double), + [&] { result = this->Eval(predts, info); }); return result; } }; -// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not -// present already. This is created when there is a device ordinal present and if xgboost -// is compiled with CUDA support -struct GPUMetric : public MetricNoCache { - static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam); -}; - -/*! - * \brief Internal registry entries for GPU Metric factory functions. - * The additional parameter const char* param gives the value after @, can be null. - * For example, metric map@3, then: param == "3". - */ -struct MetricGPUReg - : public dmlc::FunctionRegEntryBase > { -}; - -/*! - * \brief Macro to register metric computed on GPU. - * - * \code - * // example of registering a objective ndcg@k - * XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg") - * .describe("NDCG metric computer on GPU.") - * .set_body([](const char* param) { - * int at_k = atoi(param); - * return new NDCG(at_k); - * }); - * \endcode - */ - -// Note: Metric names registered in the GPU registry should follow this convention: -// - GPU metric types should be registered with the same name as the non GPU metric types -#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \ - ::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \ - ::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name) - namespace metric { - // Ranking config to be used on device and host struct EvalRankConfig { public: @@ -81,8 +42,8 @@ struct EvalRankConfig { }; class PackedReduceResult { - double residue_sum_ { 0 }; - double weights_sum_ { 0 }; + double residue_sum_{0}; + double weights_sum_{0}; public: XGBOOST_DEVICE PackedReduceResult() {} // NOLINT @@ -91,16 +52,15 @@ class PackedReduceResult { XGBOOST_DEVICE PackedReduceResult operator+(PackedReduceResult const &other) const { - return PackedReduceResult{residue_sum_ + other.residue_sum_, - weights_sum_ + other.weights_sum_}; + return PackedReduceResult{residue_sum_ + other.residue_sum_, weights_sum_ + other.weights_sum_}; } PackedReduceResult &operator+=(PackedReduceResult const &other) { this->residue_sum_ += other.residue_sum_; this->weights_sum_ += other.weights_sum_; return *this; } - double Residue() const { return residue_sum_; } - double Weights() const { return weights_sum_; } + [[nodiscard]] double Residue() const { return residue_sum_; } + [[nodiscard]] double Weights() const { return weights_sum_; } }; } // namespace metric diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index c4549458d..dd9adc017 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -1,25 +1,6 @@ /** * Copyright 2020-2023 by XGBoost contributors */ -// When device ordinal is present, we would want to build the metrics on the GPU. It is *not* -// possible for a valid device ordinal to be present for non GPU builds. However, it is possible -// for an invalid device ordinal to be specified in GPU builds - to train/predict and/or compute -// the metrics on CPU. To accommodate these scenarios, the following is done for the metrics -// accelerated on the GPU. -// - An internal GPU registry holds all the GPU metric types (defined in the .cu file) -// - An instance of the appropriate GPU metric type is created when a device ordinal is present -// - If the creation is successful, the metric computation is done on the device -// - else, it falls back on the CPU -// - The GPU metric types are *only* registered when xgboost is built for GPUs -// -// This is done for 2 reasons: -// - Clear separation of CPU and GPU logic -// - Sorting datasets containing large number of rows is (much) faster when parallel sort -// semantics is used on the CPU. The __gnu_parallel/concurrency primitives needed to perform -// this cannot be used when the translation unit is compiled using the 'nvcc' compiler (as the -// corresponding headers that brings in those function declaration can't be included with CUDA). -// This precludes the CPU and GPU logic to coexist inside a .cu file - #include "rank_metric.h" #include @@ -57,55 +38,8 @@ #include "xgboost/string_view.h" // for StringView namespace { - using PredIndPair = std::pair; using PredIndPairContainer = std::vector; - -/* - * Adapter to access instance weights. - * - * - For ranking task, weights are per-group - * - For binary classification task, weights are per-instance - * - * WeightPolicy::GetWeightOfInstance() : - * get weight associated with an individual instance, using index into - * `info.weights` - * WeightPolicy::GetWeightOfSortedRecord() : - * get weight associated with an individual instance, using index into - * sorted records `rec` (in ascending order of predicted labels). `rec` is - * of type PredIndPairContainer - */ - -class PerInstanceWeightPolicy { - public: - inline static xgboost::bst_float - GetWeightOfInstance(const xgboost::MetaInfo& info, - unsigned instance_id, unsigned) { - return info.GetWeight(instance_id); - } - inline static xgboost::bst_float - GetWeightOfSortedRecord(const xgboost::MetaInfo& info, - const PredIndPairContainer& rec, - unsigned record_id, unsigned) { - return info.GetWeight(rec[record_id].second); - } -}; - -class PerGroupWeightPolicy { - public: - inline static xgboost::bst_float - GetWeightOfInstance(const xgboost::MetaInfo& info, - unsigned, unsigned group_id) { - return info.GetWeight(group_id); - } - - inline static xgboost::bst_float - GetWeightOfSortedRecord(const xgboost::MetaInfo& info, - const PredIndPairContainer&, - unsigned, unsigned group_id) { - return info.GetWeight(group_id); - } -}; } // anonymous namespace namespace xgboost::metric { @@ -177,10 +111,6 @@ struct EvalAMS : public MetricNoCache { /*! \brief Evaluate rank list */ struct EvalRank : public MetricNoCache, public EvalRankConfig { - private: - // This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU. - std::unique_ptr rank_gpu_; - public: double Eval(const HostDeviceVector& preds, const MetaInfo& info) override { CHECK_EQ(preds.Size(), info.labels.Size()) @@ -199,20 +129,10 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { // sum statistics double sum_metric = 0.0f; - // Check and see if we have the GPU metric registered in the internal registry - if (ctx_->gpu_id >= 0) { - if (!rank_gpu_) { - rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_)); - } - if (rank_gpu_) { - sum_metric = rank_gpu_->Eval(preds, info); - } - } - CHECK(ctx_); std::vector sum_tloc(ctx_->Threads(), 0.0); - if (!rank_gpu_ || ctx_->gpu_id < 0) { + { const auto& labels = info.labels.View(Context::kCpuId); const auto &h_preds = preds.ConstHostVector(); @@ -253,23 +173,6 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { virtual double EvalGroup(PredIndPairContainer *recptr) const = 0; }; -/*! \brief Precision at N, for both classification and rank */ -struct EvalPrecision : public EvalRank { - public: - explicit EvalPrecision(const char* name, const char* param) : EvalRank(name, param) {} - - double EvalGroup(PredIndPairContainer *recptr) const override { - PredIndPairContainer &rec(*recptr); - // calculate Precision - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - unsigned nhit = 0; - for (size_t j = 0; j < rec.size() && j < this->topn; ++j) { - nhit += (rec[j].second != 0); - } - return static_cast(nhit) / this->topn; - } -}; - /*! \brief Cox: Partial likelihood of the Cox proportional hazards model */ struct EvalCox : public MetricNoCache { public: @@ -312,7 +215,7 @@ struct EvalCox : public MetricNoCache { return out/num_events; // normalize by the number of events } - const char* Name() const override { + [[nodiscard]] const char* Name() const override { return "cox-nloglik"; } }; @@ -321,10 +224,6 @@ XGBOOST_REGISTER_METRIC(AMS, "ams") .describe("AMS metric for higgs.") .set_body([](const char* param) { return new EvalAMS(param); }); -XGBOOST_REGISTER_METRIC(Precision, "pre") -.describe("precision@k for rank.") -.set_body([](const char* param) { return new EvalPrecision("pre", param); }); - XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .describe("Negative log partial likelihood of Cox proportional hazards model.") .set_body([](const char*) { return new EvalCox(); }); @@ -387,6 +286,8 @@ class EvalRankWithCache : public Metric { return result; } + [[nodiscard]] const char* Name() const override { return name_.c_str(); } + virtual double Eval(HostDeviceVector const& preds, MetaInfo const& info, std::shared_ptr p_cache) = 0; }; @@ -408,6 +309,52 @@ double Finalize(MetaInfo const& info, double score, double sw) { } } // namespace +class EvalPrecision : public EvalRankWithCache { + public: + using EvalRankWithCache::EvalRankWithCache; + + double Eval(HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache) final { + auto n_groups = p_cache->Groups(); + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight(); + } + + if (ctx_->IsCUDA()) { + auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache); + return Finalize(info, pre.Residue(), pre.Weights()); + } + + auto gptr = p_cache->DataGroupPtr(ctx_); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size()); + auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan()); + + auto weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto pre = p_cache->Pre(ctx_); + + common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { + auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1])); + auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]); + + auto n = std::min(static_cast(param_.TopK()), g_label.Size()); + double n_hits{0.0}; + for (std::size_t i = 0; i < n; ++i) { + n_hits += g_label(g_rank[i]) * weight[g]; + } + pre[g] = n_hits / static_cast(n); + }); + + auto sw = 0.0; + for (std::size_t i = 0; i < pre.size(); ++i) { + sw += weight[i]; + } + + auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0); + return Finalize(info, sum, sw); + } +}; + /** * \brief Implement the NDCG score function for learning to rank. * @@ -416,7 +363,6 @@ double Finalize(MetaInfo const& info, double score, double sw) { class EvalNDCG : public EvalRankWithCache { public: using EvalRankWithCache::EvalRankWithCache; - const char* Name() const override { return name_.c_str(); } double Eval(HostDeviceVector const& preds, MetaInfo const& info, std::shared_ptr p_cache) override { @@ -475,7 +421,6 @@ class EvalNDCG : public EvalRankWithCache { class EvalMAPScore : public EvalRankWithCache { public: using EvalRankWithCache::EvalRankWithCache; - const char* Name() const override { return name_.c_str(); } double Eval(HostDeviceVector const& predt, MetaInfo const& info, std::shared_ptr p_cache) override { @@ -494,7 +439,7 @@ class EvalMAPScore : public EvalRankWithCache { common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1])); - auto g_rank = rank_idx.subspan(gptr[g]); + auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]); auto n = std::min(static_cast(param_.TopK()), g_label.Size()); double n_hits{0.0}; @@ -527,6 +472,10 @@ class EvalMAPScore : public EvalRankWithCache { } }; +XGBOOST_REGISTER_METRIC(Precision, "pre") + .describe("precision@k for rank.") + .set_body([](const char* param) { return new EvalPrecision("pre", param); }); + XGBOOST_REGISTER_METRIC(EvalMAP, "map") .describe("map@k for ranking.") .set_body([](char const* param) { diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 386f0d53d..9ba1baf8f 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -28,108 +28,57 @@ namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(rank_metric_gpu); -/*! \brief Evaluate rank list on GPU */ -template -struct EvalRankGpu : public GPUMetric, public EvalRankConfig { - public: - double Eval(const HostDeviceVector &preds, const MetaInfo &info) override { - // Sanity check is done by the caller - std::vector tgptr(2, 0); - tgptr[1] = static_cast(preds.Size()); - const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - - const auto ngroups = static_cast(gptr.size() - 1); - - auto device = ctx_->gpu_id; - dh::safe_cuda(cudaSetDevice(device)); - - info.labels.SetDevice(device); - preds.SetDevice(device); - - auto dpreds = preds.ConstDevicePointer(); - auto dlabels = info.labels.View(device); - - // Sort all the predictions - dh::SegmentSorter segment_pred_sorter; - segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr); - - // Compute individual group metric and sum them up - return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels.Values().data(), *this); - } - - const char* Name() const override { - return name.c_str(); - } - - explicit EvalRankGpu(const char* name, const char* param) { - using namespace std; // NOLINT(*) - if (param != nullptr) { - std::ostringstream os; - if (sscanf(param, "%u[-]?", &this->topn) == 1) { - os << name << '@' << param; - this->name = os.str(); - } else { - os << name << param; - this->name = os.str(); - } - if (param[strlen(param) - 1] == '-') { - this->minus = true; - } - } else { - this->name = name; - } - } -}; - -/*! \brief Precision at N, for both classification and rank */ -struct EvalPrecisionGpu { - public: - static double EvalMetric(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg) { - // Group info on device - const auto &dgroups = pred_sorter.GetGroupsSpan(); - const auto ngroups = pred_sorter.GetNumGroups(); - const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan(); - - // Original positions of the predictions after they have been sorted - const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan(); - - // First, determine non zero labels in the dataset individually - auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) { - return (static_cast(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0; - }; // NOLINT - - // Find each group's metric sum - dh::caching_device_vector hits(ngroups, 0); - const auto nitems = pred_sorter.GetNumItems(); - auto *dhits = hits.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each group item compute the aggregated precision - dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { - const auto group_idx = dgroup_idx[idx]; - const auto group_begin = dgroups[group_idx]; - const auto ridx = idx - group_begin; - if (ridx < ecfg.topn && DetermineNonTrivialLabelLambda(idx)) { - atomicAdd(&dhits[group_idx], 1); - } - }); - - // Allocator to be used for managing space overhead while performing reductions - dh::XGBCachingDeviceAllocator alloc; - return static_cast(thrust::reduce(thrust::cuda::par(alloc), - hits.begin(), hits.end())) / ecfg.topn; - } -}; - - -XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") -.describe("precision@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("pre", param); }); - namespace cuda_impl { +PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, + std::shared_ptr p_cache) { + auto d_gptr = p_cache->DataGroupPtr(ctx); + auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + + predt.SetDevice(ctx->gpu_id); + auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan()); + auto topk = p_cache->Param().TopK(); + auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { + auto g = dh::SegmentId(d_gptr, i); + auto g_begin = d_gptr[g]; + auto g_end = d_gptr[g + 1]; + i -= g_begin; + auto g_label = d_label.Slice(linalg::Range(g_begin, g_end)); + auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin); + double y = g_label(g_rank[i]); + auto n = std::min(static_cast(topk), g_label.Size()); + double w{d_weight[g]}; + if (i >= n) { + return 0.0; + } + return y / static_cast(n) * w; + }); + + auto cuctx = ctx->CUDACtx(); + auto pre = p_cache->Pre(ctx); + thrust::fill_n(cuctx->CTP(), pre.data(), pre.size(), 0.0); + + std::size_t bytes; + cub::DeviceSegmentedReduce::Sum(nullptr, bytes, it, pre.data(), p_cache->Groups(), d_gptr.data(), + d_gptr.data() + 1, cuctx->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, it, pre.data(), p_cache->Groups(), + d_gptr.data(), d_gptr.data() + 1, cuctx->Stream()); + + auto w_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t g) { return d_weight[g]; }); + auto n_weights = p_cache->Groups(); + auto sw = dh::Reduce(cuctx->CTP(), w_it, w_it + n_weights, 0.0, thrust::plus{}); + auto sum = + dh::Reduce(cuctx->CTP(), dh::tcbegin(pre), dh::tcend(pre), 0.0, thrust::plus{}); + auto result = PackedReduceResult{sum, sw}; + return result; +} + PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, std::shared_ptr p_cache) { diff --git a/src/metric/rank_metric.h b/src/metric/rank_metric.h index b3b121973..40954ffcb 100644 --- a/src/metric/rank_metric.h +++ b/src/metric/rank_metric.h @@ -3,7 +3,7 @@ /** * Copyright 2023 by XGBoost Contributors */ -#include // for shared_ptr +#include // for shared_ptr #include "../common/common.h" // for AssertGPUSupport #include "../common/ranking_utils.h" // for NDCGCache, MAPCache @@ -12,9 +12,7 @@ #include "xgboost/data.h" // for MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector -namespace xgboost { -namespace metric { -namespace cuda_impl { +namespace xgboost::metric::cuda_impl { PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, std::shared_ptr p_cache); @@ -23,6 +21,10 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, std::shared_ptr p_cache); +PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, + std::shared_ptr p_cache); + #if !defined(XGBOOST_USE_CUDA) inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, HostDeviceVector const &, bool, @@ -37,8 +39,13 @@ inline PackedReduceResult MAPScore(Context const *, MetaInfo const &, common::AssertGPUSupport(); return {}; } + +inline PackedReduceResult PreScore(Context const *, MetaInfo const &, + HostDeviceVector const &, + std::shared_ptr) { + common::AssertGPUSupport(); + return {}; +} #endif -} // namespace cuda_impl -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric::cuda_impl #endif // XGBOOST_METRIC_RANK_METRIC_H_ diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 4601c4378..a6ef0b804 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -90,7 +90,7 @@ def check_cmd_print_failure_assistance(cmd: List[str]) -> bool: subprocess.run([cmd[0], "--version"]) msg = """ -Please run the following command on your machine to address the formatting error: +Please run the following command on your machine to address the error: """ msg += " ".join(cmd) diff --git a/tests/cpp/metric/test_rank_metric.h b/tests/cpp/metric/test_rank_metric.h index 318de961b..4b959d857 100644 --- a/tests/cpp/metric/test_rank_metric.h +++ b/tests/cpp/metric/test_rank_metric.h @@ -17,34 +17,30 @@ #include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/json.h" // for Json, String, Object -namespace xgboost { -namespace metric { +namespace xgboost::metric { inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) { - // When the limit for precision is not given, it takes the limit at - // std::numeric_limits::max(); hence all values are very small - // NOTE(AbdealiJK): Maybe this should be fixed to be num_row by default. auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("pre", &ctx); + std::unique_ptr metric{Metric::Create("pre", &ctx)}; ASSERT_STREQ(metric->Name(), "pre"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0, 1e-7); - EXPECT_NEAR(GetMetricEval(metric, - {0.1f, 0.9f, 0.1f, 0.9f}, - { 0, 0, 1, 1}, {}, {}, data_split_mode), - 0, 1e-7); + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7); + EXPECT_NEAR( + GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode), + 0.5, 1e-7); - delete metric; - metric = xgboost::Metric::Create("pre@2", &ctx); + metric.reset(xgboost::Metric::Create("pre@2", &ctx)); ASSERT_STREQ(metric->Name(), "pre@2"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7); - EXPECT_NEAR(GetMetricEval(metric, - {0.1f, 0.9f, 0.1f, 0.9f}, - { 0, 0, 1, 1}, {}, {}, data_split_mode), - 0.5f, 0.001f); + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7); + EXPECT_NEAR( + GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode), + 0.5f, 0.001f); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}, {}, {}, data_split_mode)); + EXPECT_ANY_THROW(GetMetricEval(metric.get(), {0, 1}, {}, {}, {}, data_split_mode)); - delete metric; + metric.reset(xgboost::Metric::Create("pre@4", &ctx)); + EXPECT_NEAR(GetMetricEval(metric.get(), {0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, + {0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f}, {}, {}, data_split_mode), + 0.5f, 1e-7); } inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) { @@ -187,5 +183,4 @@ inline void VerifyNDCGExpGain(DataSplitMode data_split_mode = DataSplitMode::kRo ndcg = metric->Evaluate(predt, p_fmat); ASSERT_NEAR(ndcg, 1.0, kRtEps); } -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/tests/python-gpu/test_gpu_eval_metrics.py b/tests/python-gpu/test_gpu_eval_metrics.py index 1e9d1a282..f5f770d2f 100644 --- a/tests/python-gpu/test_gpu_eval_metrics.py +++ b/tests/python-gpu/test_gpu_eval_metrics.py @@ -5,7 +5,7 @@ import pytest import xgboost from xgboost import testing as tm -from xgboost.testing.metrics import check_quantile_error +from xgboost.testing.metrics import check_precision_score, check_quantile_error sys.path.append("tests/python") import test_eval_metrics as test_em # noqa @@ -59,6 +59,9 @@ class TestGPUEvalMetrics: def test_pr_auc_ltr(self): self.cpu_test.run_pr_auc_ltr("gpu_hist") + def test_precision_score(self): + check_precision_score("gpu_hist") + @pytest.mark.skipif(**tm.no_sklearn()) def test_quantile_error(self) -> None: check_quantile_error("gpu_hist") diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 3b7dc5b8e..0328765f5 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -3,7 +3,7 @@ import pytest import xgboost as xgb from xgboost import testing as tm -from xgboost.testing.metrics import check_quantile_error +from xgboost.testing.metrics import check_precision_score, check_quantile_error rng = np.random.RandomState(1337) @@ -315,6 +315,9 @@ class TestEvalMetrics: def test_pr_auc_ltr(self): self.run_pr_auc_ltr("hist") + def test_precision_score(self): + check_precision_score("hist") + @pytest.mark.skipif(**tm.no_sklearn()) def test_quantile_error(self) -> None: check_quantile_error("hist") diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index 537910725..0e0aaed08 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -55,6 +55,38 @@ class TestQuantileDMatrix: r = np.arange(1.0, n_samples) np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r) + def test_error(self): + from sklearn.model_selection import train_test_split + + rng = np.random.default_rng(1994) + X, y = make_categorical( + n_samples=128, n_features=2, n_categories=3, onehot=False + ) + reg = xgb.XGBRegressor(tree_method="hist", enable_categorical=True) + w = rng.uniform(0, 1, size=y.shape[0]) + + X_train, X_test, y_train, y_test, w_train, w_test = train_test_split( + X, y, w, random_state=1994 + ) + + with pytest.raises(ValueError, match="sample weight"): + reg.fit( + X, + y, + sample_weight=w_train, + eval_set=[(X_test, y_test)], + sample_weight_eval_set=[w_test], + ) + + with pytest.raises(ValueError, match="sample weight"): + reg.fit( + X_train, + y_train, + sample_weight=w, + eval_set=[(X_test, y_test)], + sample_weight_eval_set=[w_test], + ) + @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9]) def test_with_iterator(self, sparsity: float) -> None: n_samples_per_batch = 317 From 288539ac781117a73e156a9e4ee56cbaf6c30f89 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 2 Jun 2023 08:17:41 -0700 Subject: [PATCH 2/9] [CI] Automatically bump Rapids version in containers (#9234) * [CI] Use RAPIDS 23.04 * [CI] Remove outdated filters in dependabot * [CI] Automatically bump Rapids version in containers * Automate pull request --- .github/dependabot.yml | 62 ----------------------------- .github/workflows/update_rapids.yml | 37 +++++++++++++++++ tests/buildkite/conftest.sh | 2 +- tests/buildkite/update-rapids.sh | 10 +++++ 4 files changed, 48 insertions(+), 63 deletions(-) create mode 100644 .github/workflows/update_rapids.yml create mode 100755 tests/buildkite/update-rapids.sh diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 0b593216c..c03a52c60 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,85 +9,23 @@ updates: directory: "/jvm-packages" schedule: interval: "daily" - ignore: - # Pin Scala version to 2.12.x - - dependency-name: "org.scala-lang:scala-compiler" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-reflect" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-library" - versions: [">= 2.13.0"] - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j" schedule: interval: "daily" - ignore: - # Pin Scala version to 2.12.x - - dependency-name: "org.scala-lang:scala-compiler" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-reflect" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-library" - versions: [">= 2.13.0"] - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-gpu" schedule: interval: "daily" - ignore: - # Pin Scala version to 2.12.x - - dependency-name: "org.scala-lang:scala-compiler" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-reflect" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-library" - versions: [">= 2.13.0"] - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-example" schedule: interval: "daily" - ignore: - # Pin Scala version to 2.12.x - - dependency-name: "org.scala-lang:scala-compiler" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-reflect" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-library" - versions: [">= 2.13.0"] - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-spark" schedule: interval: "daily" - ignore: - # Pin Scala version to 2.12.x - - dependency-name: "org.scala-lang:scala-compiler" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-reflect" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-library" - versions: [">= 2.13.0"] - # Pin Spark version to 3.0.x - - dependency-name: "org.apache.spark:spark-core_2.12" - versions: [">= 3.1.0"] - - dependency-name: "org.apache.spark:spark-sql_2.12" - versions: [">= 3.1.0"] - - dependency-name: "org.apache.spark:spark-mllib_2.12" - versions: [">= 3.1.0"] - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-spark-gpu" schedule: interval: "daily" - ignore: - # Pin Scala version to 2.12.x - - dependency-name: "org.scala-lang:scala-compiler" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-reflect" - versions: [">= 2.13.0"] - - dependency-name: "org.scala-lang:scala-library" - versions: [">= 2.13.0"] - # Pin Spark version to 3.0.x - - dependency-name: "org.apache.spark:spark-core_2.12" - versions: [">= 3.1.0"] - - dependency-name: "org.apache.spark:spark-sql_2.12" - versions: [">= 3.1.0"] - - dependency-name: "org.apache.spark:spark-mllib_2.12" - versions: [">= 3.1.0"] diff --git a/.github/workflows/update_rapids.yml b/.github/workflows/update_rapids.yml new file mode 100644 index 000000000..83dd19b35 --- /dev/null +++ b/.github/workflows/update_rapids.yml @@ -0,0 +1,37 @@ +name: update-rapids + +on: + schedule: + - cron: "0 7 * * *" # Run once daily + +permissions: + contents: read # to fetch code (actions/checkout) + +defaults: + run: + shell: bash -l {0} + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # To use GitHub CLI + +jobs: + update-rapids: + name: Check latest RAPIDS + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + submodules: 'true' + - name: Check latest RAPIDS and update conftest.sh + run: | + bash tests/buildkite/update-rapids.sh + - name: Create Pull Request + uses: peter-evans/create-pull-request@v5 + add-paths: tests/buildkite + branch: create-pull-request/update-rapids + base: master + if: github.ref == 'refs/heads/master' diff --git a/tests/buildkite/conftest.sh b/tests/buildkite/conftest.sh index 957dd443c..108b3d402 100755 --- a/tests/buildkite/conftest.sh +++ b/tests/buildkite/conftest.sh @@ -24,7 +24,7 @@ set -x CUDA_VERSION=11.8.0 NCCL_VERSION=2.16.5-1 -RAPIDS_VERSION=23.02 +RAPIDS_VERSION=23.04 SPARK_VERSION=3.4.0 JDK_VERSION=8 diff --git a/tests/buildkite/update-rapids.sh b/tests/buildkite/update-rapids.sh new file mode 100755 index 000000000..f617ccd11 --- /dev/null +++ b/tests/buildkite/update-rapids.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -euo pipefail + +LATEST_RAPIDS_VERSION=$(gh api repos/rapidsai/cuml/releases/latest --jq '.name' | sed -e 's/^v\([[:digit:]]\+\.[[:digit:]]\+\).*/\1/') +echo "LATEST_RAPIDS_VERSION = $LATEST_RAPIDS_VERSION" + +PARENT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) + +sed -i "s/^RAPIDS_VERSION=[[:digit:]]\+\.[[:digit:]]\+/RAPIDS_VERSION=${LATEST_RAPIDS_VERSION}/" $PARENT_PATH/conftest.sh From a1fad72ab382fead287b4de13e79dc4ae4466893 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 2 Jun 2023 08:22:25 -0700 Subject: [PATCH 3/9] Update outdated build badges (#9232) --- README.md | 2 +- jvm-packages/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2fae68ac5..92c246dfd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ eXtreme Gradient Boosting =========== -[![Build Status](https://xgboost-ci.net/job/xgboost/job/master/badge/icon)](https://xgboost-ci.net/blue/organizations/jenkins/xgboost/activity) +[![Build Status](https://badge.buildkite.com/aca47f40a32735c00a8550540c5eeff6a4c1d246a580cae9b0.svg?branch=master)](https://buildkite.com/xgboost/xgboost-ci) [![XGBoost-CI](https://github.com/dmlc/xgboost/workflows/XGBoost-CI/badge.svg?branch=master)](https://github.com/dmlc/xgboost/actions) [![Documentation Status](https://readthedocs.org/projects/xgboost/badge/?version=latest)](https://xgboost.readthedocs.org) [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) diff --git a/jvm-packages/README.md b/jvm-packages/README.md index 239464342..451a0d981 100644 --- a/jvm-packages/README.md +++ b/jvm-packages/README.md @@ -1,5 +1,5 @@ # XGBoost4J: Distributed XGBoost for Scala/Java -[![Build Status](https://travis-ci.org/dmlc/xgboost.svg?branch=master)](https://travis-ci.org/dmlc/xgboost) +[![Build Status](https://badge.buildkite.com/aca47f40a32735c00a8550540c5eeff6a4c1d246a580cae9b0.svg?branch=master)](https://buildkite.com/xgboost/xgboost-ci) [![Documentation Status](https://readthedocs.org/projects/xgboost/badge/?version=latest)](https://xgboost.readthedocs.org/en/latest/jvm/index.html) [![GitHub license](http://dmlc.github.io/img/apache2.svg)](../LICENSE) From 3bf0f145bb129f56ff34a0f849945b545a7d4922 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Sat, 3 Jun 2023 13:12:12 -0700 Subject: [PATCH 4/9] Update update_rapids.yml --- .github/workflows/update_rapids.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/update_rapids.yml b/.github/workflows/update_rapids.yml index 83dd19b35..022639a1f 100644 --- a/.github/workflows/update_rapids.yml +++ b/.github/workflows/update_rapids.yml @@ -2,7 +2,7 @@ name: update-rapids on: schedule: - - cron: "0 7 * * *" # Run once daily + - cron: "20 20 * * *" # Run once daily permissions: contents: read # to fetch code (actions/checkout) @@ -31,7 +31,10 @@ jobs: bash tests/buildkite/update-rapids.sh - name: Create Pull Request uses: peter-evans/create-pull-request@v5 - add-paths: tests/buildkite - branch: create-pull-request/update-rapids - base: master if: github.ref == 'refs/heads/master' + with: + add-paths: | + tests/buildkite + branch: create-pull-request/update-rapids + base: master + From 962a20693fd5e6266618edeef6c43b35e42fa675 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sun, 4 Jun 2023 17:05:38 -0700 Subject: [PATCH 5/9] More support for column split in cpu predictor (#9244) - Added column split support to `PredictInstance` and `PredictLeaf`. - Refactoring of tests. --- include/xgboost/predictor.h | 12 +- src/predictor/cpu_predictor.cc | 82 +++++++++++--- src/predictor/gpu_predictor.cu | 2 +- tests/cpp/predictor/test_cpu_predictor.cc | 131 ++++++---------------- 4 files changed, 108 insertions(+), 119 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 50665341a..615bc0f39 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -134,16 +134,18 @@ class Predictor { * usually more efficient than online prediction This function is NOT * threadsafe, make sure you only call from one thread. * - * \param inst The instance to predict. - * \param [in,out] out_preds The output preds. - * \param model The model to predict from - * \param tree_end (Optional) The tree end index. + * \param inst The instance to predict. + * \param [in,out] out_preds The output preds. + * \param model The model to predict from + * \param tree_end (Optional) The tree end index. + * \param is_column_split (Optional) If the data is split column-wise. */ virtual void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, - unsigned tree_end = 0) const = 0; + unsigned tree_end = 0, + bool is_column_split = false) const = 0; /** * \brief predict the leaf index of each tree, the output will be nsample * diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index aa8972989..96c1fbe18 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -191,6 +191,15 @@ struct SparsePageView { size_t Size() const { return view.Size(); } }; +struct SingleInstanceView { + bst_row_t base_rowid{}; + SparsePage::Inst const &inst; + + explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {} + SparsePage::Inst operator[](size_t) { return inst; } + static size_t Size() { return 1; } +}; + struct GHistIndexMatrixView { private: GHistIndexMatrix const &page_; @@ -409,6 +418,24 @@ class ColumnSplitHelper { } } + void PredictInstance(SparsePage::Inst const &inst, std::vector *out_preds) { + CHECK(xgboost::collective::IsDistributed()) + << "column-split prediction is only supported for distributed training"; + + PredictBatchKernel(SingleInstanceView{inst}, out_preds); + } + + void PredictLeaf(DMatrix *p_fmat, std::vector *out_preds) { + CHECK(xgboost::collective::IsDistributed()) + << "column-split prediction is only supported for distributed training"; + + for (auto const &batch : p_fmat->GetBatches()) { + CHECK_EQ(out_preds->size(), + p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group); + PredictBatchKernel(SparsePageView{&batch}, out_preds); + } + } + private: using BitVector = RBitField8; @@ -498,24 +525,31 @@ class ColumnSplitHelper { return nid; } + template bst_float PredictOneTree(std::size_t tree_id, std::size_t row_id) { auto const &tree = *model_.trees[tree_id]; auto const leaf = GetLeafIndex(tree, tree_id, row_id); - return tree[leaf].LeafValue(); + if constexpr (predict_leaf) { + return static_cast(leaf); + } else { + return tree[leaf].LeafValue(); + } } + template void PredictAllTrees(std::vector *out_preds, std::size_t batch_offset, std::size_t predict_offset, std::size_t num_group, std::size_t block_size) { auto &preds = *out_preds; for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) { auto const gid = model_.tree_info[tree_id]; for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += PredictOneTree(tree_id, batch_offset + i); + preds[(predict_offset + i) * num_group + gid] += + PredictOneTree(tree_id, batch_offset + i); } } } - template + template void PredictBatchKernel(DataView batch, std::vector *out_preds) { auto const num_group = model_.learner_model_param->num_output_group; @@ -544,8 +578,8 @@ class ColumnSplitHelper { auto const batch_offset = block_id * block_of_rows_size; auto const block_size = std::min(static_cast(nsize - batch_offset), static_cast(block_of_rows_size)); - PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, num_group, - block_size); + PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, + num_group, block_size); }); ClearBitVectors(); @@ -728,18 +762,25 @@ class CPUPredictor : public Predictor { return true; } - void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, + const gbm::GBTreeModel &model, unsigned ntree_limit, + bool is_column_split) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented(); - std::vector feat_vecs; - feat_vecs.resize(1, RegTree::FVec()); - feat_vecs[0].Init(model.learner_model_param->num_feature); ntree_limit *= model.learner_model_param->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } out_preds->resize(model.learner_model_param->num_output_group); + + if (is_column_split) { + ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit); + helper.PredictInstance(inst, out_preds); + return; + } + + std::vector feat_vecs; + feat_vecs.resize(1, RegTree::FVec()); + feat_vecs[0].Init(model.learner_model_param->num_feature); auto base_score = model.learner_model_param->BaseScore(ctx_)(0); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { @@ -752,16 +793,23 @@ class CPUPredictor : public Predictor { void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, const gbm::GBTreeModel &model, unsigned ntree_limit) const override { auto const n_threads = this->ctx_->Threads(); - std::vector feat_vecs; - const int num_feature = model.learner_model_param->num_feature; - InitThreadTemp(n_threads, &feat_vecs); - const MetaInfo &info = p_fmat->Info(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } + const MetaInfo &info = p_fmat->Info(); std::vector &preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); + + if (p_fmat->Info().IsColumnSplit()) { + ColumnSplitHelper helper(n_threads, model, 0, ntree_limit); + helper.PredictLeaf(p_fmat, &preds); + return; + } + + std::vector feat_vecs; + const int num_feature = model.learner_model_param->num_feature; + InitThreadTemp(n_threads, &feat_vecs); // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { // parallel over local batch @@ -796,6 +844,8 @@ class CPUPredictor : public Predictor { int condition, unsigned condition_feature) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; auto const n_threads = this->ctx_->Threads(); const int num_feature = model.learner_model_param->num_feature; std::vector feat_vecs; @@ -877,6 +927,8 @@ class CPUPredictor : public Predictor { bool approximate) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " + "column-wise data split is not yet implemented."; const MetaInfo& info = p_fmat->Info(); const int ngroup = model.learner_model_param->num_output_group; size_t const ncolumns = model.learner_model_param->num_feature; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 11662f9b8..4b834e78f 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -929,7 +929,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictInstance(const SparsePage::Inst&, std::vector*, - const gbm::GBTreeModel&, unsigned) const override { + const gbm::GBTreeModel&, unsigned, bool) const override { LOG(FATAL) << "[Internal error]: " << __func__ << " is not implemented in GPU Predictor."; } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 401d33c4d..279ba6118 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -17,13 +17,15 @@ #include "test_predictor.h" namespace xgboost { -TEST(CpuPredictor, Basic) { + +namespace { +void TestBasic(DMatrix* dmat) { auto lparam = CreateEmptyGenericParam(GPUIDX); std::unique_ptr cpu_predictor = std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - size_t constexpr kRows = 5; - size_t constexpr kCols = 5; + size_t const kRows = dmat->Info().num_row_; + size_t const kCols = dmat->Info().num_col_; LearnerModelParam mparam{MakeMP(kCols, .0, 1)}; @@ -31,12 +33,10 @@ TEST(CpuPredictor, Basic) { ctx.UpdateAllowUnknown(Args{}); gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); - auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - // Test predict batch PredictionCacheEntry out_predictions; cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + cpu_predictor->PredictBatch(dmat, &out_predictions, model, 0); std::vector& out_predictions_h = out_predictions.predictions.HostVector(); for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { @@ -44,26 +44,32 @@ TEST(CpuPredictor, Basic) { } // Test predict instance - auto const &batch = *dmat->GetBatches().begin(); + auto const& batch = *dmat->GetBatches().begin(); auto page = batch.GetView(); for (size_t i = 0; i < batch.Size(); i++) { std::vector instance_out_predictions; - cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model); + cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model, 0, + dmat->Info().IsColumnSplit()); ASSERT_EQ(instance_out_predictions[0], 1.5); } // Test predict leaf HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); + cpu_predictor->PredictLeaf(dmat, &leaf_out_predictions, model); auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); for (auto v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); } + if (dmat->Info().IsColumnSplit()) { + // Predict contribution is not supported for column split. + return; + } + // Test predict contribution HostDeviceVector out_contribution_hdv; auto& out_contribution = out_contribution_hdv.HostVector(); - cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model); + cpu_predictor->PredictContribution(dmat, &out_contribution_hdv, model); ASSERT_EQ(out_contribution.size(), kRows * (kCols + 1)); for (size_t i = 0; i < out_contribution.size(); ++i) { auto const& contri = out_contribution[i]; @@ -76,8 +82,7 @@ TEST(CpuPredictor, Basic) { } } // Test predict contribution (approximate method) - cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model, - 0, nullptr, true); + cpu_predictor->PredictContribution(dmat, &out_contribution_hdv, model, 0, nullptr, true); for (size_t i = 0; i < out_contribution.size(); ++i) { auto const& contri = out_contribution[i]; // shift 1 for bias, as test tree is a decision dump, only global bias is @@ -89,41 +94,32 @@ TEST(CpuPredictor, Basic) { } } } +} // anonymous namespace -namespace { -void TestColumnSplitPredictBatch() { +TEST(CpuPredictor, Basic) { size_t constexpr kRows = 5; size_t constexpr kCols = 5; auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + TestBasic(dmat.get()); +} + +namespace { +void TestColumnSplit() { + size_t constexpr kRows = 5; + size_t constexpr kCols = 5; + auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); + dmat = std::unique_ptr{dmat->SliceCol(world_size, rank)}; - auto lparam = CreateEmptyGenericParam(GPUIDX); - std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - - LearnerModelParam mparam{MakeMP(kCols, .0, 1)}; - - Context ctx; - ctx.UpdateAllowUnknown(Args{}); - gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); - - // Test predict batch - PredictionCacheEntry out_predictions; - cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - auto sliced = std::unique_ptr{dmat->SliceCol(world_size, rank)}; - cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0); - - std::vector& out_predictions_h = out_predictions.predictions.HostVector(); - for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { - ASSERT_EQ(out_predictions_h[i], 1.5); - } + TestBasic(dmat.get()); } } // anonymous namespace -TEST(CpuPredictor, ColumnSplit) { +TEST(CpuPredictor, ColumnSplitBasic) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestColumnSplitPredictBatch); + RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit); } TEST(CpuPredictor, IterationRange) { @@ -133,69 +129,8 @@ TEST(CpuPredictor, IterationRange) { TEST(CpuPredictor, ExternalMemory) { size_t constexpr kPageSize = 64, kEntriesPerCol = 3; size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries); - auto lparam = CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - - LearnerModelParam mparam{MakeMP(dmat->Info().num_col_, .0, 1)}; - - Context ctx; - ctx.UpdateAllowUnknown(Args{}); - gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); - - // Test predict batch - PredictionCacheEntry out_predictions; - cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); - std::vector &out_predictions_h = out_predictions.predictions.HostVector(); - ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_); - for (const auto& v : out_predictions_h) { - ASSERT_EQ(v, 1.5); - } - - // Test predict leaf - HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); - auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); - ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_); - for (const auto& v : h_leaf_out_predictions) { - ASSERT_EQ(v, 0); - } - - // Test predict contribution - HostDeviceVector out_contribution_hdv; - auto& out_contribution = out_contribution_hdv.HostVector(); - cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model); - ASSERT_EQ(out_contribution.size(), dmat->Info().num_row_ * (dmat->Info().num_col_ + 1)); - for (size_t i = 0; i < out_contribution.size(); ++i) { - auto const& contri = out_contribution[i]; - // shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue(). - if ((i + 1) % (dmat->Info().num_col_ + 1) == 0) { - ASSERT_EQ(out_contribution.back(), 1.5f); - } else { - ASSERT_EQ(contri, 0); - } - } - - // Test predict contribution (approximate method) - HostDeviceVector out_contribution_approximate_hdv; - auto& out_contribution_approximate = out_contribution_approximate_hdv.HostVector(); - cpu_predictor->PredictContribution( - dmat.get(), &out_contribution_approximate_hdv, model, 0, nullptr, true); - ASSERT_EQ(out_contribution_approximate.size(), - dmat->Info().num_row_ * (dmat->Info().num_col_ + 1)); - for (size_t i = 0; i < out_contribution.size(); ++i) { - auto const& contri = out_contribution[i]; - // shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue(). - if ((i + 1) % (dmat->Info().num_col_ + 1) == 0) { - ASSERT_EQ(out_contribution.back(), 1.5f); - } else { - ASSERT_EQ(contri, 0); - } - } + TestBasic(dmat.get()); } TEST(CpuPredictor, InplacePredict) { From a474a6657322ae0ecbfbd667cd5e67109774a667 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Jun 2023 21:29:59 +0800 Subject: [PATCH 6/9] Bump maven-release-plugin from 3.0.0 to 3.0.1 in /jvm-packages (#9252) Bumps [maven-release-plugin](https://github.com/apache/maven-release) from 3.0.0 to 3.0.1. - [Release notes](https://github.com/apache/maven-release/releases) - [Commits](https://github.com/apache/maven-release/compare/maven-release-3.0.0...maven-release-3.0.1) --- updated-dependencies: - dependency-name: org.apache.maven.plugins:maven-release-plugin dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- jvm-packages/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index f7e90162d..1e72f7b57 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -130,7 +130,7 @@ org.apache.maven.plugins maven-release-plugin - 3.0.0 + 3.0.1 true false From 7f9cb921f4d1cece63f80876f5aece9c46681db6 Mon Sep 17 00:00:00 2001 From: Boris Date: Mon, 5 Jun 2023 22:52:10 +0200 Subject: [PATCH 7/9] Rearranged maven profiles so that scala-2.13 artifacts are published without gpu-related libraries (#9253) --- jvm-packages/pom.xml | 8 -------- tests/ci_build/deploy_jvm_packages.sh | 11 ++++++++++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 1e72f7b57..586661025 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -301,14 +301,6 @@ https://s3.amazonaws.com/xgboost-maven-repo/release - - xgboost4j - xgboost4j-example - xgboost4j-spark - xgboost4j-flink - xgboost4j-gpu - xgboost4j-spark-gpu - diff --git a/tests/ci_build/deploy_jvm_packages.sh b/tests/ci_build/deploy_jvm_packages.sh index de875b14e..5f448ee2a 100755 --- a/tests/ci_build/deploy_jvm_packages.sh +++ b/tests/ci_build/deploy_jvm_packages.sh @@ -18,8 +18,17 @@ rm -rf $(find . -name target) rm -rf ../build/ # Re-build package without Mock Rabit +# Maven profiles: +# `default` includes modules: xgboost4j, xgboost4j-spark, xgboost4j-flink, xgboost4j-example +# `gpu` includes modules: xgboost4j-gpu, xgboost4j-spark-gpu, sets `use.cuda = ON` +# `scala-2.13` sets the scala binary version to the 2.13 +# `release-to-s3` sets maven deployment targets + # Deploy to S3 bucket xgboost-maven-repo -mvn --no-transfer-progress package deploy -Duse.cuda=ON -P release-to-s3 -Dspark.version=${spark_version} -DskipTests +mvn --no-transfer-progress package deploy -P default,gpu,release-to-s3 -Dspark.version=${spark_version} -DskipTests +# Deploy scala 2.13 to S3 bucket xgboost-maven-repo +mvn --no-transfer-progress package deploy -P release-to-s3,default,scala-2.13 -Dspark.version=${spark_version} -DskipTests + set +x set +e From fc8110ef79e58bc8648c1934cdf303ce1bcc009a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 6 Jun 2023 08:20:10 +0800 Subject: [PATCH 8/9] Remove document and demo in RABIT. (#9246) --- rabit/.gitignore | 52 ---- rabit/doc/.gitignore | 5 - rabit/doc/Doxyfile | 281 ---------------------- rabit/doc/Makefile | 192 --------------- rabit/doc/conf.py | 184 -------------- rabit/doc/cpp_api.md | 9 - rabit/doc/guide.md | 383 ------------------------------ rabit/doc/index.md | 24 -- rabit/doc/parameters.md | 21 -- rabit/doc/python-requirements.txt | 3 - rabit/doc/python_api.md | 11 - rabit/doc/sphinx_util.py | 16 -- rabit/guide/Makefile | 26 -- rabit/guide/README | 1 - rabit/guide/basic.cc | 35 --- rabit/guide/basic.py | 27 --- rabit/guide/broadcast.cc | 16 -- rabit/guide/broadcast.py | 23 -- rabit/guide/lazy_allreduce.cc | 34 --- rabit/guide/lazy_allreduce.py | 31 --- 20 files changed, 1374 deletions(-) delete mode 100644 rabit/.gitignore delete mode 100644 rabit/doc/.gitignore delete mode 100644 rabit/doc/Doxyfile delete mode 100644 rabit/doc/Makefile delete mode 100644 rabit/doc/conf.py delete mode 100644 rabit/doc/cpp_api.md delete mode 100644 rabit/doc/guide.md delete mode 100644 rabit/doc/index.md delete mode 100644 rabit/doc/parameters.md delete mode 100644 rabit/doc/python-requirements.txt delete mode 100644 rabit/doc/python_api.md delete mode 100644 rabit/doc/sphinx_util.py delete mode 100644 rabit/guide/Makefile delete mode 100644 rabit/guide/README delete mode 100644 rabit/guide/basic.cc delete mode 100755 rabit/guide/basic.py delete mode 100644 rabit/guide/broadcast.cc delete mode 100755 rabit/guide/broadcast.py delete mode 100644 rabit/guide/lazy_allreduce.cc delete mode 100755 rabit/guide/lazy_allreduce.py diff --git a/rabit/.gitignore b/rabit/.gitignore deleted file mode 100644 index ad9fedf10..000000000 --- a/rabit/.gitignore +++ /dev/null @@ -1,52 +0,0 @@ -# Compiled Object files -*.slo -*.lo -*.o -*.obj - -# Precompiled Headers -*.gch -*.pch -*.lnk -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Fortran module files -*.mod - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Executables -*.miss -*.exe -*.out -*.app -*~ -*.pyc -*.mpi -*.exe -*tmp* -*.rabit -*.mock -recommonmark -recom -_* - -#mpi lib -mpich/ -mpich-3.2/ - -# Jetbrain -.idea -cmake-build-debug/ -.vscode/ - -# cmake -build/ -compile_commands.json \ No newline at end of file diff --git a/rabit/doc/.gitignore b/rabit/doc/.gitignore deleted file mode 100644 index 95f88be43..000000000 --- a/rabit/doc/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -html -latex -*.sh -_* -doxygen diff --git a/rabit/doc/Doxyfile b/rabit/doc/Doxyfile deleted file mode 100644 index 3e64641f3..000000000 --- a/rabit/doc/Doxyfile +++ /dev/null @@ -1,281 +0,0 @@ -# Doxyfile 1.7.6.1 - -#--------------------------------------------------------------------------- -# Project related configuration options -#--------------------------------------------------------------------------- -DOXYFILE_ENCODING = UTF-8 -PROJECT_NAME = "rabit" -PROJECT_NUMBER = -PROJECT_BRIEF = -PROJECT_LOGO = -OUTPUT_DIRECTORY = ../doc/doxygen -CREATE_SUBDIRS = NO -OUTPUT_LANGUAGE = English -BRIEF_MEMBER_DESC = YES -REPEAT_BRIEF = YES -ABBREVIATE_BRIEF = -ALWAYS_DETAILED_SEC = NO -INLINE_INHERITED_MEMB = NO -FULL_PATH_NAMES = YES -STRIP_FROM_PATH = -STRIP_FROM_INC_PATH = -SHORT_NAMES = NO -JAVADOC_AUTOBRIEF = NO -QT_AUTOBRIEF = NO -MULTILINE_CPP_IS_BRIEF = NO -INHERIT_DOCS = YES -SEPARATE_MEMBER_PAGES = NO -TAB_SIZE = 8 -ALIASES = -TCL_SUBST = -OPTIMIZE_OUTPUT_FOR_C = YES -OPTIMIZE_OUTPUT_JAVA = NO -OPTIMIZE_FOR_FORTRAN = NO -OPTIMIZE_OUTPUT_VHDL = NO -EXTENSION_MAPPING = -BUILTIN_STL_SUPPORT = NO -CPP_CLI_SUPPORT = NO -SIP_SUPPORT = NO -IDL_PROPERTY_SUPPORT = YES -DISTRIBUTE_GROUP_DOC = NO -SUBGROUPING = YES -INLINE_GROUPED_CLASSES = NO -INLINE_SIMPLE_STRUCTS = NO -TYPEDEF_HIDES_STRUCT = NO -LOOKUP_CACHE_SIZE = 0 -#--------------------------------------------------------------------------- -# Build related configuration options -#--------------------------------------------------------------------------- -EXTRACT_ALL = NO -EXTRACT_PRIVATE = NO -EXTRACT_STATIC = NO -EXTRACT_LOCAL_CLASSES = YES -EXTRACT_LOCAL_METHODS = NO -EXTRACT_ANON_NSPACES = NO -HIDE_UNDOC_MEMBERS = NO -HIDE_UNDOC_CLASSES = YES -HIDE_FRIEND_COMPOUNDS = NO -HIDE_IN_BODY_DOCS = NO -INTERNAL_DOCS = NO -CASE_SENSE_NAMES = YES -HIDE_SCOPE_NAMES = NO -SHOW_INCLUDE_FILES = YES -FORCE_LOCAL_INCLUDES = NO -INLINE_INFO = YES -SORT_MEMBER_DOCS = YES -SORT_BRIEF_DOCS = NO -SORT_MEMBERS_CTORS_1ST = NO -SORT_GROUP_NAMES = NO -SORT_BY_SCOPE_NAME = NO -STRICT_PROTO_MATCHING = NO -GENERATE_TODOLIST = YES -GENERATE_TESTLIST = YES -GENERATE_BUGLIST = YES -GENERATE_DEPRECATEDLIST= YES -ENABLED_SECTIONS = -MAX_INITIALIZER_LINES = 30 -SHOW_USED_FILES = YES -SHOW_FILES = YES -SHOW_NAMESPACES = YES -FILE_VERSION_FILTER = -LAYOUT_FILE = -CITE_BIB_FILES = -#--------------------------------------------------------------------------- -# configuration options related to warning and progress messages -#--------------------------------------------------------------------------- -QUIET = NO -WARNINGS = YES -WARN_IF_UNDOCUMENTED = YES -WARN_IF_DOC_ERROR = YES -WARN_NO_PARAMDOC = YES -WARN_FORMAT = "$file:$line: $text" -WARN_LOGFILE = -#--------------------------------------------------------------------------- -# configuration options related to the input files -#--------------------------------------------------------------------------- -INPUT = rabit -INPUT_ENCODING = UTF-8 -FILE_PATTERNS = -RECURSIVE = NO -EXCLUDE = -EXCLUDE_SYMLINKS = NO -EXCLUDE_PATTERNS = *-inl.hpp -EXCLUDE_SYMBOLS = -EXAMPLE_PATH = -EXAMPLE_PATTERNS = -EXAMPLE_RECURSIVE = NO -IMAGE_PATH = -INPUT_FILTER = -FILTER_PATTERNS = -FILTER_SOURCE_FILES = NO -FILTER_SOURCE_PATTERNS = -#--------------------------------------------------------------------------- -# configuration options related to source browsing -#--------------------------------------------------------------------------- -SOURCE_BROWSER = NO -INLINE_SOURCES = NO -STRIP_CODE_COMMENTS = YES -REFERENCED_BY_RELATION = NO -REFERENCES_RELATION = NO -REFERENCES_LINK_SOURCE = YES -USE_HTAGS = NO -VERBATIM_HEADERS = YES -#--------------------------------------------------------------------------- -# configuration options related to the alphabetical class index -#--------------------------------------------------------------------------- -ALPHABETICAL_INDEX = YES -COLS_IN_ALPHA_INDEX = 5 -IGNORE_PREFIX = -#--------------------------------------------------------------------------- -# configuration options related to the HTML output -#--------------------------------------------------------------------------- -GENERATE_HTML = YES -HTML_OUTPUT = html -HTML_FILE_EXTENSION = .html -HTML_HEADER = -HTML_FOOTER = -HTML_STYLESHEET = -HTML_EXTRA_FILES = -HTML_COLORSTYLE_HUE = 220 -HTML_COLORSTYLE_SAT = 100 -HTML_COLORSTYLE_GAMMA = 80 -HTML_TIMESTAMP = YES -HTML_DYNAMIC_SECTIONS = NO -GENERATE_DOCSET = NO -DOCSET_FEEDNAME = "Doxygen generated docs" -DOCSET_BUNDLE_ID = org.doxygen.Project -DOCSET_PUBLISHER_ID = org.doxygen.Publisher -DOCSET_PUBLISHER_NAME = Publisher -GENERATE_HTMLHELP = NO -CHM_FILE = -HHC_LOCATION = -GENERATE_CHI = NO -CHM_INDEX_ENCODING = -BINARY_TOC = NO -TOC_EXPAND = NO -GENERATE_QHP = NO -QCH_FILE = -QHP_NAMESPACE = org.doxygen.Project -QHP_VIRTUAL_FOLDER = doc -QHP_CUST_FILTER_NAME = -QHP_CUST_FILTER_ATTRS = -QHP_SECT_FILTER_ATTRS = -QHG_LOCATION = -GENERATE_ECLIPSEHELP = NO -ECLIPSE_DOC_ID = org.doxygen.Project -DISABLE_INDEX = NO -GENERATE_TREEVIEW = NO -ENUM_VALUES_PER_LINE = 4 -TREEVIEW_WIDTH = 250 -EXT_LINKS_IN_WINDOW = NO -FORMULA_FONTSIZE = 10 -FORMULA_TRANSPARENT = YES -USE_MATHJAX = NO -MATHJAX_RELPATH = http://www.mathjax.org/mathjax -MATHJAX_EXTENSIONS = -SEARCHENGINE = YES -SERVER_BASED_SEARCH = NO -#--------------------------------------------------------------------------- -# configuration options related to the LaTeX output -#--------------------------------------------------------------------------- -GENERATE_LATEX = YES -LATEX_OUTPUT = latex -LATEX_CMD_NAME = latex -MAKEINDEX_CMD_NAME = makeindex -COMPACT_LATEX = NO -PAPER_TYPE = a4 -EXTRA_PACKAGES = -LATEX_HEADER = -LATEX_FOOTER = -PDF_HYPERLINKS = YES -USE_PDFLATEX = YES -LATEX_BATCHMODE = NO -LATEX_HIDE_INDICES = NO -LATEX_SOURCE_CODE = NO -LATEX_BIB_STYLE = plain -#--------------------------------------------------------------------------- -# configuration options related to the RTF output -#--------------------------------------------------------------------------- -GENERATE_RTF = NO -RTF_OUTPUT = rtf -COMPACT_RTF = NO -RTF_HYPERLINKS = NO -RTF_STYLESHEET_FILE = -RTF_EXTENSIONS_FILE = -#--------------------------------------------------------------------------- -# configuration options related to the man page output -#--------------------------------------------------------------------------- -GENERATE_MAN = NO -MAN_OUTPUT = man -MAN_EXTENSION = .3 -MAN_LINKS = NO -#--------------------------------------------------------------------------- -# configuration options related to the XML output -#--------------------------------------------------------------------------- -GENERATE_XML = YES -XML_OUTPUT = xml -XML_PROGRAMLISTING = YES -#--------------------------------------------------------------------------- -# configuration options for the AutoGen Definitions output -#--------------------------------------------------------------------------- -GENERATE_AUTOGEN_DEF = NO -#--------------------------------------------------------------------------- -# configuration options related to the Perl module output -#--------------------------------------------------------------------------- -GENERATE_PERLMOD = NO -PERLMOD_LATEX = NO -PERLMOD_PRETTY = YES -PERLMOD_MAKEVAR_PREFIX = -#--------------------------------------------------------------------------- -# Configuration options related to the preprocessor -#--------------------------------------------------------------------------- -ENABLE_PREPROCESSING = NO -MACRO_EXPANSION = NO -EXPAND_ONLY_PREDEF = NO -SEARCH_INCLUDES = YES -INCLUDE_PATH = -INCLUDE_FILE_PATTERNS = -PREDEFINED = -EXPAND_AS_DEFINED = -SKIP_FUNCTION_MACROS = YES -#--------------------------------------------------------------------------- -# Configuration::additions related to external references -#--------------------------------------------------------------------------- -TAGFILES = -GENERATE_TAGFILE = -ALLEXTERNALS = NO -EXTERNAL_GROUPS = YES -PERL_PATH = /usr/bin/perl -#--------------------------------------------------------------------------- -# Configuration options related to the dot tool -#--------------------------------------------------------------------------- -CLASS_DIAGRAMS = YES -MSCGEN_PATH = -HIDE_UNDOC_RELATIONS = YES -HAVE_DOT = NO -DOT_NUM_THREADS = 0 -DOT_FONTNAME = Helvetica -DOT_FONTSIZE = 10 -DOT_FONTPATH = -CLASS_GRAPH = YES -COLLABORATION_GRAPH = YES -GROUP_GRAPHS = YES -UML_LOOK = NO -TEMPLATE_RELATIONS = NO -INCLUDE_GRAPH = YES -INCLUDED_BY_GRAPH = YES -CALL_GRAPH = NO -CALLER_GRAPH = NO -GRAPHICAL_HIERARCHY = YES -DIRECTORY_GRAPH = YES -DOT_IMAGE_FORMAT = png -INTERACTIVE_SVG = NO -DOT_PATH = -DOTFILE_DIRS = -MSCFILE_DIRS = -DOT_GRAPH_MAX_NODES = 50 -MAX_DOT_GRAPH_DEPTH = 0 -DOT_TRANSPARENT = NO -DOT_MULTI_TARGETS = YES -GENERATE_LEGEND = YES -DOT_CLEANUP = YES diff --git a/rabit/doc/Makefile b/rabit/doc/Makefile deleted file mode 100644 index 40bba2a28..000000000 --- a/rabit/doc/Makefile +++ /dev/null @@ -1,192 +0,0 @@ -# Makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -BUILDDIR = _build - -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext - -help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " applehelp to make an Apple Help Book" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " xml to make Docutils-native XML files" - @echo " pseudoxml to make pseudoxml-XML files for display purposes" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " coverage to run coverage check of the documentation (if enabled)" - -clean: - rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/rabit.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/rabit.qhc" - -applehelp: - $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp - @echo - @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." - @echo "N.B. You won't be able to view it unless you put it in" \ - "~/Library/Documentation/Help or install it in your application" \ - "bundle." - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/rabit" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/rabit" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -latexpdfja: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through platex and dvipdfmx..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." - -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." - -coverage: - $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage - @echo "Testing of coverage in the sources finished, look at the " \ - "results in $(BUILDDIR)/coverage/python.txt." - -xml: - $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml - @echo - @echo "Build finished. The XML files are in $(BUILDDIR)/xml." - -pseudoxml: - $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml - @echo - @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/rabit/doc/conf.py b/rabit/doc/conf.py deleted file mode 100644 index ef89de489..000000000 --- a/rabit/doc/conf.py +++ /dev/null @@ -1,184 +0,0 @@ -# -*- coding: utf-8 -*- -# -# documentation build configuration file, created by -# sphinx-quickstart on Thu Jul 23 19:40:08 2015. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. -import sys -import os, subprocess -import shlex -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) -libpath = os.path.join(curr_path, '../wrapper/') -sys.path.insert(0, os.path.join(curr_path, '../wrapper/')) -sys.path.insert(0, curr_path) -from sphinx_util import MarkdownParser, AutoStructify - -# -- General configuration ------------------------------------------------ - -# General information about the project. -project = u'rabit' -copyright = u'2015, rabit developers' -author = u'rabit developers' -github_doc_root = 'https://github.com/dmlc/rabit/tree/master/doc/' - -# add markdown parser -MarkdownParser.github_doc_root = github_doc_root -source_parsers = { - '.md': MarkdownParser, -} -# Version information. -import rabit - -version = rabit.__version__ -release = rabit.__version__ - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.mathjax', - 'breathe', -] - -# Use breathe to include doxygen documents -breathe_projects = {'rabit' : 'doxygen/xml/'} -breathe_default_project = 'rabit' - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# source_suffix = ['.rst', '.md'] -source_suffix = ['.rst', '.md'] - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ['_build'] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# html_theme = 'alabaster' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Output file base name for HTML help builder. -htmlhelp_basename = project + 'doc' - -# -- Options for LaTeX output --------------------------------------------- -latex_elements = { -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'rabit.tex', project, - author, 'manual'), -] - -# hook for doxygen -def run_doxygen(folder): - """Run the doxygen make command in the designated folder.""" - try: - retcode = subprocess.call("cd %s; make doxygen" % folder, shell=True) - if retcode < 0: - sys.stderr.write("doxygen terminated by signal %s" % (-retcode)) - except OSError as e: - sys.stderr.write("doxygen execution failed: %s" % e) - - -def run_build_lib(folder): - """Run the doxygen make command in the designated folder.""" - try: - retcode = subprocess.call("cd %s; make" % folder, shell=True) - retcode = subprocess.call("rm -rf _build/html/doxygen", shell=True) - retcode = subprocess.call("mkdir _build", shell=True) - retcode = subprocess.call("mkdir _build/html", shell=True) - retcode = subprocess.call("cp -rf doxygen/html _build/html/doxygen", shell=True) - if retcode < 0: - sys.stderr.write("build terminated by signal %s" % (-retcode)) - except OSError as e: - sys.stderr.write("build execution failed: %s" % e) - - -def generate_doxygen_xml(app): - """Run the doxygen make commands if we're on the ReadTheDocs server""" - read_the_docs_build = os.environ.get('READTHEDOCS', None) == 'True' - if read_the_docs_build: - run_doxygen('..') - sys.stderr.write('Check if shared lib exists\n') - run_build_lib('..') - sys.stderr.write('The wrapper path: %s\n' % str(os.listdir('../wrapper'))) - rabit._loadlib() - - -def setup(app): - # Add hook for building doxygen xml when needed - app.connect("builder-inited", generate_doxygen_xml) - app.add_config_value('recommonmark_config', { - 'url_resolver': lambda url: github_doc_root + url, - }, True) - app.add_transform(AutoStructify) diff --git a/rabit/doc/cpp_api.md b/rabit/doc/cpp_api.md deleted file mode 100644 index c6184aa08..000000000 --- a/rabit/doc/cpp_api.md +++ /dev/null @@ -1,9 +0,0 @@ -C++ Library API of Rabit -======================== -This page contains document of Library API of rabit. - -```eval_rst -.. toctree:: - -.. doxygennamespace:: rabit -``` diff --git a/rabit/doc/guide.md b/rabit/doc/guide.md deleted file mode 100644 index 7bf50b09d..000000000 --- a/rabit/doc/guide.md +++ /dev/null @@ -1,383 +0,0 @@ -Tutorial -======== -This is rabit's tutorial, a ***Reliable Allreduce and Broadcast Interface***. -All the example codes are in the [guide](https://github.com/dmlc/rabit/blob/master/guide/) folder of the project. -To run the examples locally, you will need to build them with ```make```. - -**List of Topics** -* [What is Allreduce](#what-is-allreduce) -* [Common Use Case](#common-use-case) -* [Use Rabit API](#use-rabit-api) - - [Structure of a Rabit Program](#structure-of-a-rabit-program) - - [Allreduce and Lazy Preparation](#allreduce-and-lazy-preparation) - - [Checkpoint and LazyCheckpoint](#checkpoint-and-lazycheckpoint) -* [Compile Programs with Rabit](#compile-programs-with-rabit) -* [Running Rabit Jobs](#running-rabit-jobs) -* [Fault Tolerance](#fault-tolerance) - -What is Allreduce ------------------ -The main methods provided by rabit are Allreduce and Broadcast. Allreduce performs reduction across different computation nodes, -and returns the result to every node. To understand the behavior of the function, consider the following example in [basic.cc](../guide/basic.cc) (there is a python example right after this if you are more familiar with python). -```c++ -#include -using namespace rabit; -const int N = 3; -int main(int argc, char *argv[]) { - int a[N]; - rabit::Init(argc, argv); - for (int i = 0; i < N; ++i) { - a[i] = rabit::GetRank() + i; - } - printf("@node[%d] before-allreduce: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // allreduce take max of each elements in all processes - Allreduce(&a[0], N); - printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // second allreduce that sums everything up - Allreduce(&a[0], N); - printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - rabit::Finalize(); - return 0; -} -``` -You can run the example using the rabit_demo.py script. The following command -starts the rabit program with two worker processes. -```bash -../tracker/rabit_demo.py -n 2 basic.rabit -``` -This will start two processes, one process with rank 0 and the other with rank 1, both processes run the same code. -The ```rabit::GetRank()``` function returns the rank of current process. - -Before the call to Allreduce, process 0 contains the array ```a = {0, 1, 2}```, while process 1 has the array -```a = {1, 2, 3}```. After the call to Allreduce, the array contents in all processes are replaced by the -reduction result (in this case, the maximum value in each position across all the processes). So, after the -Allreduce call, the result will become ```a = {1, 2, 3}```. -Rabit provides different reduction operators, for example, if you change ```op::Max``` to ```op::Sum```, -the reduction operation will be a summation, and the result will become ```a = {1, 3, 5}```. -You can also run the example with different processes by setting -n to different values. - -If you are more familiar with python, you can also use rabit in python. The same example as before can be found in [basic.py](../guide/basic.py): - -```python -import numpy as np -import rabit - -rabit.init() -n = 3 -rank = rabit.get_rank() -a = np.zeros(n) -for i in xrange(n): - a[i] = rank + i - -print '@node[%d] before-allreduce: a=%s' % (rank, str(a)) -a = rabit.allreduce(a, rabit.MAX) -print '@node[%d] after-allreduce-max: a=%s' % (rank, str(a)) -a = rabit.allreduce(a, rabit.SUM) -print '@node[%d] after-allreduce-sum: a=%s' % (rank, str(a)) -rabit.finalize() -``` -You can run the program using the following command -```bash -../tracker/rabit_demo.py -n 2 basic.py -``` - -Broadcast is another method provided by rabit besides Allreduce. This function allows one node to broadcast its -local data to all other nodes. The following code in [broadcast.cc](../guide/broadcast.cc) broadcasts a string from -node 0 to all other nodes. -```c++ -#include -using namespace rabit; -const int N = 3; -int main(int argc, char *argv[]) { - rabit::Init(argc, argv); - std::string s; - if (rabit::GetRank() == 0) s = "hello world"; - printf("@node[%d] before-broadcast: s=\"%s\"\n", - rabit::GetRank(), s.c_str()); - // broadcast s from node 0 to all other nodes - rabit::Broadcast(&s, 0); - printf("@node[%d] after-broadcast: s=\"%s\"\n", - rabit::GetRank(), s.c_str()); - rabit::Finalize(); - return 0; -} -``` -The following command starts the program with three worker processes. -```bash -../tracker/rabit_demo.py -n 3 broadcast.rabit -``` -Besides strings, rabit also allows to broadcast constant size array and vectors. - -The counterpart in python can be found in [broadcast.py](../guide/broadcast.py). Here is a snippet so that you can get a better sense of how simple is to use the python library: - -```python -import rabit -rabit.init() -n = 3 -rank = rabit.get_rank() -s = None -if rank == 0: - s = {'hello world':100, 2:3} -print '@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s)) -s = rabit.broadcast(s, 0) -print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s)) -rabit.finalize() -``` - -Common Use Case ---------------- -Many distributed machine learning algorithms involve splitting the data into different nodes, -computing statistics locally, and finally aggregating them. Such workflow is usually done repetitively through many iterations before the algorithm converges. Allreduce naturally meets the structure of such programs, -common use cases include: - -* Aggregation of gradient values, which can be used in optimization methods such as L-BFGS. -* Aggregation of other statistics, which can be used in KMeans and Gaussian Mixture Models. -* Find the best split candidate and aggregation of split statistics, used for tree based models. - -Rabit is a reliable and portable library for distributed machine learning programs, that allow programs to run reliably on different platforms. - -Use Rabit API -------------- -This section introduces topics about how to use rabit API. -You can always refer to [API Documentation](http://homes.cs.washington.edu/~tqchen/rabit/doc) for definition of each functions. -This section trys to gives examples of different aspectes of rabit API. - -#### Structure of a Rabit Program -The following code illustrates the common structure of a rabit program. This is an abstract example, -you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/learn/kmeans/kmeans.cc) for an example implementation of kmeans algorithm. - -```c++ -#include -int main(int argc, char *argv[]) { - ... - rabit::Init(argc, argv); - // sync on expected model size before load checkpoint, if we pass rabit_bootstrap_cache=true - rabit::Allreduce(&model.size(), 1); - // load the latest checked model - int version = rabit::LoadCheckPoint(&model); - // initialize the model if it is the first version - if (version == 0) model.InitModel(); - // the version number marks the iteration to resume - for (int iter = version; iter < max_iter; ++iter) { - // at this point, the model object should allow us to recover the program state - ... - // each iteration can contain multiple calls of allreduce/broadcast - rabit::Allreduce(&data[0], n); - ... - // checkpoint model after one iteration finishes - rabit::CheckPoint(&model); - } - rabit::Finalize(); - return 0; -} -``` - -Besides the common Allreduce and Broadcast functions, there are two additional functions: ```LoadCheckPoint``` -and ```CheckPoint```. These two functions are used for fault-tolerance purposes. -As mentioned before, traditional machine learning programs involve several iterations. In each iteration, we start with a model, make some calls -to Allreduce or Broadcast and update the model. The calling sequence in each iteration does not need to be the same. - -* When the nodes start from the beginning (i.e. iteration 0), ```LoadCheckPoint``` returns 0, so we can initialize the model. -* ```CheckPoint``` saves the model after each iteration. - - Efficiency Note: the model is only kept in local memory and no save to disk is performed when calling Checkpoint -* When a node goes down and restarts, ```LoadCheckPoint``` will recover the latest saved model, and -* When a node goes down, the rest of the nodes will block in the call of Allreduce/Broadcast and wait for - the recovery of the failed node until it catches up. - -Please see the [Fault Tolerance](#fault-tolerance) section to understand the recovery procedure executed by rabit. - -#### Allreduce and Lazy Preparation -Allreduce is one of the most important function provided by rabit. You can call allreduce by specifying the -reduction operator, pointer to the data and size of the buffer, as follows -```c++ -Allreduce(pointer_of_data, size_of_data); -``` -This is the basic use case of Allreduce function. It is common that user writes the code to prepare the data needed -into the data buffer, pass the data to Allreduce function, and get the reduced result. However, when a node restarts -from failure, we can directly recover the result from other nodes(see also [Fault Tolerance](#fault-tolerance)) and -the data preparation procedure no longer necessary. Rabit Allreduce add an optional parameter preparation function -to support such scenario. User can pass in a function that corresponds to the data preparation procedure to Allreduce -calls, and the data preparation function will only be called when necessary. We use [lazy_allreduce.cc](../guide/lazy_allreduce.cc) -as an example to demonstrate this feature. It is modified from [basic.cc](../guide/basic.cc), and you can compare the two codes. -```c++ -#include -using namespace rabit; -const int N = 3; -int main(int argc, char *argv[]) { - int a[N] = {0}; - rabit::Init(argc, argv); - // lazy preparation function - auto prepare = [&]() { - printf("@node[%d] run prepare function\n", rabit::GetRank()); - for (int i = 0; i < N; ++i) { - a[i] = rabit::GetRank() + i; - } - }; - printf("@node[%d] before-allreduce: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // allreduce take max of each elements in all processes - Allreduce(&a[0], N, prepare); - printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // rum second allreduce - Allreduce(&a[0], N); - printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - rabit::Finalize(); - return 0; -} -``` -Here we use features of C++11 because the lambda function makes things much shorter. -There is also C++ compatible callback interface provided in the [API](http://homes.cs.washington.edu/~tqchen/rabit/doc). -You can compile the program by typing ```make lazy_allreduce.mock```. We link against the mock library so that we can see -the effect when a process goes down. You can run the program using the following command -```bash -../tracker/rabit_demo.py -n 2 lazy_allreduce.mock mock=0,0,1,0 -``` -The additional arguments ```mock=0,0,1,0``` will cause node 0 to kill itself before second call of Allreduce (see also [mock test](#link-against-mock-test-rabit-library)). -You will find that the prepare function's print is only executed once and node 0 will no longer execute the preparation function when it restarts from failure. - -You can also find python version of the example in [lazy_allreduce.py](../guide/lazy_allreduce.py), and run it using the followin command -```bash -../tracker/rabit_demo.py -n 2 lazy_allreduce.py mock=0,0,1,0 - -``` - -Since lazy preparation function may not be called during execution. User should be careful when using this feature. For example, a possible mistake -could be putting some memory allocation code in the lazy preparation function, and the computing memory was not allocated when lazy preparation function is not called. -The example in [lazy_allreduce.cc](../guide/lazy_allreduce.cc) provides a simple way to migrate normal prepration code([basic.cc](../guide/basic.cc)) to lazy version: wrap the preparation -code with a lambda function, and pass it to allreduce. - -#### Checkpoint and LazyCheckpoint -Common machine learning algorithms usually involves iterative computation. As mentioned in the section ([Structure of a Rabit Program](#structure-of-a-rabit-program)), -user can and should use Checkpoint to ```save``` the progress so far, so that when a node fails, the latest checkpointed model can be loaded. - -There are two model arguments you can pass to Checkpoint and LoadCheckpoint: ```global_model``` and ```local_model```: -* ```global_model``` refers to the model that is commonly shared across all the nodes - - For example, the centriods of clusters in kmeans is shared across all nodes -* ```local_model``` refers to the model that is specifically tied to the current node - - For example, in topic modeling, the topic assignments of subset of documents in current node is local model - -Because the different nature of the two types of models, different strategy will be used for them. -```global_model``` is simply saved in local memory of each node, while ```local_model``` will replicated to some other -nodes (selected using a ring replication strategy). The checkpoint is only saved in the memory without touching the disk which makes rabit programs more efficient. -User is encouraged to use ```global_model``` only when is sufficient for better efficiency. - -To enable a model class to be checked pointed, user can implement a [serialization interface](../include/rabit_serialization.h). The serialization interface already -provide serialization functions of STL vector and string. For python API, user can checkpoint any python object that can be pickled. - -There is a special Checkpoint function called [LazyCheckpoint](http://homes.cs.washington.edu/~tqchen/rabit/doc/namespacerabit.html#a99f74c357afa5fba2c80cc0363e4e459), -which can be used for ```global_model``` only cases under certain condition. -When LazyCheckpoint is called, no action is taken and the rabit engine only remembers the pointer to the model. -The serialization will only happen when another node fails and the recovery starts. So user basically pays no extra cost calling LazyCheckpoint. -To use this function, the user need to ensure the model remain unchanged until the last call of Allreduce/Broadcast in the current version finishes. -So that when recovery procedure happens in these function calls, the serialized model will be the same. - -For example, consider the following calling sequence -``` -LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint -``` -The user must only change the model in code3. Such condition can usually be satiesfied in many scenarios, and user can use LazyCheckpoint to further -improve the efficiency of the program. - - -Compile Programs with Rabit ---------------------------- -Rabit is a portable library, to use it, you only need to include the rabit header file. -* You will need to add the path to [../include](../include) to the header search path of the compiler - - Solution 1: add ```-I/path/to/rabit/include``` to the compiler flag in gcc or clang - - Solution 2: add the path to the environment variable CPLUS_INCLUDE_PATH -* You will need to add the path to [../lib](../lib) to the library search path of the compiler - - Solution 1: add ```-L/path/to/rabit/lib``` to the linker flag - - Solution 2: add the path to environment variable LIBRARY_PATH AND LD_LIBRARY_PATH -* Link against lib/rabit.a - - Add ```-lrabit``` to the linker flag - -The procedure above allows you to compile a program with rabit. The following two sections contain additional -options you can use to link against different backends other than the normal one. - -#### Link against MPI Allreduce -You can link against ```rabit_mpi.a``` instead of using MPI Allreduce, however, the resulting program is backed by MPI and -is not fault tolerant anymore. -* Simply change the linker flag from ```-lrabit``` to ```-lrabit_mpi``` -* The final linking needs to be done by mpi wrapper compiler ```mpicxx``` - -#### Link against Mock Test Rabit Library -If you want to use a mock to test the program in order to see the behavior of the code when some nodes go down, you can link against ```rabit_mock.a``` . -* Simply change the linker flag from ```-lrabit``` to ```-lrabit_mock``` - -The resulting rabit mock program can take in additional arguments in the following format -``` -mock=rank,version,seq,ndeath -``` - -The four integers specify an event that will cause the program to ```commit suicide```(exit with -2) -* rank specifies the rank of the node to kill -* version specifies the version (iteration) of the model where you want the process to die -* seq specifies the sequence number of the Allreduce/Broadcast call since last checkpoint, where the process will be killed -* ndeath specifies how many times this node died already - -For example, consider the following script in the test case -```bash -../tracker/rabit_demo.py -n 10 test_model_recover 10000\ - mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 -``` -* The first mock will cause node 0 to exit when calling the second Allreduce/Broadcast (seq = 1) in iteration 0 -* The second mock will cause node 1 to exit when calling the second Allreduce/Broadcast (seq = 1) in iteration 1 -* The third mock will cause node 1 to exit again when calling second Allreduce/Broadcast (seq = 1) in iteration 1 - - Note that ndeath = 1 means this will happen only if node 1 died once, which is our case - -Running Rabit Jobs ------------------- -Rabit is a portable library that can run on multiple platforms. -All the rabit jobs can be submitted using [dmlc-tracker](https://github.com/dmlc/dmlc-core/tree/master/tracker) - -Fault Tolerance ---------------- -This section introduces how fault tolerance works in rabit. -The following figure shows how rabit deals with failures. - -![](http://homes.cs.washington.edu/~tqchen/rabit/fig/fault-tol.png) - -The scenario is as follows: -* Node 1 fails between the first and second call of Allreduce after the second checkpoint -* The other nodes wait in the call of the second Allreduce in order to help node 1 to recover. -* When node 1 restarts, it will call ```LoadCheckPoint```, and get the latest checkpoint from one of the existing nodes. -* Then node 1 can start from the latest checkpoint and continue running. -* When node 1 calls the first Allreduce again, as the other nodes already know the result, node 1 can get it from one of them. -* When node 1 reaches the second Allreduce, the other nodes find out that node 1 has catched up and they can continue the program normally. - -This fault tolerance model is based on a key property of Allreduce and -Broadcast: All the nodes get the same result after calling Allreduce/Broadcast. -Because of this property, any node can record the results of history -Allreduce/Broadcast calls. When a node is recovered, it can fetch the lost -results from some alive nodes and rebuild its model. - -The checkpoint is introduced so that we can discard the history results of -Allreduce/Broadcast calls before the latest checkpoint. This saves memory -consumption used for backup. The checkpoint of each node is a model defined by -users and can be split into 2 parts: a global model and a local model. The -global model is shared by all nodes and can be backed up by any nodes. The -local model of a node is replicated to some other nodes (selected using a ring -replication strategy). The checkpoint is only saved in the memory without -touching the disk which makes rabit programs more efficient. The strategy of -rabit is different from the fail-restart strategy where all the nodes restart -from the same checkpoint when any of them fail. In rabit, all the alive nodes -will block in the Allreduce call and help the recovery. To catch up, the -recovered node fetches its latest checkpoint and the results of -Allreduce/Broadcast calls after the checkpoint from some alive nodes. - -This is just a conceptual introduction to rabit's fault tolerance model. The actual implementation is more sophisticated, -and can deal with more complicated cases such as multiple nodes failure and node failure during recovery phase. - -Rabit Timeout ---------------- - -In certain cases, rabit cluster may suffer lack of resources to retry failed workers. -Thanks to fault tolerant assumption with infinite retry, it might cause entire cluster hang infinitely. -We introduce sidecar thread which runs when rabit fault tolerant runtime observed allreduce/broadcast errors. -By default, it will wait for 30 mins before all workers program exit. -User can opt-in this feature and change treshold by passing rabit_timeout=true and rabit_timeout_sec=x (in seconds). diff --git a/rabit/doc/index.md b/rabit/doc/index.md deleted file mode 100644 index d209d95ba..000000000 --- a/rabit/doc/index.md +++ /dev/null @@ -1,24 +0,0 @@ -Rabit Documentation -===================== -rabit is a light weight library that provides a fault tolerant interface of Allreduce and Broadcast. It is designed to support easy implementations of distributed machine learning programs, many of which fall naturally under the Allreduce abstraction. The goal of rabit is to support **portable** , **scalable** and **reliable** distributed machine learning programs. - -API Documents -------------- -```eval_rst - -.. toctree:: - :maxdepth: 2 - - python_api.md - cpp_api.md - parameters.md - guide.md -``` -Indices and tables ------------------- - -```eval_rst -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` -``` \ No newline at end of file diff --git a/rabit/doc/parameters.md b/rabit/doc/parameters.md deleted file mode 100644 index eca8d0f5d..000000000 --- a/rabit/doc/parameters.md +++ /dev/null @@ -1,21 +0,0 @@ -Parameters -========== -This section list all the parameters that can be passed to rabit::Init function as argv. -All the parameters are passed in as string in format of ``parameter-name=parameter-value``. -In most setting these parameters have default value or will be automatically detected, -and do not need to be manually configured. - -* rabit_tracker_uri [passed in automatically by tracker] - - The uri/ip of rabit tracker -* rabit_tracker_port [passed in automatically by tracker] - - The port of rabit tracker -* rabit_task_id [automatically detected] - - The unique identifier of computing process - - When running on Hadoop, this is automatically extracted from environment variable -* rabit_reduce_buffer [default = 256MB] - - The memory buffer used to store intermediate result of reduction - - Format "digits + unit", can be 128M, 1G -* rabit_global_replica [default = 5] - - Number of replication copies of result kept for each Allreduce/Broadcast call -* rabit_local_replica [default = 2] - - Number of replication of local model in check point diff --git a/rabit/doc/python-requirements.txt b/rabit/doc/python-requirements.txt deleted file mode 100644 index 244b8378f..000000000 --- a/rabit/doc/python-requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -numpy -breathe -commonmark diff --git a/rabit/doc/python_api.md b/rabit/doc/python_api.md deleted file mode 100644 index 8a0eda921..000000000 --- a/rabit/doc/python_api.md +++ /dev/null @@ -1,11 +0,0 @@ -Python API of Rabit -=================== -This page contains document of python API of rabit. - -```eval_rst -.. toctree:: - -.. automodule:: rabit - :members: - :show-inheritance: -``` diff --git a/rabit/doc/sphinx_util.py b/rabit/doc/sphinx_util.py deleted file mode 100644 index f6a33ffa3..000000000 --- a/rabit/doc/sphinx_util.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -"""Helper utilty function for customization.""" -import sys -import os -import docutils -import subprocess - -if os.environ.get('READTHEDOCS', None) == 'True': - subprocess.call('cd ..; rm -rf recommonmark;' + - 'git clone https://github.com/tqchen/recommonmark', shell=True) - -sys.path.insert(0, os.path.abspath('../recommonmark/')) -from recommonmark import parser, transform - -MarkdownParser = parser.CommonMarkParser -AutoStructify = transform.AutoStructify diff --git a/rabit/guide/Makefile b/rabit/guide/Makefile deleted file mode 100644 index 802889095..000000000 --- a/rabit/guide/Makefile +++ /dev/null @@ -1,26 +0,0 @@ -export CC = gcc -export CXX = g++ -export MPICXX = mpicxx -export LDFLAGS= -pthread -lm -L../lib -export CFLAGS = -Wall -O3 -msse2 -std=c++11 -Wno-unknown-pragmas -fPIC -fopenmp -I../include - -.PHONY: clean all lib libmpi -BIN = basic.rabit broadcast.rabit -MOCKBIN= lazy_allreduce.mock - -all: $(BIN) -basic.rabit: basic.cc lib ../lib/librabit.a -broadcast.rabit: broadcast.cc lib ../lib/librabit.a -lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a - -$(BIN) : - $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS) - -$(MOCKBIN) : - $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock - -$(OBJ) : - $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) - -clean: - $(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~ diff --git a/rabit/guide/README b/rabit/guide/README deleted file mode 100644 index 2483d683f..000000000 --- a/rabit/guide/README +++ /dev/null @@ -1 +0,0 @@ -See tutorial at ../doc/guide.md \ No newline at end of file diff --git a/rabit/guide/basic.cc b/rabit/guide/basic.cc deleted file mode 100644 index d08397b54..000000000 --- a/rabit/guide/basic.cc +++ /dev/null @@ -1,35 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file basic.cc - * \brief This is an example demonstrating what is Allreduce - * - * \author Tianqi Chen - */ -#define _CRT_SECURE_NO_WARNINGS -#define _CRT_SECURE_NO_DEPRECATE -#include -#include -using namespace rabit; -int main(int argc, char *argv[]) { - int N = 3; - if (argc > 1) { - N = atoi(argv[1]); - } - std::vector a(N); - rabit::Init(argc, argv); - for (int i = 0; i < N; ++i) { - a[i] = rabit::GetRank() + i; - } - printf("@node[%d] before-allreduce: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // allreduce take max of each elements in all processes - Allreduce(&a[0], N); - printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // second allreduce that sums everything up - Allreduce(&a[0], N); - printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - rabit::Finalize(); - return 0; -} diff --git a/rabit/guide/basic.py b/rabit/guide/basic.py deleted file mode 100755 index 363150b5d..000000000 --- a/rabit/guide/basic.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/python -""" -demo python script of rabit -""" -from __future__ import print_function -from builtins import range -import os -import sys -import numpy as np -# import rabit, the tracker script will setup the lib path correctly -# for normal run without tracker script, add following line -# sys.path.append(os.path.dirname(__file__) + '/../python') -import rabit - -rabit.init() -n = 3 -rank = rabit.get_rank() -a = np.zeros(n) -for i in range(n): - a[i] = rank + i - -print('@node[%d] before-allreduce: a=%s' % (rank, str(a))) -a = rabit.allreduce(a, rabit.MAX) -print('@node[%d] after-allreduce-max: a=%s' % (rank, str(a))) -a = rabit.allreduce(a, rabit.SUM) -print('@node[%d] after-allreduce-sum: a=%s' % (rank, str(a))) -rabit.finalize() diff --git a/rabit/guide/broadcast.cc b/rabit/guide/broadcast.cc deleted file mode 100644 index 9e360d8de..000000000 --- a/rabit/guide/broadcast.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include -using namespace rabit; -const int N = 3; -int main(int argc, char *argv[]) { - rabit::Init(argc, argv); - std::string s; - if (rabit::GetRank() == 0) s = "hello world"; - printf("@node[%d] before-broadcast: s=\"%s\"\n", - rabit::GetRank(), s.c_str()); - // broadcast s from node 0 to all other nodes - rabit::Broadcast(&s, 0); - printf("@node[%d] after-broadcast: s=\"%s\"\n", - rabit::GetRank(), s.c_str()); - rabit::Finalize(); - return 0; -} diff --git a/rabit/guide/broadcast.py b/rabit/guide/broadcast.py deleted file mode 100755 index 8b8169223..000000000 --- a/rabit/guide/broadcast.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/python -""" -demo python script of rabit -""" -from __future__ import print_function -import os -import sys -# add path to wrapper -# for normal run without tracker script, add following line -# sys.path.append(os.path.dirname(__file__) + '/../wrapper') -import rabit - -rabit.init() -n = 3 -rank = rabit.get_rank() -s = None -if rank == 0: - s = {'hello world':100, 2:3} -print('@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s))) -s = rabit.broadcast(s, 0) - -print('@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))) -rabit.finalize() diff --git a/rabit/guide/lazy_allreduce.cc b/rabit/guide/lazy_allreduce.cc deleted file mode 100644 index b4b816fa0..000000000 --- a/rabit/guide/lazy_allreduce.cc +++ /dev/null @@ -1,34 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file basic.cc - * \brief This is an example demonstrating what is Allreduce - * - * \author Tianqi Chen - */ -#include - -using namespace rabit; -const int N = 3; -int main(int argc, char *argv[]) { - int a[N] = {0}; - rabit::Init(argc, argv); - // lazy preparation function - auto prepare = [&]() { - printf("@node[%d] run prepare function\n", rabit::GetRank()); - for (int i = 0; i < N; ++i) { - a[i] = rabit::GetRank() + i; - } - }; - printf("@node[%d] before-allreduce: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // allreduce take max of each elements in all processes - Allreduce(&a[0], N, prepare); - printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - // rum second allreduce - Allreduce(&a[0], N); - printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); - rabit::Finalize(); - return 0; -} diff --git a/rabit/guide/lazy_allreduce.py b/rabit/guide/lazy_allreduce.py deleted file mode 100755 index 1c4b6b1e1..000000000 --- a/rabit/guide/lazy_allreduce.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/python -""" -demo python script of rabit: Lazy preparation function -""" -import os -import sys -import numpy as np -# import rabit, the tracker script will setup the lib path correctly -# for normal run without tracker script, add following line -# sys.path.append(os.path.dirname(__file__) + '/../wrapper') -import rabit - - -# use mock library so that we can run failure test -rabit.init(lib = 'mock') -n = 3 -rank = rabit.get_rank() -a = np.zeros(n) - -def prepare(a): - print('@node[%d] run prepare function' % rank) - # must take in reference and modify the reference - for i in range(n): - a[i] = rank + i - -print('@node[%d] before-allreduce: a=%s' % (rank, str(a))) -a = rabit.allreduce(a, rabit.MAX, prepare_fun = prepare) -print('@node[%d] after-allreduce-max: a=%s' % (rank, str(a))) -a = rabit.allreduce(a, rabit.SUM) -print('@node[%d] after-allreduce-sum: a=%s' % (rank, str(a))) -rabit.finalize() From 0cba2cdbb02ad16d26b4f3f248bdb692329d44c3 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 6 Jun 2023 09:47:24 +0800 Subject: [PATCH 9/9] Support linalg data structures in check device. (#9243) --- src/data/data.cc | 50 +++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/data/data.cc b/src/data/data.cc index f9886b2f0..00cff8ab0 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -7,14 +7,15 @@ #include #include +#include #include #include "../collective/communicator-inl.h" #include "../collective/communicator.h" -#include "../common/common.h" #include "../common/algorithm.h" // for StableSort #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "../common/error_msg.h" // for InfInData +#include "../common/common.h" +#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize #include "../common/group_data.h" #include "../common/io.h" #include "../common/linalg_op.h" @@ -35,6 +36,7 @@ #include "xgboost/context.h" #include "xgboost/host_device_vector.h" #include "xgboost/learner.h" +#include "xgboost/linalg.h" // Vector #include "xgboost/logging.h" #include "xgboost/string_view.h" #include "xgboost/version_config.h" @@ -491,7 +493,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { } // uint info if (key == "group") { - linalg::Tensor t; + linalg::Vector t; CopyTensorInfoImpl(ctx, arr, &t); auto const& h_groups = t.Data()->HostVector(); group_ptr_.clear(); @@ -516,6 +518,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { data::ValidateQueryGroup(group_ptr_); return; } + // float info linalg::Tensor t; CopyTensorInfoImpl<1>(ctx, arr, &t); @@ -717,58 +720,63 @@ void MetaInfo::SynchronizeNumberOfColumns() { } } +namespace { +template +void CheckDevice(std::int32_t device, HostDeviceVector const& v) { + CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) + << "Data is resided on a different device than `gpu_id`. " + << "Device that data is on: " << v.DeviceIdx() << ", " + << "`gpu_id` for XGBoost: " << device; +} +template +void CheckDevice(std::int32_t device, linalg::Tensor const& v) { + CheckDevice(device, *v.Data()); +} +} // anonymous namespace + void MetaInfo::Validate(std::int32_t device) const { if (group_ptr_.size() != 0 && weights_.Size() != 0) { - CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) - << "Size of weights must equal to number of groups when ranking " - "group is used."; + CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight(); return; } if (group_ptr_.size() != 0) { CHECK_EQ(group_ptr_.back(), num_row_) - << "Invalid group structure. Number of rows obtained from groups " - "doesn't equal to actual number of rows given by data."; + << error::GroupSize() << "the actual number of rows given by data."; } - auto check_device = [device](HostDeviceVector const& v) { - CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) - << "Data is resided on a different device than `gpu_id`. " - << "Device that data is on: " << v.DeviceIdx() << ", " - << "`gpu_id` for XGBoost: " << device; - }; if (weights_.Size() != 0) { CHECK_EQ(weights_.Size(), num_row_) << "Size of weights must equal to number of rows."; - check_device(weights_); + CheckDevice(device, weights_); return; } if (labels.Size() != 0) { CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows."; - check_device(*labels.Data()); + CheckDevice(device, labels); return; } if (labels_lower_bound_.Size() != 0) { CHECK_EQ(labels_lower_bound_.Size(), num_row_) << "Size of label_lower_bound must equal to number of rows."; - check_device(labels_lower_bound_); + CheckDevice(device, labels_lower_bound_); return; } if (feature_weights.Size() != 0) { CHECK_EQ(feature_weights.Size(), num_col_) << "Size of feature_weights must equal to number of columns."; - check_device(feature_weights); + CheckDevice(device, feature_weights); } if (labels_upper_bound_.Size() != 0) { CHECK_EQ(labels_upper_bound_.Size(), num_row_) << "Size of label_upper_bound must equal to number of rows."; - check_device(labels_upper_bound_); + CheckDevice(device, labels_upper_bound_); return; } CHECK_LE(num_nonzero_, num_col_ * num_row_); if (base_margin_.Size() != 0) { CHECK_EQ(base_margin_.Size() % num_row_, 0) << "Size of base margin must be a multiple of number of rows."; - check_device(*base_margin_.Data()); + CheckDevice(device, base_margin_); } } @@ -1028,6 +1036,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { bool SparsePage::IsIndicesSorted(int32_t n_threads) const { auto& h_offset = this->offset.HostVector(); auto& h_data = this->data.HostVector(); + n_threads = std::max(std::min(static_cast(n_threads), this->Size()), + static_cast(1)); std::vector is_sorted_tloc(n_threads, 0); common::ParallelFor(this->Size(), n_threads, [&](auto i) { auto beg = h_offset[i];