From d062a9e0095149a06611e09f3952e040698291ae Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 30 Mar 2023 12:00:35 +0800 Subject: [PATCH 01/12] Define pair generation strategies for LTR. (#8984) --- src/common/ranking_utils.h | 2 +- src/metric/rank_metric.cc | 1 - src/objective/init_estimation.cc | 8 +- src/objective/init_estimation.h | 6 +- src/objective/lambdarank_obj.cu | 62 +++++ src/objective/lambdarank_obj.cuh | 172 +++++++++++++ src/objective/lambdarank_obj.h | 260 ++++++++++++++++++++ tests/cpp/objective/test_lambdarank_obj.cc | 106 ++++++++ tests/cpp/objective/test_lambdarank_obj.cu | 138 +++++++++++ tests/cpp/objective/test_lambdarank_obj.h | 26 ++ tests/cpp/objective/test_ranking_obj_gpu.cu | 37 --- 11 files changed, 770 insertions(+), 48 deletions(-) create mode 100644 src/objective/lambdarank_obj.cu create mode 100644 src/objective/lambdarank_obj.cuh create mode 100644 src/objective/lambdarank_obj.h create mode 100644 tests/cpp/objective/test_lambdarank_obj.cc create mode 100644 tests/cpp/objective/test_lambdarank_obj.cu create mode 100644 tests/cpp/objective/test_lambdarank_obj.h diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 727f918f2..bc071c2d6 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -123,7 +123,7 @@ struct LambdaRankParam : public XGBoostParameter { DMLC_DECLARE_PARAMETER(LambdaRankParam) { DMLC_DECLARE_FIELD(lambdarank_pair_method) - .set_default(PairMethod::kMean) + .set_default(PairMethod::kTopK) .add_enum("mean", PairMethod::kMean) .add_enum("topk", PairMethod::kTopK) .describe("Method for constructing pairs."); diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 3a1416b0f..a84d0edb1 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -112,7 +112,6 @@ class PerGroupWeightPolicy { return info.GetWeight(group_id); } }; - } // anonymous namespace namespace xgboost::metric { diff --git a/src/objective/init_estimation.cc b/src/objective/init_estimation.cc index 938ceb59d..834c052f5 100644 --- a/src/objective/init_estimation.cc +++ b/src/objective/init_estimation.cc @@ -14,8 +14,7 @@ #include "xgboost/linalg.h" // Tensor,Vector #include "xgboost/task.h" // ObjInfo -namespace xgboost { -namespace obj { +namespace xgboost::obj { void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const { if (this->Task().task == ObjInfo::kRegression) { CheckInitInputs(info); @@ -31,14 +30,13 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* b ObjFunction::Create(get(config["name"]), this->ctx_)}; new_obj->LoadConfig(config); new_obj->GetGradient(dummy_predt, info, 0, &gpair); + bst_target_t n_targets = this->Targets(info); linalg::Vector leaf_weight; tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight); - // workaround, we don't support multi-target due to binary model serialization for // base margin. common::Mean(this->ctx_, leaf_weight, base_score); this->PredTransform(base_score->Data()); } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/src/objective/init_estimation.h b/src/objective/init_estimation.h index b0a91d8c3..0ac5c5206 100644 --- a/src/objective/init_estimation.h +++ b/src/objective/init_estimation.h @@ -7,8 +7,7 @@ #include "xgboost/linalg.h" // Tensor #include "xgboost/objective.h" // ObjFunction -namespace xgboost { -namespace obj { +namespace xgboost::obj { class FitIntercept : public ObjFunction { void InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const override; }; @@ -20,6 +19,5 @@ inline void CheckInitInputs(MetaInfo const& info) { << "Number of weights should be equal to number of data points."; } } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj #endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_ diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu new file mode 100644 index 000000000..eb82b17b4 --- /dev/null +++ b/src/objective/lambdarank_obj.cu @@ -0,0 +1,62 @@ +/** + * Copyright 2015-2023 by XGBoost contributors + * + * \brief CUDA implementation of lambdarank. + */ +#include // for fill_n +#include // for for_each_n +#include // for make_counting_iterator +#include // for make_zip_iterator +#include // for make_tuple, tuple, tie, get + +#include // for min +#include // for assert +#include // for abs, log2, isinf +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include + +#include "../common/algorithm.cuh" // for SegmentedArgSort +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/deterministic.cuh" // for CreateRoundingFactor, TruncateWithRounding +#include "../common/device_helpers.cuh" // for SegmentId, TemporaryArray, AtomicAddGpair +#include "../common/optional_weight.h" // for MakeOptionalWeights +#include "../common/ranking_utils.h" // for NDCGCache, LambdaRankParam, rel_degree_t +#include "lambdarank_obj.cuh" +#include "lambdarank_obj.h" +#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE, GradientPair +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for VectorView, Range, Vector +#include "xgboost/logging.h" +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu); + +namespace cuda_impl { +common::Span SortY(Context const* ctx, MetaInfo const& info, + common::Span d_rank, + std::shared_ptr p_cache) { + auto const d_group_ptr = p_cache->DataGroupPtr(ctx); + auto label = info.labels.View(ctx->gpu_id); + // The buffer for ranked y is necessary as cub segmented sort accepts only pointer. + auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_); + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + auto g = dh::SegmentId(d_group_ptr, i); + auto g_label = + label.Slice(linalg::Range(d_group_ptr[g], d_group_ptr[g + 1]), 0); + auto g_rank_idx = d_rank.subspan(d_group_ptr[g], g_label.Size()); + i -= d_group_ptr[g]; + auto g_y_ranked = d_y_ranked.subspan(d_group_ptr[g], g_label.Size()); + g_y_ranked[i] = g_label(g_rank_idx[i]); + }); + auto d_y_sorted_idx = p_cache->SortedIdxY(ctx, info.num_row_); + common::SegmentedArgSort(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx); + return d_y_sorted_idx; +} +} // namespace cuda_impl +} // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh new file mode 100644 index 000000000..be9f479ce --- /dev/null +++ b/src/objective/lambdarank_obj.cuh @@ -0,0 +1,172 @@ +/** + * Copyright 2023 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ +#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ + +#include // for lower_bound, upper_bound +#include // for greater +#include // for make_counting_iterator +#include // for minstd_rand +#include // for uniform_int_distribution + +#include // for cassert +#include // for size_t +#include // for int32_t +#include // for make_tuple, tuple + +#include "../common/device_helpers.cuh" // for MakeTransformIterator +#include "../common/ranking_utils.cuh" // for PairsForGroup +#include "../common/ranking_utils.h" // for RankingCache +#include "../common/threading_utils.cuh" // for UnravelTrapeziodIdx +#include "xgboost/base.h" // for bst_group_t, GradientPair, XGBOOST_DEVICE +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/linalg.h" // for VectorView, Range, UnravelIndex +#include "xgboost/span.h" // for Span + +namespace xgboost::obj::cuda_impl { +/** + * \brief Find number of elements left to the label bucket + */ +template ::value_type> +XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheLeftOf(It items, std::size_t n, T v) { + return thrust::lower_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items; +} +/** + * \brief Find number of elements right to the label bucket + */ +template ::value_type> +XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheRightOf(It items, std::size_t n, T v) { + return n - (thrust::upper_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items); +} +/** + * \brief Sort labels according to rank list for making pairs. + */ +common::Span SortY(Context const *ctx, MetaInfo const &info, + common::Span d_rank, + std::shared_ptr p_cache); + +/** + * \brief Parameters needed for calculating gradient + */ +struct KernelInputs { + linalg::VectorView ti_plus; // input bias ratio + linalg::VectorView tj_minus; // input bias ratio + linalg::VectorView li; + linalg::VectorView lj; + + common::Span d_group_ptr; + common::Span d_threads_group_ptr; + common::Span d_sorted_idx; + + linalg::MatrixView labels; + common::Span predts; + common::Span gpairs; + + linalg::VectorView d_roundings; + double const *d_cost_rounding; + + common::Span d_y_sorted_idx; + + std::int32_t iter; +}; +/** + * \brief Functor for generating pairs + */ +template +struct MakePairsOp { + KernelInputs args; + /** + * \brief Make pair for the topk pair method. + */ + XGBOOST_DEVICE std::tuple WithTruncation(std::size_t idx, + bst_group_t g) const { + auto thread_group_begin = args.d_threads_group_ptr[g]; + auto idx_in_thread_group = idx - thread_group_begin; + + auto data_group_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; + // obtain group segment data. + auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); + auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data); + + std::size_t i = 0, j = 0; + common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j); + + std::size_t rank_high = i, rank_low = j; + return std::make_tuple(rank_high, rank_low); + } + /** + * \brief Make pair for the mean pair method + */ + XGBOOST_DEVICE std::tuple WithSampling(std::size_t idx, + bst_group_t g) const { + std::size_t n_samples = args.labels.Size(); + assert(n_samples == args.predts.size()); + // Constructed from ranking cache. + std::size_t n_pairs = + ltr::cuda_impl::PairsForGroup(args.d_threads_group_ptr[g + 1] - args.d_threads_group_ptr[g], + args.d_group_ptr[g + 1] - args.d_group_ptr[g]); + + assert(n_pairs > 0); + auto [sample_idx, sample_pair_idx] = linalg::UnravelIndex(idx, {n_samples, n_pairs}); + + auto g_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - g_begin; + + auto g_label = args.labels.Slice(linalg::Range(g_begin, g_begin + n_data)); + auto g_rank_idx = args.d_sorted_idx.subspan(args.d_group_ptr[g], n_data); + auto g_y_sorted_idx = args.d_y_sorted_idx.subspan(g_begin, n_data); + + std::size_t const i = sample_idx - g_begin; + assert(sample_pair_idx < n_samples); + assert(i <= sample_idx); + + auto g_sorted_label = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [&](std::size_t i) { return g_label(g_rank_idx[g_y_sorted_idx[i]]); }); + + // Are the labels diverse enough? If they are all the same, then there is nothing to pick + // from another group - bail sooner + if (g_label.Size() == 0 || g_sorted_label[0] == g_sorted_label[n_data - 1]) { + auto z = static_cast(0ul); + return std::make_tuple(z, z); + } + + std::size_t n_lefts = CountNumItemsToTheLeftOf(g_sorted_label, i + 1, g_sorted_label[i]); + std::size_t n_rights = + CountNumItemsToTheRightOf(g_sorted_label + i, n_data - i, g_sorted_label[i]); + // The index pointing to the first element of the next bucket + std::size_t right_bound = n_data - n_rights; + + thrust::minstd_rand rng(args.iter); + auto pair_idx = i; + rng.discard(sample_pair_idx * n_data + g + pair_idx); // fixme + thrust::uniform_int_distribution dist(0, n_lefts + n_rights - 1); + auto ridx = dist(rng); + SPAN_CHECK(ridx < n_lefts + n_rights); + if (ridx >= n_lefts) { + ridx = ridx - n_lefts + right_bound; // fixme + } + + auto idx0 = g_y_sorted_idx[pair_idx]; + auto idx1 = g_y_sorted_idx[ridx]; + + return std::make_tuple(idx0, idx1); + } + /** + * \brief Generate a single pair. + * + * \param idx Pair index (CUDA thread index). + * \param g Query group index. + */ + XGBOOST_DEVICE auto operator()(std::size_t idx, bst_group_t g) const { + if (has_truncation) { + return this->WithTruncation(idx, g); + } else { + return this->WithSampling(idx, g); + } + } +}; +} // namespace xgboost::obj::cuda_impl +#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h new file mode 100644 index 000000000..3adb27a2e --- /dev/null +++ b/src/objective/lambdarank_obj.h @@ -0,0 +1,260 @@ +/** + * Copyright 2023 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ +#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ +#include // for min, max +#include // for assert +#include // for log, abs +#include // for size_t +#include // for greater +#include // for shared_ptr +#include // for minstd_rand, uniform_int_distribution +#include // for vector + +#include "../common/algorithm.h" // for ArgSort +#include "../common/math.h" // for Sigmoid +#include "../common/ranking_utils.h" // for CalcDCGGain +#include "../common/transform_iterator.h" // for MakeIndexTransformIter +#include "xgboost/base.h" // for GradientPair, XGBOOST_DEVICE, kRtEps +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for VectorView, Vector +#include "xgboost/logging.h" // for CHECK_EQ +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +template +XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low, + double inv_IDCG, common::Span discount) { + double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high; + double discount_high = discount[r_high]; + + double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low; + double discount_low = discount[r_low]; + + double original = gain_high * discount_high + gain_low * discount_low; + double changed = gain_low * discount_high + gain_high * discount_low; + + double delta_NDCG = (original - changed) * inv_IDCG; + assert(delta_NDCG >= -1.0); + assert(delta_NDCG <= 1.0); + return delta_NDCG; +} + +XGBOOST_DEVICE inline double DeltaMAP(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, common::Span n_rel, + common::Span acc) { + double r_h = static_cast(rank_high) + 1.0; + double r_l = static_cast(rank_low) + 1.0; + double delta{0.0}; + double n_total_relevances = n_rel.back(); + assert(n_total_relevances > 0.0); + auto m = n_rel[rank_low]; + double n = n_rel[rank_high]; + + if (y_high < y_low) { + auto a = m / r_l - (n + 1.0) / r_h; + auto b = acc[rank_low - 1] - acc[rank_high]; + delta = (a - b) / n_total_relevances; + } else { + auto a = n / r_h - m / r_l; + auto b = acc[rank_low - 1] - acc[rank_high]; + delta = (a + b) / n_total_relevances; + } + return delta; +} + +template +XGBOOST_DEVICE GradientPair +LambdaGrad(linalg::VectorView labels, common::Span predts, + common::Span sorted_idx, + std::size_t rank_high, // cordiniate + std::size_t rank_low, // cordiniate + Delta delta, // delta score + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + double* p_cost) { + assert(sorted_idx.size() > 0 && "Empty sorted idx for a group."); + std::size_t idx_high = sorted_idx[rank_high]; + std::size_t idx_low = sorted_idx[rank_low]; + + if (labels(idx_high) == labels(idx_low)) { + *p_cost = 0; + return {0.0f, 0.0f}; + } + + auto best_score = predts[sorted_idx.front()]; + auto worst_score = predts[sorted_idx.back()]; + + auto y_high = labels(idx_high); + float s_high = predts[idx_high]; + auto y_low = labels(idx_low); + float s_low = predts[idx_low]; + + // Use double whenever possible as we are working on the exp space. + double delta_score = std::abs(s_high - s_low); + double sigmoid = common::Sigmoid(s_high - s_low); + // Change in metric score like \delta NDCG or \delta MAP + double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low)); + + if (best_score != worst_score) { + delta_metric /= (delta_score + kRtEps); + } + + if (unbiased) { + *p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric; + } + + constexpr double kEps = 1e-16; + auto lambda_ij = (sigmoid - 1.0) * delta_metric; + auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0; + + auto k = t_plus.Size(); + assert(t_minus.Size() == k && "Invalid size of position bias"); + + if (unbiased && idx_high < k && idx_low < k) { + lambda_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps); + hessian_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps); + } + + auto pg = GradientPair{static_cast(lambda_ij), static_cast(hessian_ij)}; + return pg; +} + +XGBOOST_DEVICE inline GradientPair Repulse(GradientPair pg) { + auto ng = GradientPair{-pg.GetGrad(), pg.GetHess()}; + return ng; +} + +namespace cuda_impl { +void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, + HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + */ +void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, + std::shared_ptr p_cache); + +void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, + linalg::VectorView lj_full, + linalg::Vector* p_ti_plus, + linalg::Vector* p_tj_minus, linalg::Vector* p_li, + linalg::Vector* p_lj, + std::shared_ptr p_cache); +} // namespace cuda_impl + +namespace cpu_impl { +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + * + * \param label Ground truth relevance label. + * \param rank_idx Sorted index of prediction. + * \param p_cache An initialized MAPCache. + */ +void MAPStat(Context const* ctx, linalg::VectorView label, + common::Span rank_idx, std::shared_ptr p_cache); +} // namespace cpu_impl + +/** + * \param Construct pairs on CPU + * + * \tparam Op Functor for upgrading a pair of gradients. + * + * \param ctx The global context. + * \param iter The boosting iteration. + * \param cache ltr cache. + * \param g The current query group + * \param g_label label The labels for the current query group + * \param g_rank Sorted index of model scores for the current query group. + * \param op A callable that accepts two index for a pair of documents. The index is for + * the ranked list (labels sorted according to model scores). + */ +template +void MakePairs(Context const* ctx, std::int32_t iter, + std::shared_ptr const cache, bst_group_t g, + linalg::VectorView g_label, common::Span g_rank, + Op op) { + auto group_ptr = cache->DataGroupPtr(ctx); + ltr::position_t cnt = group_ptr[g + 1] - group_ptr[g]; + + if (cache->Param().HasTruncation()) { + for (std::size_t i = 0; i < std::min(cnt, cache->Param().NumPair()); ++i) { + for (std::size_t j = i + 1; j < cnt; ++j) { + op(i, j); + } + } + } else { + CHECK_EQ(g_rank.size(), g_label.Size()); + std::minstd_rand rnd(iter); + rnd.discard(g); // fixme(jiamingy): honor the global seed + // sort label according to the rank list + auto it = common::MakeIndexTransformIter( + [&g_rank, &g_label](std::size_t idx) { return g_label(g_rank[idx]); }); + std::vector y_sorted_idx = + common::ArgSort(ctx, it, it + cnt, std::greater<>{}); + // permutation iterator to get the original label + auto rev_it = common::MakeIndexTransformIter( + [&](std::size_t idx) { return g_label(g_rank[y_sorted_idx[idx]]); }); + + for (std::size_t i = 0; i < cnt;) { + std::size_t j = i + 1; + // find the bucket boundary + while (j < cnt && rev_it[i] == rev_it[j]) { + ++j; + } + // Bucket [i,j), construct n_samples pairs for each sample inside the bucket with + // another sample outside the bucket. + // + // n elements left to the bucket, and n elements right to the bucket + std::size_t n_lefts = i, n_rights = static_cast(cnt - j); + if (n_lefts + n_rights == 0) { + i = j; + continue; + } + + auto n_samples = cache->Param().NumPair(); + // for each pair specifed by the user + while (n_samples--) { + // for each sample in the bucket + for (std::size_t pair_idx = i; pair_idx < j; ++pair_idx) { + std::size_t ridx = std::uniform_int_distribution( + static_cast(0), n_lefts + n_rights - 1)(rnd); + if (ridx >= n_lefts) { + ridx = ridx - i + j; // shift to the right of the bucket + } + // index that points to the rank list. + auto idx0 = y_sorted_idx[pair_idx]; + auto idx1 = y_sorted_idx[ridx]; + op(idx0, idx1); + } + } + i = j; + } + } +} +} // namespace xgboost::obj +#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc new file mode 100644 index 000000000..11cbf6bec --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include "test_lambdarank_obj.h" + +#include // for Test, Message, TestPartResult, CmpHel... + +#include // for size_t +#include // for initializer_list +#include // for map +#include // for unique_ptr, shared_ptr, make_shared +#include // for iota +#include // for char_traits, basic_string, string +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam +#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam +#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload +#include "xgboost/base.h" // for GradientPair, bst_group_t, Args +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo, DMatrix +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Tensor, All, TensorView +#include "xgboost/objective.h" // for ObjFunction +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt) { + out_predt->SetDevice(ctx->gpu_id); + MetaInfo& info = *out_info; + info.num_row_ = 128; + info.labels.ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + shape[0] = info.num_row_; + shape[1] = 1; + auto& h_data = data->HostVector(); + h_data.resize(shape[0]); + for (std::size_t i = 0; i < h_data.size(); ++i) { + h_data[i] = i % 2; + } + }); + std::vector predt(info.num_row_); + std::iota(predt.rbegin(), predt.rend(), 0.0f); + out_predt->HostVector() = predt; +} + +TEST(LambdaRank, MakePair) { + Context ctx; + MetaInfo info; + HostDeviceVector predt; + + InitMakePairTest(&ctx, &info, &predt); + + ltr::LambdaRankParam param; + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}}); + ASSERT_TRUE(param.HasTruncation()); + + std::shared_ptr p_cache = std::make_shared(&ctx, info, param); + auto const& h_predt = predt.ConstHostVector(); + { + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + for (std::size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_EQ(rank_idx[i], static_cast(*(h_predt.crbegin() + i))); + } + std::int32_t n_pairs{0}; + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ASSERT_GT(j, i); + ASSERT_LT(i, p_cache->Param().NumPair()); + ++n_pairs; + }); + ASSERT_EQ(n_pairs, 3568); + } + + auto const h_label = info.labels.HostView(); + + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}}); + auto p_cache = std::make_shared(&ctx, info, param); + ASSERT_FALSE(param.HasTruncation()); + std::int32_t n_pairs = 0; + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ++n_pairs; + // Not in the same bucket + ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j])); + }); + ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); + } + + { + param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + std::int32_t n_pairs = 0; + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ++n_pairs; + // Not in the same bucket + ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j])); + }); + ASSERT_EQ(param.NumPair(), 2); + ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); + } +} +} // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu new file mode 100644 index 000000000..03ccdef8b --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // for Context + +#include // for uint32_t +#include // for vector + +#include "../../../src/common/cuda_context.cuh" // for CUDAContext +#include "../../../src/objective/lambdarank_obj.cuh" +#include "test_lambdarank_obj.h" + +namespace xgboost::obj { +void TestGPUMakePair() { + Context ctx; + ctx.gpu_id = 0; + + MetaInfo info; + HostDeviceVector predt; + InitMakePairTest(&ctx, &info, &predt); + + ltr::LambdaRankParam param; + + auto make_args = [&](std::shared_ptr p_cache, auto rank_idx, + common::Span y_sorted_idx) { + linalg::Vector dummy; + auto d = dummy.View(ctx.gpu_id); + linalg::Vector dgpair; + auto dg = dgpair.View(ctx.gpu_id); + cuda_impl::KernelInputs args{d, + d, + d, + d, + p_cache->DataGroupPtr(&ctx), + p_cache->CUDAThreadsGroupPtr(), + rank_idx, + info.labels.View(ctx.gpu_id), + predt.ConstDeviceSpan(), + {}, + dg, + nullptr, + y_sorted_idx, + 0}; + return args; + }; + + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + + ASSERT_EQ(p_cache->CUDAThreads(), 3568); + + auto args = make_args(p_cache, rank_idx, {}); + auto n_pairs = p_cache->Param().NumPair(); + auto make_pair = cuda_impl::MakePairsOp{args}; + + dh::LaunchN(p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), + [=] XGBOOST_DEVICE(std::size_t idx) { + auto [i, j] = make_pair(idx, 0); + SPAN_CHECK(j > i); + SPAN_CHECK(i < n_pairs); + }); + } + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache); + + ASSERT_FALSE(param.HasTruncation()); + ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair()); + + auto args = make_args(p_cache, rank_idx, y_sorted_idx); + auto make_pair = cuda_impl::MakePairsOp{args}; + auto n_pairs = p_cache->Param().NumPair(); + ASSERT_EQ(n_pairs, 1); + + dh::LaunchN( + p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) { + idx = 97; + auto [i, j] = make_pair(idx, 0); + // Not in the same bucket + SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j])); + }); + } + { + param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache); + + auto args = make_args(p_cache, rank_idx, y_sorted_idx); + auto make_pair = cuda_impl::MakePairsOp{args}; + + dh::LaunchN( + p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) { + auto [i, j] = make_pair(idx, 0); + // Not in the same bucket + SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j])); + }); + ASSERT_EQ(param.NumPair(), 2); + ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair()); + } +} + +TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); } + +template +void RankItemCountImpl(std::vector const &sorted_items, CountFunctor f, + std::uint32_t find_val, std::uint32_t exp_val) { + EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end()); + EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val); +} + +TEST(LambdaRank, RankItemCountOnLeft) { + // Items sorted descendingly + std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; + auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheLeftOf(args...); }; + RankItemCountImpl(sorted_items, wrapper, 10, static_cast(0)); + RankItemCountImpl(sorted_items, wrapper, 6, static_cast(2)); + RankItemCountImpl(sorted_items, wrapper, 4, static_cast(3)); + RankItemCountImpl(sorted_items, wrapper, 1, static_cast(7)); + RankItemCountImpl(sorted_items, wrapper, 0, static_cast(12)); +} + +TEST(LambdaRank, RankItemCountOnRight) { + // Items sorted descendingly + std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; + auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheRightOf(args...); }; + RankItemCountImpl(sorted_items, wrapper, 10, static_cast(11)); + RankItemCountImpl(sorted_items, wrapper, 6, static_cast(10)); + RankItemCountImpl(sorted_items, wrapper, 4, static_cast(6)); + RankItemCountImpl(sorted_items, wrapper, 1, static_cast(1)); + RankItemCountImpl(sorted_items, wrapper, 0, static_cast(0)); +} +} // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h new file mode 100644 index 000000000..8dd238d2b --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#include +#include // for MetaInfo +#include // for HostDeviceVector +#include // for All +#include // for ObjFunction + +#include // for shared_ptr, make_shared +#include // for iota +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache +#include "../../../src/objective/lambdarank_obj.h" // for MAPStat +#include "../helpers.h" // for EmptyDMatrix + +namespace xgboost::obj { +/** + * \brief Initialize test data for make pair tests. + */ +void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt); +} // namespace xgboost::obj +#endif // XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index 02286ab46..540560c1f 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -89,43 +89,6 @@ TEST(Objective, RankSegmentSorterAscendingTest) { 5, 4, 6}); } -using CountFunctor = uint32_t (*)(const int *, uint32_t, int); -void RankItemCountImpl(const std::vector &sorted_items, CountFunctor f, - int find_val, uint32_t exp_val) { - EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end()); - EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val); -} - -TEST(Objective, RankItemCountOnLeft) { - // Items sorted descendingly - std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 10, static_cast(0)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 6, static_cast(2)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 4, static_cast(3)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 1, static_cast(7)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 0, static_cast(12)); -} - -TEST(Objective, RankItemCountOnRight) { - // Items sorted descendingly - std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 10, static_cast(11)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 6, static_cast(10)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 4, static_cast(6)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 1, static_cast(1)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 0, static_cast(0)); -} - TEST(Objective, NDCGLambdaWeightComputerTest) { std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels 7.8f, 5.01f, 6.96f, From cd05e38533f3db52b18f227e787da4015412cd2c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 30 Mar 2023 19:09:07 +0800 Subject: [PATCH 02/12] [doc][R] Update link. (#8998) --- R-package/LICENSE | 4 ++-- R-package/R/xgb.plot.tree.R | 2 +- R-package/man/xgb.plot.tree.Rd | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R-package/LICENSE b/R-package/LICENSE index b9f38c38a..bc1c21d59 100644 --- a/R-package/LICENSE +++ b/R-package/LICENSE @@ -1,9 +1,9 @@ -Copyright (c) 2014 by Tianqi Chen and Contributors +Copyright (c) 2014-2023, Tianqi Chen and XBGoost Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index cb6cb25ad..956c13cf7 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -34,7 +34,7 @@ #' The branches that also used for missing values are marked as bold #' (as in "carrying extra capacity"). #' -#' This function uses \href{http://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR. +#' This function uses \href{https://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR. #' #' @return #' diff --git a/R-package/man/xgb.plot.tree.Rd b/R-package/man/xgb.plot.tree.Rd index 8fd7196af..d419eb76a 100644 --- a/R-package/man/xgb.plot.tree.Rd +++ b/R-package/man/xgb.plot.tree.Rd @@ -67,7 +67,7 @@ The "Yes" branches are marked by the "< split_value" label. The branches that also used for missing values are marked as bold (as in "carrying extra capacity"). -This function uses \href{http://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR. +This function uses \href{https://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR. } \examples{ data(agaricus.train, package='xgboost') From b647403baaa5086aca732e45f090b6a876a58355 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 31 Mar 2023 03:52:09 +0800 Subject: [PATCH 03/12] Update release news. [skip ci] (#9000) --- NEWS.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/NEWS.md b/NEWS.md index 03ed1d7e9..963dd3337 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,26 @@ XGBoost Change Log This file records the changes in xgboost library in reverse chronological order. +## 1.7.5 (2023 Mar 30) +This is a patch release for bug fixes. + +* C++ requirement is updated to C++-17, along with which, CUDA 11.8 is used as the default CTK. (#8860, #8855, #8853) +* Fix import for pyspark ranker. (#8692) +* Fix Windows binary wheel to be compatible with Poetry (#8991) +* Fix GPU hist with column sampling. (#8850) +* Make sure iterative DMatrix is properly initialized. (#8997) +* [R] Update link in document. (#8998) + +## 1.7.4 (2023 Feb 16) +This is a patch release for bug fixes. + +* [R] Fix OpenMP detection on macOS. (#8684) +* [Python] Make sure input numpy array is aligned. (#8690) +* Fix feature interaction with column sampling in gpu_hist evaluator. (#8754) +* Fix GPU L1 error. (#8749) +* [PySpark] Fix feature types param (#8772) +* Fix ranking with quantile dmatrix and group weight. (#8762) + ## 1.7.3 (2023 Jan 6) This is a patch release for bug fixes. From bac22734fb6637089d0bf3e3889bb9cdf261eaa7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 31 Mar 2023 19:01:55 +0800 Subject: [PATCH 04/12] Remove ntree limit in python package. (#8345) - Remove `ntree_limit`. The parameter has been deprecated since 1.4.0. - The SHAP package compatibility is broken. --- python-package/xgboost/callback.py | 16 +- python-package/xgboost/core.py | 67 +----- python-package/xgboost/dask.py | 9 - python-package/xgboost/sklearn.py | 35 +--- python-package/xgboost/spark/data.py | 4 +- tests/ci_build/conda_env/aarch64_test.yml | 1 - tests/ci_build/conda_env/linux_cpu_test.yml | 1 - tests/ci_build/lint_python.py | 1 + tests/python/test_basic_models.py | 16 +- tests/python/test_cli.py | 1 - tests/python/test_predict.py | 38 ++-- tests/python/test_ranking.py | 2 +- tests/python/test_training_continuation.py | 43 ++-- tests/python/test_with_shap.py | 6 +- tests/python/test_with_sklearn.py | 39 ++-- .../test_with_spark/test_spark_local.py | 169 +++++++-------- .../test_spark_local_cluster.py | 193 +++++++++++------- 17 files changed, 284 insertions(+), 357 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 6569f7e3d..cc62b354d 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -23,13 +23,7 @@ from typing import ( import numpy from . import collective -from .core import ( - Booster, - DMatrix, - XGBoostError, - _get_booster_layer_trees, - _parse_eval_str, -) +from .core import Booster, DMatrix, XGBoostError, _parse_eval_str __all__ = [ "TrainingCallback", @@ -177,22 +171,14 @@ class CallbackContainer: assert isinstance(model, Booster), msg if not self.is_cv: - num_parallel_tree, _ = _get_booster_layer_trees(model) if model.attr("best_score") is not None: model.best_score = float(cast(str, model.attr("best_score"))) model.best_iteration = int(cast(str, model.attr("best_iteration"))) - # num_class is handled internally - model.set_attr( - best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree) - ) - model.best_ntree_limit = int(cast(str, model.attr("best_ntree_limit"))) else: # Due to compatibility with version older than 1.4, these attributes are # added to Python object even if early stopping is not used. model.best_iteration = model.num_boosted_rounds() - 1 model.set_attr(best_iteration=str(model.best_iteration)) - model.best_ntree_limit = (model.best_iteration + 1) * num_parallel_tree - model.set_attr(best_ntree_limit=str(model.best_ntree_limit)) return model diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 30aa771e3..68346d900 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -126,25 +126,6 @@ def _parse_eval_str(result: str) -> List[Tuple[str, float]]: IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int]) -def _convert_ntree_limit( - booster: "Booster", ntree_limit: Optional[int], iteration_range: IterRange -) -> IterRange: - if ntree_limit is not None and ntree_limit != 0: - warnings.warn( - "ntree_limit is deprecated, use `iteration_range` or model " - "slicing instead.", - UserWarning, - ) - if iteration_range is not None and iteration_range[1] != 0: - raise ValueError( - "Only one of `iteration_range` and `ntree_limit` can be non zero." - ) - num_parallel_tree, _ = _get_booster_layer_trees(booster) - num_parallel_tree = max([num_parallel_tree, 1]) - iteration_range = (0, ntree_limit // num_parallel_tree) - return iteration_range - - def _expect(expectations: Sequence[Type], got: Type) -> str: """Translate input error into string. @@ -1508,41 +1489,6 @@ Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] -def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]: - """Get number of trees added to booster per-iteration. This function will be removed - once `best_ntree_limit` is dropped in favor of `best_iteration`. Returns - `num_parallel_tree` and `num_groups`. - - """ - config = json.loads(model.save_config()) - booster = config["learner"]["gradient_booster"]["name"] - if booster == "gblinear": - num_parallel_tree = 0 - elif booster == "dart": - num_parallel_tree = int( - config["learner"]["gradient_booster"]["gbtree"]["gbtree_model_param"][ - "num_parallel_tree" - ] - ) - elif booster == "gbtree": - try: - num_parallel_tree = int( - config["learner"]["gradient_booster"]["gbtree_model_param"][ - "num_parallel_tree" - ] - ) - except KeyError: - num_parallel_tree = int( - config["learner"]["gradient_booster"]["gbtree_train_param"][ - "num_parallel_tree" - ] - ) - else: - raise ValueError(f"Unknown booster: {booster}") - num_groups = int(config["learner"]["learner_model_param"]["num_class"]) - return num_parallel_tree, num_groups - - def _configure_metrics(params: BoosterParam) -> BoosterParam: if ( isinstance(params, dict) @@ -1576,11 +1522,11 @@ class Booster: """ Parameters ---------- - params : dict + params : Parameters for boosters. - cache : list + cache : List of cache items. - model_file : string/os.PathLike/Booster/bytearray + model_file : Path to the model file if it's string or PathLike. """ cache = cache if cache is not None else [] @@ -2100,7 +2046,6 @@ class Booster: self, data: DMatrix, output_margin: bool = False, - ntree_limit: int = 0, pred_leaf: bool = False, pred_contribs: bool = False, approx_contribs: bool = False, @@ -2127,9 +2072,6 @@ class Booster: output_margin : Whether to output the raw untransformed margin value. - ntree_limit : - Deprecated, use `iteration_range` instead. - pred_leaf : When this option is on, the output will be a matrix of (nsample, ntrees) with each record indicating the predicted leaf index of @@ -2196,7 +2138,6 @@ class Booster: raise TypeError("Expecting data to be a DMatrix object, got: ", type(data)) if validate_features: self._validate_dmatrix_features(data) - iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range) args = { "type": 0, "training": training, @@ -2522,8 +2463,6 @@ class Booster: self.best_iteration = int(self.attr("best_iteration")) # type: ignore if self.attr("best_score") is not None: self.best_score = float(self.attr("best_score")) # type: ignore - if self.attr("best_ntree_limit") is not None: - self.best_ntree_limit = int(self.attr("best_ntree_limit")) # type: ignore def num_boosted_rounds(self) -> int: """Get number of boosted rounds. For gblinear this is reset to 0 after diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 8c679b75b..88bd1c819 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1653,14 +1653,11 @@ class DaskScikitLearnBase(XGBModel): self, X: _DataT, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[_DaskCollection] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." - assert ntree_limit is None, msg return self.client.sync( self._predict_async, X, @@ -1694,12 +1691,9 @@ class DaskScikitLearnBase(XGBModel): def apply( self, X: _DataT, - ntree_limit: Optional[int] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." - assert ntree_limit is None, msg return self.client.sync(self._apply_async, X, iteration_range=iteration_range) def __await__(self) -> Awaitable[Any]: @@ -1993,14 +1987,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa def predict_proba( self, X: _DaskCollection, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[_DaskCollection] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." - assert ntree_limit is None, msg return self._client_sync( self._predict_proba_async, X=X, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 563ff8659..fffc0eb9b 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -36,7 +36,6 @@ from .core import ( Objective, QuantileDMatrix, XGBoostError, - _convert_ntree_limit, _deprecate_positional_args, _parse_eval_str, ) @@ -391,8 +390,7 @@ __model_doc = f""" metric will be used for early stopping. - If early stopping occurs, the model will have three additional fields: - :py:attr:`best_score`, :py:attr:`best_iteration` and - :py:attr:`best_ntree_limit`. + :py:attr:`best_score`, :py:attr:`best_iteration`. .. note:: @@ -1117,7 +1115,6 @@ class XGBModel(XGBModelBase): self, X: ArrayLike, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -1135,8 +1132,6 @@ class XGBModel(XGBModelBase): Data to predict with. output_margin : Whether to output the raw untransformed margin value. - ntree_limit : - Deprecated, use `iteration_range` instead. validate_features : When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. @@ -1156,9 +1151,6 @@ class XGBModel(XGBModelBase): """ with config_context(verbosity=self.verbosity): - iteration_range = _convert_ntree_limit( - self.get_booster(), ntree_limit, iteration_range - ) iteration_range = self._get_iteration_range(iteration_range) if self._can_use_inplace_predict(): try: @@ -1197,7 +1189,6 @@ class XGBModel(XGBModelBase): def apply( self, X: ArrayLike, - ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: """Return the predicted leaf every tree for each sample. If the model is trained @@ -1211,9 +1202,6 @@ class XGBModel(XGBModelBase): iteration_range : See :py:meth:`predict`. - ntree_limit : - Deprecated, use ``iteration_range`` instead. - Returns ------- X_leaves : array_like, shape=[n_samples, n_trees] @@ -1223,9 +1211,6 @@ class XGBModel(XGBModelBase): """ with config_context(verbosity=self.verbosity): - iteration_range = _convert_ntree_limit( - self.get_booster(), ntree_limit, iteration_range - ) iteration_range = self._get_iteration_range(iteration_range) test_dmatrix = DMatrix( X, @@ -1309,10 +1294,6 @@ class XGBModel(XGBModelBase): """ return int(self._early_stopping_attr("best_iteration")) - @property - def best_ntree_limit(self) -> int: - return int(self._early_stopping_attr("best_ntree_limit")) - @property def feature_importances_(self) -> np.ndarray: """Feature importances property, return depends on `importance_type` @@ -1562,7 +1543,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): self, X: ArrayLike, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -1571,7 +1551,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): class_probs = super().predict( X=X, output_margin=output_margin, - ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -1599,7 +1578,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): def predict_proba( self, X: ArrayLike, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -1614,8 +1592,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): ---------- X : array_like Feature matrix. See :ref:`py-data` for a list of supported types. - ntree_limit : int - Deprecated, use `iteration_range` instead. validate_features : bool When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. @@ -1642,7 +1618,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): if self.objective == "multi:softmax": raw_predt = super().predict( X=X, - ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -1652,7 +1627,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): return class_prob class_probs = super().predict( X=X, - ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -2074,7 +2048,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn): self, X: ArrayLike, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -2083,20 +2056,18 @@ class XGBRanker(XGBModel, XGBRankerMixIn): return super().predict( X, output_margin, - ntree_limit, validate_features, base_margin, - iteration_range, + iteration_range=iteration_range, ) def apply( self, X: ArrayLike, - ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None, ) -> ArrayLike: X, _ = _get_qid(X, None) - return super().apply(X, ntree_limit, iteration_range) + return super().apply(X, iteration_range) def score(self, X: ArrayLike, y: ArrayLike) -> float: """Evaluate score for data using the last evaluation metric. If the model is diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 6e2d4c6db..f2c5e1197 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -11,7 +11,6 @@ from xgboost import DataIter, DMatrix, QuantileDMatrix, XGBModel from xgboost.compat import concat from .._typing import ArrayLike -from ..core import _convert_ntree_limit from .utils import get_logger # type: ignore @@ -343,8 +342,7 @@ def pred_contribs( strict_shape: bool = False, ) -> np.ndarray: """Predict contributions with data with the full model.""" - iteration_range = _convert_ntree_limit(model.get_booster(), None, None) - iteration_range = model._get_iteration_range(iteration_range) + iteration_range = model._get_iteration_range(None) data_dmatrix = DMatrix( data, base_margin=base_margin, diff --git a/tests/ci_build/conda_env/aarch64_test.yml b/tests/ci_build/conda_env/aarch64_test.yml index fe30eced1..42a2fe1e4 100644 --- a/tests/ci_build/conda_env/aarch64_test.yml +++ b/tests/ci_build/conda_env/aarch64_test.yml @@ -31,6 +31,5 @@ dependencies: - pyspark - cloudpickle - pip: - - shap - awscli - auditwheel diff --git a/tests/ci_build/conda_env/linux_cpu_test.yml b/tests/ci_build/conda_env/linux_cpu_test.yml index 7977abcd4..bf657708d 100644 --- a/tests/ci_build/conda_env/linux_cpu_test.yml +++ b/tests/ci_build/conda_env/linux_cpu_test.yml @@ -37,7 +37,6 @@ dependencies: - pyarrow - protobuf - cloudpickle -- shap>=0.41 - modin # TODO: Replace it with pyspark>=3.4 once 3.4 released. # - https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index d248e14df..00791e19d 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -146,6 +146,7 @@ def main(args: argparse.Namespace) -> None: "tests/python/test_config.py", "tests/python/test_data_iterator.py", "tests/python/test_dt.py", + "tests/python/test_predict.py", "tests/python/test_quantile_dmatrix.py", "tests/python/test_tree_regularization.py", "tests/python-gpu/test_gpu_data_iterator.py", diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 516cbd6cf..f9d6f37e1 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -64,7 +64,7 @@ class TestModels: num_round = 2 bst = xgb.train(param, dtrain, num_round, watchlist) # this is prediction - preds = bst.predict(dtest, ntree_limit=num_round) + preds = bst.predict(dtest, iteration_range=(0, num_round)) labels = dtest.get_label() err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) @@ -83,7 +83,7 @@ class TestModels: bst2 = xgb.Booster(params=param, model_file=model_path) dtest2 = xgb.DMatrix(dtest_path) - preds2 = bst2.predict(dtest2, ntree_limit=num_round) + preds2 = bst2.predict(dtest2, iteration_range=(0, num_round)) # assert they are the same assert np.sum(np.abs(preds2 - preds)) == 0 @@ -96,7 +96,7 @@ class TestModels: # check whether custom evaluation metrics work bst = xgb.train(param, dtrain, num_round, watchlist, feval=my_logloss) - preds3 = bst.predict(dtest, ntree_limit=num_round) + preds3 = bst.predict(dtest, iteration_range=(0, num_round)) assert all(preds3 == preds) # check whether sample_type and normalize_type work @@ -110,7 +110,7 @@ class TestModels: param['sample_type'] = p[0] param['normalize_type'] = p[1] bst = xgb.train(param, dtrain, num_round, watchlist) - preds = bst.predict(dtest, ntree_limit=num_round) + preds = bst.predict(dtest, iteration_range=(0, num_round)) err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) assert err < 0.1 @@ -472,8 +472,8 @@ class TestModels: X, y = load_iris(return_X_y=True) cls = xgb.XGBClassifier(n_estimators=2) cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)]) - assert cls.get_booster().best_ntree_limit == 2 - assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit + assert cls.get_booster().best_iteration == cls.n_estimators - 1 + assert cls.best_iteration == cls.get_booster().best_iteration with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "cls.json") @@ -481,8 +481,8 @@ class TestModels: cls = xgb.XGBClassifier(n_estimators=2) cls.load_model(path) - assert cls.get_booster().best_ntree_limit == 2 - assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit + assert cls.get_booster().best_iteration == cls.n_estimators - 1 + assert cls.best_iteration == cls.get_booster().best_iteration def run_slice( self, diff --git a/tests/python/test_cli.py b/tests/python/test_cli.py index 69e8df83d..3d7415232 100644 --- a/tests/python/test_cli.py +++ b/tests/python/test_cli.py @@ -102,7 +102,6 @@ eval[test] = {data_path} booster.feature_names = None booster.feature_types = None booster.set_attr(best_iteration=None) - booster.set_attr(best_ntree_limit=None) booster.save_model(model_out_py) py_predt = booster.predict(data) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index cb400df87..6f89edd16 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -1,4 +1,4 @@ -'''Tests for running inplace prediction.''' +"""Tests for running inplace prediction.""" from concurrent.futures import ThreadPoolExecutor import numpy as np @@ -17,10 +17,10 @@ def run_threaded_predict(X, rows, predict_func): per_thread = 20 with ThreadPoolExecutor(max_workers=10) as e: for i in range(0, rows, int(rows / per_thread)): - if hasattr(X, 'iloc'): - predictor = X.iloc[i:i+per_thread, :] + if hasattr(X, "iloc"): + predictor = X.iloc[i : i + per_thread, :] else: - predictor = X[i:i+per_thread, ...] + predictor = X[i : i + per_thread, ...] f = e.submit(predict_func, predictor) results.append(f) @@ -61,27 +61,31 @@ def run_predict_leaf(predictor): validate_leaf_output(leaf, num_parallel_tree) - ntree_limit = 2 + n_iters = 2 sliced = booster.predict( - m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit, strict_shape=True + m, + pred_leaf=True, + iteration_range=(0, n_iters), + strict_shape=True, ) first = sliced[0, ...] - assert np.prod(first.shape) == classes * num_parallel_tree * ntree_limit + assert np.prod(first.shape) == classes * num_parallel_tree * n_iters # When there's only 1 tree, the output is a 1 dim vector booster = xgb.train({"tree_method": "hist"}, num_boost_round=1, dtrain=m) - assert booster.predict(m, pred_leaf=True).shape == (rows, ) + assert booster.predict(m, pred_leaf=True).shape == (rows,) return leaf def test_predict_leaf(): - run_predict_leaf('cpu_predictor') + run_predict_leaf("cpu_predictor") def test_predict_shape(): from sklearn.datasets import fetch_california_housing + X, y = fetch_california_housing(return_X_y=True) reg = xgb.XGBRegressor(n_estimators=1) reg.fit(X, y) @@ -119,13 +123,14 @@ def test_predict_shape(): class TestInplacePredict: - '''Tests for running inplace prediction''' + """Tests for running inplace prediction""" + @classmethod def setup_class(cls): cls.rows = 1000 cls.cols = 10 - cls.missing = 11 # set to integer for testing + cls.missing = 11 # set to integer for testing cls.rng = np.random.RandomState(1994) @@ -139,7 +144,7 @@ class TestInplacePredict: cls.test = xgb.DMatrix(cls.X[:10, ...], missing=cls.missing) cls.num_boost_round = 10 - cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10) + cls.booster = xgb.train({"tree_method": "hist"}, dtrain, num_boost_round=10) def test_predict(self): booster = self.booster @@ -162,28 +167,22 @@ class TestInplacePredict: predt_from_array = booster.inplace_predict( X[:10, ...], iteration_range=(0, 4), missing=self.missing ) - predt_from_dmatrix = booster.predict(test, ntree_limit=4) + predt_from_dmatrix = booster.predict(test, iteration_range=(0, 4)) np.testing.assert_allclose(predt_from_dmatrix, predt_from_array) - with pytest.raises(ValueError): - booster.predict(test, ntree_limit=booster.best_ntree_limit + 1) with pytest.raises(ValueError): booster.predict(test, iteration_range=(0, booster.best_iteration + 2)) default = booster.predict(test) range_full = booster.predict(test, iteration_range=(0, self.num_boost_round)) - ntree_full = booster.predict(test, ntree_limit=self.num_boost_round) np.testing.assert_allclose(range_full, default) - np.testing.assert_allclose(ntree_full, default) range_full = booster.predict( test, iteration_range=(0, booster.best_iteration + 1) ) - ntree_full = booster.predict(test, ntree_limit=booster.best_ntree_limit) np.testing.assert_allclose(range_full, default) - np.testing.assert_allclose(ntree_full, default) def predict_dense(x): inplace_predt = booster.inplace_predict(x) @@ -251,6 +250,7 @@ class TestInplacePredict: @pytest.mark.skipif(**tm.no_pandas()) def test_pd_dtypes(self) -> None: from pandas.api.types import is_bool_dtype + for orig, x in pd_dtypes(): dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes] if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]): diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 30de920f7..088b681ff 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -60,7 +60,7 @@ def test_ranking_with_weighted_data(): assert all(p <= q for p, q in zip(auc_rec, auc_rec[1:])) for i in range(1, 11): - pred = bst.predict(dtrain, ntree_limit=i) + pred = bst.predict(dtrain, iteration_range=(0, i)) # is_sorted[i]: is i-th group correctly sorted by the ranking predictor? is_sorted = [] for k in range(0, 20, 5): diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 258af760c..3ec1f1ffb 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -95,44 +95,39 @@ class TestTrainingContinuation: res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class)) assert res1 == res2 - gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, - num_boost_round=3) - assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + - 1) * self.num_parallel_tree - + gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, num_boost_round=3) res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) - res2 = mean_squared_error(y_2class, - gbdt_04.predict( - dtrain_2class, - ntree_limit=gbdt_04.best_ntree_limit)) + res2 = mean_squared_error( + y_2class, + gbdt_04.predict( + dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) + ) + ) assert res1 == res2 - gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, - num_boost_round=7, xgb_model=gbdt_04) - assert gbdt_04.best_ntree_limit == ( - gbdt_04.best_iteration + 1) * self.num_parallel_tree - + gbdt_04 = xgb.train( + xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04 + ) res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) - res2 = mean_squared_error(y_2class, - gbdt_04.predict( - dtrain_2class, - ntree_limit=gbdt_04.best_ntree_limit)) + res2 = mean_squared_error( + y_2class, + gbdt_04.predict( + dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) + ) + ) assert res1 == res2 gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7) - assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) - assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree res1 = gbdt_05.predict(dtrain_5class) - res2 = gbdt_05.predict(dtrain_5class, - ntree_limit=gbdt_05.best_ntree_limit) + res2 = gbdt_05.predict( + dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1) + ) np.testing.assert_almost_equal(res1, res2) @pytest.mark.skipif(**tm.no_sklearn()) diff --git a/tests/python/test_with_shap.py b/tests/python/test_with_shap.py index eab98f487..63d0fd11b 100644 --- a/tests/python/test_with_shap.py +++ b/tests/python/test_with_shap.py @@ -13,9 +13,9 @@ except Exception: pytestmark = pytest.mark.skipif(shap is None, reason="Requires shap package") -# Check integration is not broken from xgboost side -# Changes in binary format may cause problems -def test_with_shap(): +# xgboost removed ntree_limit in 2.0, which breaks the SHAP package. +@pytest.mark.xfail +def test_with_shap() -> None: from sklearn.datasets import fetch_california_housing X, y = fetch_california_housing(return_X_y=True) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 90d4dff18..67620e6dd 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -63,9 +63,15 @@ def test_multiclass_classification(objective): assert xgb_model.get_booster().num_boosted_rounds() == 100 preds = xgb_model.predict(X[test_index]) # test other params in XGBClassifier().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) + preds2 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=(0, 1) + ) + preds3 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=None + ) + preds4 = xgb_model.predict( + X[test_index], output_margin=False, iteration_range=(0, 1) + ) labels = y[test_index] check_pred(preds, labels, output_margin=False) @@ -86,25 +92,21 @@ def test_multiclass_classification(objective): assert proba.shape[1] == cls.n_classes_ -def test_best_ntree_limit(): +def test_best_iteration(): from sklearn.datasets import load_iris X, y = load_iris(return_X_y=True) - def train(booster, forest): + def train(booster: str, forest: Optional[int]) -> None: rounds = 4 cls = xgb.XGBClassifier( n_estimators=rounds, num_parallel_tree=forest, booster=booster ).fit( X, y, eval_set=[(X, y)], early_stopping_rounds=3 ) + assert cls.best_iteration == rounds - 1 - if forest: - assert cls.best_ntree_limit == rounds * forest - else: - assert cls.best_ntree_limit == 0 - - # best_ntree_limit is used by default, assert that under gblinear it's + # best_iteration is used by default, assert that under gblinear it's # automatically ignored due to being 0. cls.predict(X) @@ -430,12 +432,15 @@ def test_regression(): preds = xgb_model.predict(X[test_index]) # test other params in XGBRegressor().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, - ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, - ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, - ntree_limit=3) + preds2 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=(0, 3) + ) + preds3 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=None + ) + preds4 = xgb_model.predict( + X[test_index], output_margin=False, iteration_range=(0, 3) + ) labels = y[test_index] assert mean_squared_error(preds, labels) < 25 diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index a8c64713f..0ffdb2a2b 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -169,7 +169,7 @@ def reg_with_weight( ) -RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test")) +RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test", "reg_params")) @pytest.fixture @@ -181,6 +181,13 @@ def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: predt0 = reg1.predict(X) pred_contrib0: np.ndarray = pred_contribs(reg1, X, None, False) + reg_params = { + "max_depth": 5, + "n_estimators": 10, + "iteration_range": [0, 5], + "max_bin": 9, + } + # convert np array to pyspark dataframe reg_df_train_data = [ (Vectors.dense(X[0, :]), int(y[0])), @@ -188,26 +195,34 @@ def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: ] reg_df_train = spark.createDataFrame(reg_df_train_data, ["features", "label"]) + reg2 = xgb.XGBRegressor(max_depth=5, n_estimators=10) + reg2.fit(X, y) + predt2 = reg2.predict(X, iteration_range=[0, 5]) + # array([0.22185266, 0.77814734], dtype=float32) + reg_df_test = spark.createDataFrame( [ ( Vectors.dense(X[0, :]), float(predt0[0]), pred_contrib0[0, :].tolist(), + float(predt2[0]), ), ( Vectors.sparse(3, {1: 1.0, 2: 5.5}), float(predt0[1]), pred_contrib0[1, :].tolist(), + float(predt2[1]), ), ], [ "features", "expected_prediction", "expected_pred_contribs", + "expected_prediction_with_params", ], ) - yield RegData(reg_df_train, reg_df_test) + yield RegData(reg_df_train, reg_df_test, reg_params) MultiClfData = namedtuple("MultiClfData", ("multi_clf_df_train", "multi_clf_df_test")) @@ -740,6 +755,76 @@ class TestPySparkLocal: model = classifier.fit(clf_data.cls_df_train) model.transform(clf_data.cls_df_test).collect() + def test_regressor_model_save_load(self, reg_data: RegData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + regressor = SparkXGBRegressor(**reg_data.reg_params) + model = regressor.fit(reg_data.reg_df_train) + model.save(path) + loaded_model = SparkXGBRegressorModel.load(path) + assert model.uid == loaded_model.uid + for k, v in reg_data.reg_params.items(): + assert loaded_model.getOrDefault(k) == v + + pred_result = loaded_model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + + with pytest.raises(AssertionError, match="Expected class name"): + SparkXGBClassifierModel.load(path) + + assert_model_compatible(model, tmpdir) + + def test_regressor_with_params(self, reg_data: RegData) -> None: + regressor = SparkXGBRegressor(**reg_data.reg_params) + all_params = dict( + **(regressor._gen_xgb_params_dict()), + **(regressor._gen_fit_params_dict()), + **(regressor._gen_predict_params_dict()), + ) + check_sub_dict_match( + reg_data.reg_params, all_params, excluding_keys=_non_booster_params + ) + + model = regressor.fit(reg_data.reg_df_train) + all_params = dict( + **(model._gen_xgb_params_dict()), + **(model._gen_fit_params_dict()), + **(model._gen_predict_params_dict()), + ) + check_sub_dict_match( + reg_data.reg_params, all_params, excluding_keys=_non_booster_params + ) + pred_result = model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + + def test_regressor_model_pipeline_save_load(self, reg_data: RegData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + regressor = SparkXGBRegressor() + pipeline = Pipeline(stages=[regressor]) + pipeline = pipeline.copy( + extra=get_params_map(reg_data.reg_params, regressor) + ) + model = pipeline.fit(reg_data.reg_df_train) + model.save(path) + + loaded_model = PipelineModel.load(path) + for k, v in reg_data.reg_params.items(): + assert loaded_model.stages[0].getOrDefault(k) == v + + pred_result = loaded_model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + assert_model_compatible(model.stages[0], tmpdir) + class XgboostLocalTest(SparkTestCase): def setUp(self): @@ -918,12 +1003,6 @@ class XgboostLocalTest(SparkTestCase): def get_local_tmp_dir(self): return self.tempdir + str(uuid.uuid4()) - def assert_model_compatible(self, model: XGBModel, model_path: str): - bst = xgb.Booster() - path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] - bst.load_model(path) - self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) - def test_convert_to_sklearn_model_reg(self) -> None: regressor = SparkXGBRegressor( n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5 @@ -1007,80 +1086,6 @@ class XgboostLocalTest(SparkTestCase): == "float64" ) - def test_regressor_with_params(self): - regressor = SparkXGBRegressor(**self.reg_params) - all_params = dict( - **(regressor._gen_xgb_params_dict()), - **(regressor._gen_fit_params_dict()), - **(regressor._gen_predict_params_dict()), - ) - check_sub_dict_match( - self.reg_params, all_params, excluding_keys=_non_booster_params - ) - - model = regressor.fit(self.reg_df_train) - all_params = dict( - **(model._gen_xgb_params_dict()), - **(model._gen_fit_params_dict()), - **(model._gen_predict_params_dict()), - ) - check_sub_dict_match( - self.reg_params, all_params, excluding_keys=_non_booster_params - ) - pred_result = model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - - def test_regressor_model_save_load(self): - tmp_dir = self.get_local_tmp_dir() - path = "file:" + tmp_dir - regressor = SparkXGBRegressor(**self.reg_params) - model = regressor.fit(self.reg_df_train) - model.save(path) - loaded_model = SparkXGBRegressorModel.load(path) - self.assertEqual(model.uid, loaded_model.uid) - for k, v in self.reg_params.items(): - self.assertEqual(loaded_model.getOrDefault(k), v) - - pred_result = loaded_model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - - with self.assertRaisesRegex(AssertionError, "Expected class name"): - SparkXGBClassifierModel.load(path) - - self.assert_model_compatible(model, tmp_dir) - - def test_regressor_model_pipeline_save_load(self): - tmp_dir = self.get_local_tmp_dir() - path = "file:" + tmp_dir - regressor = SparkXGBRegressor() - pipeline = Pipeline(stages=[regressor]) - pipeline = pipeline.copy(extra=get_params_map(self.reg_params, regressor)) - model = pipeline.fit(self.reg_df_train) - model.save(path) - - loaded_model = PipelineModel.load(path) - for k, v in self.reg_params.items(): - self.assertEqual(loaded_model.stages[0].getOrDefault(k), v) - - pred_result = loaded_model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - self.assert_model_compatible(model.stages[0], tmp_dir) - def test_callbacks(self): from xgboost.callback import LearningRateScheduler diff --git a/tests/test_distributed/test_with_spark/test_spark_local_cluster.py b/tests/test_distributed/test_with_spark/test_spark_local_cluster.py index 528b770ff..199a8087d 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local_cluster.py +++ b/tests/test_distributed/test_with_spark/test_spark_local_cluster.py @@ -1,16 +1,24 @@ import json +import logging import os import random +import tempfile import uuid +from collections import namedtuple import numpy as np import pytest +import xgboost as xgb from xgboost import testing as tm +from xgboost.callback import LearningRateScheduler pytestmark = pytest.mark.skipif(**tm.no_spark()) +from typing import Generator + from pyspark.ml.linalg import Vectors +from pyspark.sql import SparkSession from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark.utils import _get_max_num_concurrent_tasks @@ -18,51 +26,119 @@ from xgboost.spark.utils import _get_max_num_concurrent_tasks from .utils import SparkLocalClusterTestCase +@pytest.fixture +def spark() -> Generator[SparkSession, None, None]: + config = { + "spark.master": "local-cluster[2, 2, 1024]", + "spark.python.worker.reuse": "false", + "spark.driver.host": "127.0.0.1", + "spark.task.maxFailures": "1", + "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", + "spark.sql.pyspark.jvmStacktrace.enabled": "true", + "spark.cores.max": "4", + "spark.task.cpus": "1", + "spark.executor.cores": "2", + } + + builder = SparkSession.builder.appName("XGBoost PySpark Python API Tests") + for k, v in config.items(): + builder.config(k, v) + logging.getLogger("pyspark").setLevel(logging.INFO) + sess = builder.getOrCreate() + yield sess + + sess.stop() + sess.sparkContext.stop() + + +RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test", "reg_params")) + + +@pytest.fixture +def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: + reg_params = {"max_depth": 5, "n_estimators": 10, "iteration_range": (0, 5)} + + X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + y = np.array([0, 1]) + + def custom_lr(boosting_round): + return 1.0 / (boosting_round + 1) + + reg1 = xgb.XGBRegressor(callbacks=[LearningRateScheduler(custom_lr)]) + reg1.fit(X, y) + predt1 = reg1.predict(X) + # array([0.02406833, 0.97593164], dtype=float32) + + reg2 = xgb.XGBRegressor(max_depth=5, n_estimators=10) + reg2.fit(X, y) + predt2 = reg2.predict(X, iteration_range=(0, 5)) + # array([0.22185263, 0.77814734], dtype=float32) + + reg_df_train = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + ], + ["features", "label"], + ) + reg_df_test = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0.0, float(predt2[0]), float(predt1[0])), + ( + Vectors.sparse(3, {1: 1.0, 2: 5.5}), + 1.0, + float(predt2[1]), + float(predt1[1]), + ), + ], + [ + "features", + "expected_prediction", + "expected_prediction_with_params", + "expected_prediction_with_callbacks", + ], + ) + yield RegData(reg_df_train, reg_df_test, reg_params) + + +class TestPySparkLocalCluster: + def test_regressor_basic_with_params(self, reg_data: RegData) -> None: + regressor = SparkXGBRegressor(**reg_data.reg_params) + model = regressor.fit(reg_data.reg_df_train) + pred_result = model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + + def test_callbacks(self, reg_data: RegData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, str(uuid.uuid4())) + + def custom_lr(boosting_round): + return 1.0 / (boosting_round + 1) + + cb = [LearningRateScheduler(custom_lr)] + regressor = SparkXGBRegressor(callbacks=cb) + + # Test the save/load of the estimator instead of the model, since + # the callbacks param only exists in the estimator but not in the model + regressor.save(path) + regressor = SparkXGBRegressor.load(path) + + model = regressor.fit(reg_data.reg_df_train) + pred_result = model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 + ) + + class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): def setUp(self): random.seed(2020) self.n_workers = _get_max_num_concurrent_tasks(self.session.sparkContext) - # The following code use xgboost python library to train xgb model and predict. - # - # >>> import numpy as np - # >>> import xgboost - # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) - # >>> y = np.array([0, 1]) - # >>> reg1 = xgboost.XGBRegressor() - # >>> reg1.fit(X, y) - # >>> reg1.predict(X) - # array([8.8363886e-04, 9.9911636e-01], dtype=float32) - # >>> def custom_lr(boosting_round, num_boost_round): - # ... return 1.0 / (boosting_round + 1) - # ... - # >>> reg1.fit(X, y, callbacks=[xgboost.callback.reset_learning_rate(custom_lr)]) - # >>> reg1.predict(X) - # array([0.02406833, 0.97593164], dtype=float32) - # >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10) - # >>> reg2.fit(X, y) - # >>> reg2.predict(X, ntree_limit=5) - # array([0.22185263, 0.77814734], dtype=float32) - self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} - self.reg_df_train = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), - ], - ["features", "label"], - ) - self.reg_df_test = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759), - ], - [ - "features", - "expected_prediction", - "expected_prediction_with_params", - "expected_prediction_with_callbacks", - ], - ) # Distributed section # Binary classification @@ -218,42 +294,6 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): self.reg_best_score_eval = 5.239e-05 self.reg_best_score_weight_and_eval = 4.850e-05 - def test_regressor_basic_with_params(self): - regressor = SparkXGBRegressor(**self.reg_params) - model = regressor.fit(self.reg_df_train) - pred_result = model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - - def test_callbacks(self): - from xgboost.callback import LearningRateScheduler - - path = os.path.join(self.tempdir, str(uuid.uuid4())) - - def custom_learning_rate(boosting_round): - return 1.0 / (boosting_round + 1) - - cb = [LearningRateScheduler(custom_learning_rate)] - regressor = SparkXGBRegressor(callbacks=cb) - - # Test the save/load of the estimator instead of the model, since - # the callbacks param only exists in the estimator but not in the model - regressor.save(path) - regressor = SparkXGBRegressor.load(path) - - model = regressor.fit(self.reg_df_train) - pred_result = model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 - ) - ) - def test_classifier_distributed_basic(self): classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100) model = classifier.fit(self.cls_df_train_distributed) @@ -409,7 +449,6 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): pred_result = model.transform( self.cls_df_test_distributed_lower_estimators ).collect() - print(pred_result) for row in pred_result: self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) self.assertTrue( From bcb55d3b6a5f821712713788432afbd4ad5cbbb5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 31 Mar 2023 20:48:59 +0800 Subject: [PATCH 05/12] Portable macro definition. (#8999) --- include/xgboost/collective/socket.h | 44 ++++++++++++++++------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 33d14fe8c..b5fa7cd70 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -1,11 +1,11 @@ -/*! - * Copyright (c) 2022 by XGBoost Contributors +/** + * Copyright (c) 2022-2023, XGBoost Contributors */ #pragma once #if !defined(NOMINMAX) && defined(_WIN32) #define NOMINMAX -#endif // !defined(NOMINMAX) +#endif // !defined(NOMINMAX) #include // errno, EINTR, EBADF #include // HOST_NAME_MAX @@ -18,7 +18,11 @@ #include // std::swap #if !defined(xgboost_IS_MINGW) -#define xgboost_IS_MINGW() defined(__MINGW32__) + +#if defined(__MINGW32__) +#define xgboost_IS_MINGW 1 +#endif // defined(__MINGW32__) + #endif // xgboost_IS_MINGW #if defined(_WIN32) @@ -32,11 +36,11 @@ using in_port_t = std::uint16_t; #pragma comment(lib, "Ws2_32.lib") #endif // _MSC_VER -#if !xgboost_IS_MINGW() +#if !defined(xgboost_IS_MINGW) using ssize_t = int; -#endif // !xgboost_IS_MINGW() +#endif // !xgboost_IS_MINGW() -#else // UNIX +#else // UNIX #include // inet_ntop #include // fcntl, F_GETFL, O_NONBLOCK @@ -48,9 +52,9 @@ using ssize_t = int; #if defined(__sun) || defined(sun) #include -#endif // defined(__sun) || defined(sun) +#endif // defined(__sun) || defined(sun) -#endif // defined(_WIN32) +#endif // defined(_WIN32) #include "xgboost/base.h" // XGBOOST_EXPECT #include "xgboost/logging.h" // LOG @@ -62,10 +66,10 @@ using ssize_t = int; namespace xgboost { -#if xgboost_IS_MINGW() +#if defined(xgboost_IS_MINGW) // see the dummy implementation of `poll` in rabit for more info. inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; } -#endif // xgboost_IS_MINGW() +#endif // defined(xgboost_IS_MINGW) namespace system { inline std::int32_t LastError() { @@ -144,7 +148,7 @@ inline void SocketFinalize() { #endif // defined(_WIN32) } -#if defined(_WIN32) && xgboost_IS_MINGW() +#if defined(_WIN32) && defined(xgboost_IS_MINGW) // dummy definition for old mysys32. inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT MingWError(); @@ -152,7 +156,7 @@ inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT } #else using ::inet_ntop; -#endif +#endif // defined(_WIN32) && defined(xgboost_IS_MINGW) } // namespace system @@ -296,13 +300,12 @@ class TCPSocket { #else struct sockaddr sa; socklen_t sizeofsa = sizeof(sa); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, &sa, &sizeofsa), 0); - if (sizeofsa < sizeof(uchar_t)*2) { + xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0); + if (sizeofsa < sizeof(uchar_t) * 2) { return ret_iafamily(AF_INET); } return ret_iafamily(sa.sa_family); -#endif // __PASE__ +#endif // __PASE__ #else LOG(FATAL) << "Unknown platform."; return ret_iafamily(AF_INET); @@ -508,7 +511,7 @@ class TCPSocket { * \brief Create a TCP socket on specified domain. */ static TCPSocket Create(SockDomain domain) { -#if xgboost_IS_MINGW() +#if defined(xgboost_IS_MINGW) MingWError(); return {}; #else @@ -522,7 +525,7 @@ class TCPSocket { socket.domain_ = domain; #endif // defined(__APPLE__) return socket; -#endif // xgboost_IS_MINGW() +#endif // defined(xgboost_IS_MINGW) } }; @@ -544,4 +547,7 @@ inline std::string GetHostName() { } // namespace xgboost #undef xgboost_CHECK_SYS_CALL + +#if defined(xgboost_IS_MINGW) #undef xgboost_IS_MINGW +#endif From 4caca2947df22fa76af27bbaad379aafe9f6c51d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 31 Mar 2023 23:14:58 +0800 Subject: [PATCH 06/12] Improve helper script for making release. [skip ci] (#9004) * Merge source tarball generation script. * Generate Python source wheel. * Generate hashes and release note. --- dev/release-artifacts.py | 343 +++++++++++++++++++++++++++++++++++++++ dev/release-py-r.py | 200 ----------------------- dev/release-tarball.sh | 91 ----------- doc/contrib/release.rst | 4 +- 4 files changed, 346 insertions(+), 292 deletions(-) create mode 100644 dev/release-artifacts.py delete mode 100644 dev/release-py-r.py delete mode 100755 dev/release-tarball.sh diff --git a/dev/release-artifacts.py b/dev/release-artifacts.py new file mode 100644 index 000000000..18c317a91 --- /dev/null +++ b/dev/release-artifacts.py @@ -0,0 +1,343 @@ +"""Simple script for managing Python, R, and source release packages. + +tqdm, sh are required to run this script. +""" +import argparse +import os +import shutil +import subprocess +import tarfile +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from urllib.request import urlretrieve + +import tqdm +from packaging import version +from sh.contrib import git + +# The package building is managed by Jenkins CI. +PREFIX = "https://s3-us-west-2.amazonaws.com/xgboost-nightly-builds/release_" +ROOT = Path(__file__).absolute().parent.parent +DIST = ROOT / "python-package" / "dist" + +pbar = None + + +class DirectoryExcursion: + def __init__(self, path: Union[os.PathLike, str]) -> None: + self.path = path + self.curdir = os.path.normpath(os.path.abspath(os.path.curdir)) + + def __enter__(self) -> None: + os.chdir(self.path) + + def __exit__(self, *args: Any) -> None: + os.chdir(self.curdir) + + +def show_progress(block_num, block_size, total_size): + "Show file download progress." + global pbar + if pbar is None: + pbar = tqdm.tqdm(total=total_size / 1024, unit="kB") + + downloaded = block_num * block_size + if downloaded < total_size: + upper = (total_size - downloaded) / 1024 + pbar.update(min(block_size / 1024, upper)) + else: + pbar.close() + pbar = None + + +def retrieve(url, filename=None): + print(f"{url} -> {filename}") + return urlretrieve(url, filename, reporthook=show_progress) + + +def latest_hash() -> str: + "Get latest commit hash." + ret = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True) + assert ret.returncode == 0, "Failed to get latest commit hash." + commit_hash = ret.stdout.decode("utf-8").strip() + return commit_hash + + +def download_wheels( + platforms: List[str], + dir_URL: str, + src_filename_prefix: str, + target_filename_prefix: str, + outdir: str, +) -> List[str]: + """Download all binary wheels. dir_URL is the URL for remote directory storing the + release wheels. + + """ + + filenames = [] + outdir = os.path.join(outdir, "dist") + if not os.path.exists(outdir): + os.mkdir(outdir) + + for platform in platforms: + src_wheel = src_filename_prefix + platform + ".whl" + url = dir_URL + src_wheel + + target_wheel = target_filename_prefix + platform + ".whl" + filename = os.path.join(outdir, target_wheel) + filenames.append(filename) + retrieve(url=url, filename=filename) + ret = subprocess.run(["twine", "check", filename], capture_output=True) + assert ret.returncode == 0, "Failed twine check" + stderr = ret.stderr.decode("utf-8") + stdout = ret.stdout.decode("utf-8") + assert stderr.find("warning") == -1, "Unresolved warnings:\n" + stderr + assert stdout.find("warning") == -1, "Unresolved warnings:\n" + stdout + return filenames + + +def make_pysrc_wheel(release: str, outdir: str) -> None: + """Make Python source distribution.""" + dist = os.path.join(outdir, "dist") + if not os.path.exists(dist): + os.mkdir(dist) + + with DirectoryExcursion(os.path.join(ROOT, "python-package")): + subprocess.check_call(["python", "setup.py", "sdist"]) + src = os.path.join(DIST, f"xgboost-{release}.tar.gz") + subprocess.check_call(["twine", "check", src]) + shutil.move(src, os.path.join(dist, f"xgboost-{release}.tar.gz")) + + +def download_py_packages( + branch: str, major: int, minor: int, commit_hash: str, outdir: str +) -> None: + platforms = [ + "win_amd64", + "manylinux2014_x86_64", + "manylinux2014_aarch64", + "macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64", + "macosx_12_0_arm64", + ] + + branch = branch.split("_")[1] # release_x.y.z + dir_URL = PREFIX + branch + "/" + src_filename_prefix = "xgboost-" + args.release + "%2B" + commit_hash + "-py3-none-" + target_filename_prefix = "xgboost-" + args.release + "-py3-none-" + + if not os.path.exists(DIST): + os.mkdir(DIST) + + filenames = download_wheels( + platforms, dir_URL, src_filename_prefix, target_filename_prefix, outdir + ) + print("List of downloaded wheels:", filenames) + print( + """ +Following steps should be done manually: +- Upload pypi package by `python3 -m twine upload dist/` for all wheels. +- Check the uploaded files on `https://pypi.org/project/xgboost//#files` and + `pip install xgboost==` """ + ) + + +def download_r_packages( + release: str, branch: str, rc: str, commit: str, outdir: str +) -> Tuple[Dict[str, str], List[str]]: + platforms = ["win64", "linux"] + dirname = os.path.join(outdir, "r-packages") + if not os.path.exists(dirname): + os.mkdir(dirname) + + filenames = [] + branch = branch.split("_")[1] # release_x.y.z + urls = {} + + for plat in platforms: + url = f"{PREFIX}{branch}/xgboost_r_gpu_{plat}_{commit}.tar.gz" + + if not rc: + filename = f"xgboost_r_gpu_{plat}_{release}.tar.gz" + else: + filename = f"xgboost_r_gpu_{plat}_{release}-{rc}.tar.gz" + + target = os.path.join(dirname, filename) + retrieve(url=url, filename=target) + filenames.append(target) + urls[plat] = url + + print("Finished downloading R packages:", filenames) + hashes = [] + with DirectoryExcursion(os.path.join(outdir, "r-packages")): + for f in filenames: + ret = subprocess.run(["sha256sum", os.path.basename(f)], capture_output=True) + h = ret.stdout.decode().strip() + hashes.append(h) + return urls, hashes + + +def check_path(): + root = os.path.abspath(os.path.curdir) + assert os.path.basename(root) == "xgboost", "Must be run on project root." + + +def make_src_package(release: str, outdir: str) -> Tuple[str, str]: + tarname = f"xgboost-{release}.tar.gz" + tarpath = os.path.join(outdir, tarname) + if os.path.exists(tarpath): + os.remove(tarpath) + + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + shutil.copytree(os.path.curdir, tmpdir / "xgboost") + with DirectoryExcursion(tmpdir / "xgboost"): + ret = subprocess.run( + ["git", "submodule", "foreach", "--quiet", "echo $sm_path"], + capture_output=True, + ) + submodules = ret.stdout.decode().strip().split() + for mod in submodules: + mod_path = os.path.join(os.path.abspath(os.path.curdir), mod, ".git") + os.remove(mod_path) + shutil.rmtree(".git") + with tarfile.open(tarpath, "x:gz") as tar: + src = tmpdir / "xgboost" + tar.add(src, arcname="xgboost") + + with DirectoryExcursion(os.path.dirname(tarpath)): + ret = subprocess.run(["sha256sum", tarname], capture_output=True) + h = ret.stdout.decode().strip() + return tarname, h + + +def release_note( + release: str, + artifact_hashes: List[str], + r_urls: Dict[str, str], + tarname: str, + outdir: str, +) -> None: + """Generate a note for GitHub release description.""" + r_gpu_linux_url = r_urls["linux"] + r_gpu_win64_url = r_urls["win64"] + src_tarball = ( + f"https://github.com/dmlc/xgboost/releases/download/v{release}/{tarname}" + ) + hash_note = "\n".join(artifact_hashes) + + end_note = f""" +### Additional artifacts: + +You can verify the downloaded packages by running the following command on your Unix shell: + +``` sh +echo " " | shasum -a 256 --check +``` + +``` +{hash_note} +``` + +**Experimental binary packages for R with CUDA enabled** +* xgboost_r_gpu_linux_1.7.5.tar.gz: [Download]({r_gpu_linux_url}) +* xgboost_r_gpu_win64_1.7.5.tar.gz: [Download]({r_gpu_win64_url}) + +**Source tarball** +* xgboost.tar.gz: [Download]({src_tarball})""" + print(end_note) + with open(os.path.join(outdir, "end_note.md"), "w") as fd: + fd.write(end_note) + + +def main(args: argparse.Namespace) -> None: + check_path() + + rel = version.parse(args.release) + assert isinstance(rel, version.Version) + + major = rel.major + minor = rel.minor + patch = rel.micro + + print("Release:", rel) + if not rel.is_prerelease: + # Major release + rc: Optional[str] = None + rc_ver: Optional[int] = None + else: + # RC release + major = rel.major + minor = rel.minor + patch = rel.micro + assert rel.pre is not None + rc, rc_ver = rel.pre + assert rc == "rc" + + release = str(major) + "." + str(minor) + "." + str(patch) + if args.branch is not None: + branch = args.branch + else: + branch = "release_" + str(major) + "." + str(minor) + ".0" + + git.clean("-xdf") + git.checkout(branch) + git.pull("origin", branch) + git.submodule("update") + commit_hash = latest_hash() + + if not os.path.exists(args.outdir): + os.mkdir(args.outdir) + + # source tarball + hashes: List[str] = [] + tarname, h = make_src_package(release, args.outdir) + hashes.append(h) + + # CUDA R packages + urls, hr = download_r_packages( + release, + branch, + "" if rc is None else rc + str(rc_ver), + commit_hash, + args.outdir, + ) + hashes.extend(hr) + + # Python source wheel + make_pysrc_wheel(release, args.outdir) + + # Python binary wheels + download_py_packages(branch, major, minor, commit_hash, args.outdir) + + # Write end note + release_note(release, hashes, urls, tarname, args.outdir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--release", + type=str, + required=True, + help="Version tag, e.g. '1.3.2', or '1.5.0rc1'", + ) + parser.add_argument( + "--branch", + type=str, + default=None, + help=( + "Optional branch. Usually patch releases reuse the same branch of the" + " major release, but there can be exception." + ), + ) + parser.add_argument( + "--outdir", + type=str, + default=None, + required=True, + help="Directory to store the generated packages.", + ) + args = parser.parse_args() + main(args) diff --git a/dev/release-py-r.py b/dev/release-py-r.py deleted file mode 100644 index 11524927d..000000000 --- a/dev/release-py-r.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Simple script for downloading and checking pypi release wheels. - -tqdm, sh are required to run this script. -""" -import argparse -import os -import subprocess -from typing import List, Optional -from urllib.request import urlretrieve - -import tqdm -from packaging import version -from sh.contrib import git - -# The package building is managed by Jenkins CI. -PREFIX = "https://s3-us-west-2.amazonaws.com/xgboost-nightly-builds/release_" -DIST = os.path.join(os.path.curdir, "python-package", "dist") - -pbar = None - - -def show_progress(block_num, block_size, total_size): - "Show file download progress." - global pbar - if pbar is None: - pbar = tqdm.tqdm(total=total_size / 1024, unit="kB") - - downloaded = block_num * block_size - if downloaded < total_size: - upper = (total_size - downloaded) / 1024 - pbar.update(min(block_size / 1024, upper)) - else: - pbar.close() - pbar = None - - -def retrieve(url, filename=None): - print(f"{url} -> {filename}") - return urlretrieve(url, filename, reporthook=show_progress) - - -def latest_hash() -> str: - "Get latest commit hash." - ret = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True) - assert ret.returncode == 0, "Failed to get latest commit hash." - commit_hash = ret.stdout.decode("utf-8").strip() - return commit_hash - - -def download_wheels( - platforms: List[str], - dir_URL: str, - src_filename_prefix: str, - target_filename_prefix: str, -) -> List[str]: - """Download all binary wheels. dir_URL is the URL for remote directory storing the release - wheels - - """ - - filenames = [] - for platform in platforms: - src_wheel = src_filename_prefix + platform + ".whl" - url = dir_URL + src_wheel - - target_wheel = target_filename_prefix + platform + ".whl" - filename = os.path.join(DIST, target_wheel) - filenames.append(filename) - retrieve(url=url, filename=filename) - ret = subprocess.run(["twine", "check", filename], capture_output=True) - assert ret.returncode == 0, "Failed twine check" - stderr = ret.stderr.decode("utf-8") - stdout = ret.stdout.decode("utf-8") - assert stderr.find("warning") == -1, "Unresolved warnings:\n" + stderr - assert stdout.find("warning") == -1, "Unresolved warnings:\n" + stdout - return filenames - - -def download_py_packages(branch: str, major: int, minor: int, commit_hash: str) -> None: - platforms = [ - "win_amd64", - "manylinux2014_x86_64", - "manylinux2014_aarch64", - "macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64", - "macosx_12_0_arm64" - ] - - branch = branch.split("_")[1] # release_x.y.z - dir_URL = PREFIX + branch + "/" - src_filename_prefix = "xgboost-" + args.release + "%2B" + commit_hash + "-py3-none-" - target_filename_prefix = "xgboost-" + args.release + "-py3-none-" - - if not os.path.exists(DIST): - os.mkdir(DIST) - - filenames = download_wheels( - platforms, dir_URL, src_filename_prefix, target_filename_prefix - ) - print("List of downloaded wheels:", filenames) - print( - """ -Following steps should be done manually: -- Generate source package by running `python setup.py sdist`. -- Upload pypi package by `python3 -m twine upload dist/` for all wheels. -- Check the uploaded files on `https://pypi.org/project/xgboost//#files` and `pip - install xgboost==` """ - ) - - -def download_r_packages(release: str, branch: str, rc: str, commit: str) -> None: - platforms = ["win64", "linux"] - dirname = "./r-packages" - if not os.path.exists(dirname): - os.mkdir(dirname) - - filenames = [] - branch = branch.split("_")[1] # release_x.y.z - - for plat in platforms: - url = f"{PREFIX}{branch}/xgboost_r_gpu_{plat}_{commit}.tar.gz" - - if not rc: - filename = f"xgboost_r_gpu_{plat}_{release}.tar.gz" - else: - filename = f"xgboost_r_gpu_{plat}_{release}-{rc}.tar.gz" - - target = os.path.join(dirname, filename) - retrieve(url=url, filename=target) - filenames.append(target) - - print("Finished downloading R packages:", filenames) - - -def check_path(): - root = os.path.abspath(os.path.curdir) - assert os.path.basename(root) == "xgboost", "Must be run on project root." - - -def main(args: argparse.Namespace) -> None: - check_path() - - rel = version.parse(args.release) - assert isinstance(rel, version.Version) - - major = rel.major - minor = rel.minor - patch = rel.micro - - print("Release:", rel) - if not rel.is_prerelease: - # Major release - rc: Optional[str] = None - rc_ver: Optional[int] = None - else: - # RC release - major = rel.major - minor = rel.minor - patch = rel.micro - assert rel.pre is not None - rc, rc_ver = rel.pre - assert rc == "rc" - - release = str(major) + "." + str(minor) + "." + str(patch) - if args.branch is not None: - branch = args.branch - else: - branch = "release_" + str(major) + "." + str(minor) + ".0" - - git.clean("-xdf") - git.checkout(branch) - git.pull("origin", branch) - git.submodule("update") - commit_hash = latest_hash() - - download_r_packages( - release, branch, "" if rc is None else rc + str(rc_ver), commit_hash - ) - - download_py_packages(branch, major, minor, commit_hash) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--release", - type=str, - required=True, - help="Version tag, e.g. '1.3.2', or '1.5.0rc1'" - ) - parser.add_argument( - "--branch", - type=str, - default=None, - help=( - "Optional branch. Usually patch releases reuse the same branch of the" - " major release, but there can be exception." - ) - ) - args = parser.parse_args() - main(args) diff --git a/dev/release-tarball.sh b/dev/release-tarball.sh deleted file mode 100755 index c2c24f9a9..000000000 --- a/dev/release-tarball.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env bash - -# Helper script for creating release tarball. - -print_usage() { - printf "Script for making release source tarball.\n" - printf "Usage:\n\trelease-tarball.sh \n\n" -} - -print_error() { - local msg=$1 - printf "\u001b[31mError\u001b[0m: $msg\n\n" - print_usage -} - -check_input() { - local TAG=$1 - if [ -z $TAG ]; then - print_error "Empty tag argument" - exit -1 - fi -} - -check_curdir() { - local CUR_ABS=$1 - printf "Current directory: ${CUR_ABS}\n" - local CUR=$(basename $CUR_ABS) - - if [ $CUR == "dev" ]; then - cd .. - CUR=$(basename $(pwd)) - fi - - if [ $CUR != "xgboost" ]; then - print_error "Must be in project root or xgboost/dev. Current directory: ${CUR}" - exit -1; - fi -} - -# Remove all submodules. -cleanup_git() { - local TAG=$1 - check_input $TAG - - git checkout $TAG || exit -1 - - local SUBMODULES=$(grep "path = " ./.gitmodules | cut -f 3 --delimiter=' ' -) - - for module in $SUBMODULES; do - rm -rf ${module}/.git - done - - rm -rf .git -} - -make_tarball() { - local SRCDIR=$1 - local CUR_ABS=$2 - tar -czf xgboost.tar.gz xgboost - - printf "Copying ${SRCDIR}/xgboost.tar.gz back to ${CUR_ABS}/xgboost.tar.gz .\n" - cp xgboost.tar.gz ${CUR_ABS}/xgboost.tar.gz - printf "Writing hash to ${CUR_ABS}/hash .\n" - sha256sum -z ${CUR_ABS}/xgboost.tar.gz | cut -f 1 --delimiter=' ' > ${CUR_ABS}/hash -} - -main() { - local TAG=$1 - check_input $TAG - - local CUR_ABS=$(pwd) - check_curdir $CUR_ABS - - local TMPDIR=$(mktemp -d) - printf "tmpdir: ${TMPDIR}\n" - - git clean -xdf || exit -1 - cp -R . $TMPDIR/xgboost - pushd . - - cd $TMPDIR/xgboost - cleanup_git $TAG - - cd .. - make_tarball $TMPDIR $CUR_ABS - - popd - rm -rf $TMPDIR -} - -main $1 diff --git a/doc/contrib/release.rst b/doc/contrib/release.rst index 2799d8fe2..c0370b14e 100644 --- a/doc/contrib/release.rst +++ b/doc/contrib/release.rst @@ -23,7 +23,9 @@ Making a Release 5. Make a release on GitHub tag page, which might be done with previous step if the tag is created on GitHub. 6. Submit pip, CRAN, and Maven packages. - + The pip package is maintained by `Hyunsu Cho `__ and `Jiaming Yuan `__. There's a helper script for downloading pre-built wheels and R packages ``xgboost/dev/release-pypi-r.py`` along with simple instructions for using ``twine``. + There are helper scripts for automating the process in ``xgboost/dev/``. + + + The pip package is maintained by `Hyunsu Cho `__ and `Jiaming Yuan `__. + The CRAN package is maintained by `Tong He `_ and `Jiaming Yuan `__. From 720a8c3273a844679e1edf11bd87297331f19d58 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 1 Apr 2023 04:04:30 +0800 Subject: [PATCH 07/12] [doc] Remove parameter type in Python doc strings. (#9005) --- demo/guide-python/quantile_regression.py | 2 + doc/parameter.rst | 8 ++ python-package/xgboost/core.py | 106 +++++++++++------------ python-package/xgboost/plotting.py | 63 +++++++------- python-package/xgboost/sklearn.py | 18 ++-- python-package/xgboost/training.py | 2 +- 6 files changed, 105 insertions(+), 94 deletions(-) diff --git a/demo/guide-python/quantile_regression.py b/demo/guide-python/quantile_regression.py index d92115bf0..6d3e08df5 100644 --- a/demo/guide-python/quantile_regression.py +++ b/demo/guide-python/quantile_regression.py @@ -2,6 +2,8 @@ Quantile Regression =================== + .. versionadded:: 2.0.0 + The script is inspired by this awesome example in sklearn: https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html diff --git a/doc/parameter.rst b/doc/parameter.rst index e26ec83b2..c070e7018 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -360,7 +360,13 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``reg:logistic``: logistic regression. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. - ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction. If used in distributed training, the leaf value is calculated as the mean value from all workers, which is not guaranteed to be optimal. + + .. versionadded:: 1.7.0 + - ``reg:quantileerror``: Quantile loss, also known as ``pinball loss``. See later sections for its parameter and :ref:`sphx_glr_python_examples_quantile_regression.py` for a worked example. + + .. versionadded:: 2.0.0 + - ``binary:logistic``: logistic regression for binary classification, output probability - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. @@ -467,6 +473,8 @@ Parameter for using Quantile Loss (``reg:quantileerror``) * ``quantile_alpha``: A scala or a list of targeted quantiles. + .. versionadded:: 2.0.0 + Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likelihood of AFT metric (``aft-nloglik``) ==================================================================================================================== diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 68346d900..3a27f5e18 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -94,9 +94,9 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]: Parameters ---------- - data : ctypes pointer + data : pointer to data - length : ctypes pointer + length : pointer to length of data """ res = [] @@ -131,9 +131,9 @@ def _expect(expectations: Sequence[Type], got: Type) -> str: Parameters ---------- - expectations: sequence + expectations : a list of expected value. - got: + got : actual input Returns @@ -263,7 +263,7 @@ def _check_call(ret: int) -> None: Parameters ---------- - ret : int + ret : return value from API calls """ if ret != 0: @@ -271,10 +271,10 @@ def _check_call(ret: int) -> None: def build_info() -> dict: - """Build information of XGBoost. The returned value format is not stable. Also, please - note that build time dependency is not the same as runtime dependency. For instance, - it's possible to build XGBoost with older CUDA version but run it with the lastest - one. + """Build information of XGBoost. The returned value format is not stable. Also, + please note that build time dependency is not the same as runtime dependency. For + instance, it's possible to build XGBoost with older CUDA version but run it with the + lastest one. .. versionadded:: 1.6.0 @@ -658,28 +658,28 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m data : Data source of DMatrix. See :ref:`py-data` for a list of supported input types. - label : array_like + label : Label of the training data. - weight : array_like + weight : Weight for each instance. - .. note:: For ranking task, weights are per-group. + .. note:: - In ranking task, one weight is assigned to each group (not each - data point). This is because we only care about the relative - ordering of data points within each group, so it doesn't make - sense to assign weights to individual data points. + For ranking task, weights are per-group. In ranking task, one weight + is assigned to each group (not each data point). This is because we + only care about the relative ordering of data points within each group, + so it doesn't make sense to assign weights to individual data points. - base_margin: array_like + base_margin : Base margin used for boosting from existing model. - missing : float, optional - Value in the input data which needs to be present as a missing - value. If None, defaults to np.nan. - silent : boolean, optional + missing : + Value in the input data which needs to be present as a missing value. If + None, defaults to np.nan. + silent : Whether print messages during construction - feature_names : list, optional + feature_names : Set names for features. - feature_types : FeatureTypes + feature_types : Set types for features. When `enable_categorical` is set to `True`, string "c" represents categorical data type while "q" represents numerical feature @@ -689,20 +689,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m `.cat.codes` method. This is useful when users want to specify categorical features without having to construct a dataframe as input. - nthread : integer, optional + nthread : Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. - group : array_like + group : Group size for all ranking group. - qid : array_like + qid : Query ID for data samples, used for ranking. - label_lower_bound : array_like + label_lower_bound : Lower bound for survival training. - label_upper_bound : array_like + label_upper_bound : Upper bound for survival training. - feature_weights : array_like, optional + feature_weights : Set feature weights for column sampling. - enable_categorical: boolean, optional + enable_categorical : .. versionadded:: 1.3.0 @@ -1712,6 +1712,7 @@ class Booster: string. .. versionadded:: 1.0.0 + """ json_string = ctypes.c_char_p() length = c_bst_ulong() @@ -1744,8 +1745,8 @@ class Booster: Returns ------- - booster: `Booster` - a copied booster model + booster : + A copied booster model """ return copy.copy(self) @@ -1754,12 +1755,12 @@ class Booster: Parameters ---------- - key : str + key : The key to get attribute from. Returns ------- - value : str + value : The attribute value of the key, returns None if attribute do not exist. """ ret = ctypes.c_char_p() @@ -1878,9 +1879,9 @@ class Booster: Parameters ---------- - params: dict/list/str + params : list of key,value pairs, dict of key to value or simply str key - value: optional + value : value of the specified parameter, when params is str key """ if isinstance(params, Mapping): @@ -1903,11 +1904,11 @@ class Booster: Parameters ---------- - dtrain : DMatrix + dtrain : Training data. - iteration : int + iteration : Current iteration number. - fobj : function + fobj : Customized objective function. """ @@ -2205,8 +2206,7 @@ class Booster: Parameters ---------- - data : numpy.ndarray/scipy.sparse.csr_matrix/cupy.ndarray/ - cudf.DataFrame/pd.DataFrame + data : The input data, must not be a view for numpy array. Set ``predictor`` to ``gpu_predictor`` for running prediction on CuPy array or CuDF DataFrame. @@ -2390,7 +2390,7 @@ class Booster: Parameters ---------- - fname : string or os.PathLike + fname : Output file name """ @@ -2494,13 +2494,13 @@ class Booster: Parameters ---------- - fout : string or os.PathLike + fout : Output file name. - fmap : string or os.PathLike, optional + fmap : Name of the file containing feature map names. - with_stats : bool, optional + with_stats : Controls whether the split statistics are output. - dump_format : string, optional + dump_format : Format of model dump file. Can be 'text' or 'json'. """ if isinstance(fout, (str, os.PathLike)): @@ -2604,9 +2604,9 @@ class Booster: Parameters ---------- - fmap: + fmap : The name of feature map file. - importance_type: + importance_type : One of the importance types defined above. Returns @@ -2655,7 +2655,7 @@ class Booster: Parameters ---------- - fmap: str or os.PathLike (optional) + fmap : The name of feature map file. """ # pylint: disable=too-many-locals @@ -2821,15 +2821,15 @@ class Booster: Parameters ---------- - feature: str + feature : The name of the feature. - fmap: str or os.PathLike (optional) + fmap: The name of feature map file. - bin: int, default None + bin : The maximum number of bins. Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique. - as_pandas: bool, default True + as_pandas : Return pd.DataFrame when pandas is installed. If False or pandas is not installed, return numpy ndarray. diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index a364e1eb6..71058e8c9 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -1,10 +1,9 @@ # pylint: disable=too-many-locals, too-many-arguments, invalid-name, # pylint: disable=too-many-branches -# coding: utf-8 """Plotting Library.""" import json from io import BytesIO -from typing import Any, Optional +from typing import Any, Optional, Union import numpy as np @@ -17,7 +16,7 @@ GraphvizSource = Any # real type is graphviz.Source def plot_importance( - booster: Booster, + booster: Union[XGBModel, Booster, dict], ax: Optional[Axes] = None, height: float = 0.2, xlim: Optional[tuple] = None, @@ -37,40 +36,42 @@ def plot_importance( Parameters ---------- - booster : Booster, XGBModel or dict + booster : Booster or XGBModel instance, or dict taken by Booster.get_fscore() - ax : matplotlib Axes, default None + ax : matplotlib Axes Target axes instance. If None, new figure and axes will be created. - grid : bool, Turn the axes grids on or off. Default is True (On). - importance_type : str, default "weight" + grid : + Turn the axes grids on or off. Default is True (On). + importance_type : How the importance is calculated: either "weight", "gain", or "cover" * "weight" is the number of times a feature appears in a tree * "gain" is the average gain of splits which use the feature * "cover" is the average coverage of splits which use the feature where coverage is defined as the number of samples affected by the split - max_num_features : int, default None - Maximum number of top features displayed on plot. If None, all features will be displayed. - height : float, default 0.2 + max_num_features : + Maximum number of top features displayed on plot. If None, all features will be + displayed. + height : Bar height, passed to ax.barh() - xlim : tuple, default None + xlim : Tuple passed to axes.xlim() - ylim : tuple, default None + ylim : Tuple passed to axes.ylim() - title : str, default "Feature importance" + title : Axes title. To disable, pass None. - xlabel : str, default "F score" + xlabel : X axis title label. To disable, pass None. - ylabel : str, default "Features" + ylabel : Y axis title label. To disable, pass None. - fmap: str or os.PathLike (optional) + fmap : The name of feature map file. - show_values : bool, default True + show_values : Show values on plot. To disable, pass False. - values_format : str, default "{v}" - Format string for values. "v" will be replaced by the value of the feature importance. - e.g. Pass "{v:.2f}" in order to limit the number of digits after the decimal point - to two, for each value printed on the graph. + values_format : + Format string for values. "v" will be replaced by the value of the feature + importance. e.g. Pass "{v:.2f}" in order to limit the number of digits after + the decimal point to two, for each value printed on the graph. kwargs : Other keywords passed to ax.barh() @@ -146,7 +147,7 @@ def plot_importance( def to_graphviz( - booster: Booster, + booster: Union[Booster, XGBModel], fmap: PathLike = "", num_trees: int = 0, rankdir: Optional[str] = None, @@ -162,19 +163,19 @@ def to_graphviz( Parameters ---------- - booster : Booster, XGBModel + booster : Booster or XGBModel instance - fmap: str (optional) + fmap : The name of feature map file - num_trees : int, default 0 + num_trees : Specify the ordinal number of target tree - rankdir : str, default "UT" + rankdir : Passed to graphviz via graph_attr - yes_color : str, default '#0000FF' + yes_color : Edge color when meets the node condition. - no_color : str, default '#FF0000' + no_color : Edge color when doesn't meet the node condition. - condition_node_params : dict, optional + condition_node_params : Condition node configuration for for graphviz. Example: .. code-block:: python @@ -183,7 +184,7 @@ def to_graphviz( 'style': 'filled,rounded', 'fillcolor': '#78bceb'} - leaf_node_params : dict, optional + leaf_node_params : Leaf node configuration for graphviz. Example: .. code-block:: python @@ -192,7 +193,7 @@ def to_graphviz( 'style': 'filled', 'fillcolor': '#e48038'} - \\*\\*kwargs: dict, optional + kwargs : Other keywords passed to graphviz graph_attr, e.g. ``graph [ {key} = {value} ]`` Returns diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index fffc0eb9b..9b5949cdb 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1012,9 +1012,9 @@ class XGBModel(XGBModelBase): verbose : If `verbose` is True and an evaluation set is used, the evaluation metric measured on the validation set is printed to stdout at each boosting stage. - If `verbose` is an integer, the evaluation metric is printed at each `verbose` - boosting stage. The last boosting stage / the boosting stage found by using - `early_stopping_rounds` is also printed. + If `verbose` is an integer, the evaluation metric is printed at each + `verbose` boosting stage. The last boosting stage / the boosting stage found + by using `early_stopping_rounds` is also printed. xgb_model : file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). @@ -1590,12 +1590,12 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): Parameters ---------- - X : array_like + X : Feature matrix. See :ref:`py-data` for a list of supported types. - validate_features : bool + validate_features : When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. - base_margin : array_like + base_margin : Margin added to prediction. iteration_range : Specifies which layer of trees are used in prediction. For example, if a @@ -1964,9 +1964,9 @@ class XGBRanker(XGBModel, XGBRankerMixIn): verbose : If `verbose` is True and an evaluation set is used, the evaluation metric measured on the validation set is printed to stdout at each boosting stage. - If `verbose` is an integer, the evaluation metric is printed at each `verbose` - boosting stage. The last boosting stage / the boosting stage found by using - `early_stopping_rounds` is also printed. + If `verbose` is an integer, the evaluation metric is printed at each + `verbose` boosting stage. The last boosting stage / the boosting stage found + by using `early_stopping_rounds` is also printed. xgb_model : file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 5ef6eeaa2..a238e73c8 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -95,7 +95,7 @@ def train( feval : .. deprecated:: 1.6.0 Use `custom_metric` instead. - maximize : bool + maximize : Whether to maximize feval. early_stopping_rounds : Activates early stopping. Validation metric needs to improve at least once in From 15e073ca9d633d622eae94d33b846b9fd0ecfa3e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 3 Apr 2023 02:07:42 -0700 Subject: [PATCH 08/12] Make objectives work with vertical distributed and federated learning (#9002) --- src/objective/adaptive.cc | 65 ++++++----- src/objective/adaptive.cu | 4 +- src/objective/adaptive.h | 14 ++- src/objective/quantile_obj.cu | 11 +- src/objective/regression_obj.cu | 6 +- tests/cpp/plugin/test_federated_learner.cc | 127 ++++++++++++--------- tests/cpp/test_learner.cc | 83 ++++++++++---- 7 files changed, 199 insertions(+), 111 deletions(-) diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc index 4a67e848b..32fda9ef1 100644 --- a/src/objective/adaptive.cc +++ b/src/objective/adaptive.cc @@ -85,7 +85,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit size_t n_leaf = nidx.size(); if (nptr.empty()) { std::vector quantiles; - UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree); + UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree); return; } @@ -99,39 +99,46 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_, predt.Size() / info.num_row_); - // loop over each leaf - common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { - auto nidx = h_node_idx[k]; - CHECK(tree[nidx].IsLeaf()); - CHECK_LT(k + 1, h_node_ptr.size()); - size_t n = h_node_ptr[k + 1] - h_node_ptr[k]; - auto h_row_set = common::Span{ridx}.subspan(h_node_ptr[k], n); + if (!info.IsVerticalFederated() || collective::GetRank() == 0) { + // loop over each leaf + common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { + auto nidx = h_node_idx[k]; + CHECK(tree[nidx].IsLeaf()); + CHECK_LT(k + 1, h_node_ptr.size()); + size_t n = h_node_ptr[k + 1] - h_node_ptr[k]; + auto h_row_set = common::Span{ridx}.subspan(h_node_ptr[k], n); - auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx)); - auto h_weights = linalg::MakeVec(&info.weights_); + auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx)); + auto h_weights = linalg::MakeVec(&info.weights_); - auto iter = common::MakeIndexTransformIter([&](size_t i) -> float { - auto row_idx = h_row_set[i]; - return h_labels(row_idx) - h_predt(row_idx, group_idx); - }); - auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float { - auto row_idx = h_row_set[i]; - return h_weights(row_idx); + auto iter = common::MakeIndexTransformIter([&](size_t i) -> float { + auto row_idx = h_row_set[i]; + return h_labels(row_idx) - h_predt(row_idx, group_idx); + }); + auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float { + auto row_idx = h_row_set[i]; + return h_weights(row_idx); + }); + + float q{0}; + if (info.weights_.Empty()) { + q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size()); + } else { + q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it); + } + if (std::isnan(q)) { + CHECK(h_row_set.empty()); + } + quantiles.at(k) = q; }); + } - float q{0}; - if (info.weights_.Empty()) { - q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size()); - } else { - q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it); - } - if (std::isnan(q)) { - CHECK(h_row_set.empty()); - } - quantiles.at(k) = q; - }); + if (info.IsVerticalFederated()) { + collective::Broadcast(static_cast(quantiles.data()), quantiles.size() * sizeof(float), + 0); + } - UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree); + UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree); } #if !defined(XGBOOST_USE_CUDA) diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu index 662b0330b..bba8b85ad 100644 --- a/src/objective/adaptive.cu +++ b/src/objective/adaptive.cu @@ -151,7 +151,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos if (nptr.Empty()) { std::vector quantiles; - UpdateLeafValues(&quantiles, nidx.ConstHostVector(), learning_rate, p_tree); + UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree); } HostDeviceVector quantiles; @@ -186,7 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos w_it + d_weights.size(), &quantiles); } - UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), learning_rate, p_tree); + UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree); } } // namespace detail } // namespace obj diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index fef920ec9..7494bceb1 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -36,13 +36,15 @@ inline void FillMissingLeaf(std::vector const& maybe_missing, } inline void UpdateLeafValues(std::vector* p_quantiles, std::vector const& nidx, - float learning_rate, RegTree* p_tree) { + MetaInfo const& info, float learning_rate, RegTree* p_tree) { auto& tree = *p_tree; auto& quantiles = *p_quantiles; auto const& h_node_idx = nidx; size_t n_leaf{h_node_idx.size()}; - collective::Allreduce(&n_leaf, 1); + if (info.IsRowSplit()) { + collective::Allreduce(&n_leaf, 1); + } CHECK(quantiles.empty() || quantiles.size() == n_leaf); if (quantiles.empty()) { quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); @@ -52,12 +54,16 @@ inline void UpdateLeafValues(std::vector* p_quantiles, std::vector n_valids(quantiles.size()); std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(), [](float q) { return static_cast(!std::isnan(q)); }); - collective::Allreduce(n_valids.data(), n_valids.size()); + if (info.IsRowSplit()) { + collective::Allreduce(n_valids.data(), n_valids.size()); + } // convert to 0 for all reduce std::replace_if( quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f); // use the mean value - collective::Allreduce(quantiles.data(), quantiles.size()); + if (info.IsRowSplit()) { + collective::Allreduce(quantiles.data(), quantiles.size()); + } for (size_t i = 0; i < n_leaf; ++i) { if (n_valids[i] > 0) { quantiles[i] /= static_cast(n_valids[i]); diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 0a40758bc..b6e540b24 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -35,7 +35,10 @@ class QuantileRegression : public ObjFunction { bst_target_t Targets(MetaInfo const& info) const override { auto const& alpha = param_.quantile_alpha.Get(); CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured."; - CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target is not yet supported by the quantile loss."; + if (!info.IsVerticalFederated() || collective::GetRank() == 0) { + CHECK_EQ(info.labels.Shape(1), 1) + << "Multi-target is not yet supported by the quantile loss."; + } CHECK(!alpha.empty()); // We have some placeholders for multi-target in the quantile loss. But it's not // supported as the gbtree doesn't know how to slice the gradient and there's no 3-dim @@ -167,8 +170,10 @@ class QuantileRegression : public ObjFunction { common::Mean(ctx_, *base_score, &temp); double meanq = temp(0) * sw; - collective::Allreduce(&meanq, 1); - collective::Allreduce(&sw, 1); + if (info.IsRowSplit()) { + collective::Allreduce(&meanq, 1); + collective::Allreduce(&sw, 1); + } meanq /= (sw + kRtEps); base_score->Reshape(1); base_score->Data()->Fill(meanq); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index d7999f8c1..e0dbb2edc 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -728,8 +728,10 @@ class MeanAbsoluteError : public ObjFunction { std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out), [w](float v) { return v * w; }); - collective::Allreduce(out.Values().data(), out.Values().size()); - collective::Allreduce(&w, 1); + if (info.IsRowSplit()) { + collective::Allreduce(out.Values().data(), out.Values().size()); + collective::Allreduce(&w, 1); + } if (common::CloseTo(w, 0.0)) { // Mostly for handling empty dataset test. diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index 67e322323..fe7fe6854 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -13,66 +13,91 @@ namespace xgboost { +void VerifyObjectives(size_t rows, size_t cols, std::vector const &expected_base_scores, + std::vector const &expected_models) { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; + + if (rank == 0) { + auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); + auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); + h_lower.resize(rows); + h_upper.resize(rows); + for (size_t i = 0; i < rows; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } + } + std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; + + auto i = 0; + for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { + std::unique_ptr learner{Learner::Create({sliced})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("objective", entry->name); + if (entry->name.find("quantile") != std::string::npos) { + learner->SetParam("quantile_alpha", "0.5"); + } + if (entry->name.find("multi") != std::string::npos) { + learner->SetParam("num_class", "3"); + } + learner->UpdateOneIter(0, sliced); + + Json config{Object{}}; + learner->SaveConfig(&config); + auto base_score = GetBaseScore(config); + ASSERT_EQ(base_score, expected_base_scores[i]); + + Json model{Object{}}; + learner->SaveModel(&model); + ASSERT_EQ(model, expected_models[i]); + + i++; + } +} + class FederatedLearnerTest : public BaseFederatedTest { protected: static auto constexpr kRows{16}; static auto constexpr kCols{16}; }; -void VerifyBaseScore(size_t rows, size_t cols, float expected_base_score) { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::shared_ptr Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; - std::shared_ptr sliced{Xy_->SliceCol(world_size, rank)}; - std::unique_ptr learner{Learner::Create({sliced})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", "binary:logistic"); - learner->UpdateOneIter(0, sliced); - Json config{Object{}}; - learner->SaveConfig(&config); - auto base_score = GetBaseScore(config); - ASSERT_EQ(base_score, expected_base_score); -} +TEST_F(FederatedLearnerTest, Objectives) { + std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; -void VerifyModel(size_t rows, size_t cols, Json const& expected_model) { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::shared_ptr Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; - std::shared_ptr sliced{Xy_->SliceCol(world_size, rank)}; - std::unique_ptr learner{Learner::Create({sliced})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", "binary:logistic"); - learner->UpdateOneIter(0, sliced); - Json model{Object{}}; - learner->SaveModel(&model); - ASSERT_EQ(model, expected_model); -} + auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); + auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); + h_lower.resize(kRows); + h_upper.resize(kRows); + for (size_t i = 0; i < kRows; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } -TEST_F(FederatedLearnerTest, BaseScore) { - std::shared_ptr Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; - std::unique_ptr learner{Learner::Create({Xy_})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", "binary:logistic"); - learner->UpdateOneIter(0, Xy_); - Json config{Object{}}; - learner->SaveConfig(&config); - auto base_score = GetBaseScore(config); - ASSERT_NE(base_score, ObjFunction::DefaultBaseScore()); + std::vector base_scores; + std::vector models; + for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { + std::unique_ptr learner{Learner::Create({dmat})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("objective", entry->name); + if (entry->name.find("quantile") != std::string::npos) { + learner->SetParam("quantile_alpha", "0.5"); + } + if (entry->name.find("multi") != std::string::npos) { + learner->SetParam("num_class", "3"); + } + learner->UpdateOneIter(0, dmat); + Json config{Object{}}; + learner->SaveConfig(&config); + base_scores.emplace_back(GetBaseScore(config)); - RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyBaseScore, kRows, kCols, - base_score); -} + Json model{Object{}}; + learner->SaveModel(&model); + models.emplace_back(model); + } -TEST_F(FederatedLearnerTest, Model) { - std::shared_ptr Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; - std::unique_ptr learner{Learner::Create({Xy_})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", "binary:logistic"); - learner->UpdateOneIter(0, Xy_); - Json model{Object{}}; - learner->SaveModel(&model); - - RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyModel, kRows, kCols, - std::cref(model)); + RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyObjectives, kRows, kCols, + base_scores, models); } } // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index e4313125d..537820e40 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -608,31 +608,74 @@ TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); } TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); } -void TestColumnSplitBaseScore(std::shared_ptr Xy_, float expected_base_score) { +void TestColumnSplit(std::shared_ptr dmat, std::vector const& expected_base_scores, + std::vector const& expected_models) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - std::shared_ptr sliced{Xy_->SliceCol(world_size, rank)}; - std::unique_ptr learner{Learner::Create({sliced})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", "binary:logistic"); - learner->UpdateOneIter(0, sliced); - Json config{Object{}}; - learner->SaveConfig(&config); - auto base_score = GetBaseScore(config); - ASSERT_EQ(base_score, expected_base_score); + std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; + + auto i = 0; + for (auto const* entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { + std::unique_ptr learner{Learner::Create({sliced})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("objective", entry->name); + if (entry->name.find("quantile") != std::string::npos) { + learner->SetParam("quantile_alpha", "0.5"); + } + if (entry->name.find("multi") != std::string::npos) { + learner->SetParam("num_class", "3"); + } + learner->UpdateOneIter(0, sliced); + Json config{Object{}}; + learner->SaveConfig(&config); + auto base_score = GetBaseScore(config); + ASSERT_EQ(base_score, expected_base_scores[i]); + + Json model{Object{}}; + learner->SaveModel(&model); + ASSERT_EQ(model, expected_models[i]); + + i++; + } } -TEST_F(InitBaseScore, ColumnSplit) { - std::unique_ptr learner{Learner::Create({Xy_})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", "binary:logistic"); - learner->UpdateOneIter(0, Xy_); - Json config{Object{}}; - learner->SaveConfig(&config); - auto base_score = GetBaseScore(config); - ASSERT_NE(base_score, ObjFunction::DefaultBaseScore()); +TEST(ColumnSplit, Objectives) { + auto constexpr kRows = 10, kCols = 10; + std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; + + auto& h_upper = dmat->Info().labels_upper_bound_.HostVector(); + auto& h_lower = dmat->Info().labels_lower_bound_.HostVector(); + h_lower.resize(kRows); + h_upper.resize(kRows); + for (size_t i = 0; i < kRows; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } + + std::vector base_scores; + std::vector models; + for (auto const* entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { + std::unique_ptr learner{Learner::Create({dmat})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("objective", entry->name); + if (entry->name.find("quantile") != std::string::npos) { + learner->SetParam("quantile_alpha", "0.5"); + } + if (entry->name.find("multi") != std::string::npos) { + learner->SetParam("num_class", "3"); + } + learner->UpdateOneIter(0, dmat); + + Json config{Object{}}; + learner->SaveConfig(&config); + base_scores.emplace_back(GetBaseScore(config)); + + Json model{Object{}}; + learner->SaveModel(&model); + models.emplace_back(model); + } auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplitBaseScore, Xy_, base_score); + RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplit, dmat, base_scores, models); } } // namespace xgboost From 1cf4d93246821124cea9720197099857855ba14c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 4 Apr 2023 01:29:47 +0800 Subject: [PATCH 09/12] Convert federated tests into test suite. (#9006) - Add specialization for learning to rank. --- tests/cpp/test_learner.cc | 167 +++++++++++++++++++++++++------------- 1 file changed, 111 insertions(+), 56 deletions(-) diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 537820e40..b43a0ecc1 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -1,22 +1,48 @@ -/*! - * Copyright 2017-2023 by XGBoost contributors +/** + * Copyright (c) 2017-2023, XGBoost contributors */ #include -#include -#include // ObjFunction -#include +#include // for Learner +#include // for LogCheck_NE, CHECK_NE, LogCheck_EQ +#include // for ObjFunction +#include // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR -#include // std::stof, std::string -#include -#include +#include // for equal, transform +#include // for int32_t, int64_t, uint32_t +#include // for size_t +#include // for ofstream +#include // for back_insert_iterator, back_inserter +#include // for numeric_limits +#include // for map +#include // for unique_ptr, shared_ptr, __shared_ptr_... +#include // for uniform_real_distribution +#include // for allocator, basic_string, string, oper... +#include // for thread +#include // for is_integral +#include // for pair +#include // for vector -#include "../../src/common/api_entry.h" // XGBAPIThreadLocalEntry -#include "../../src/common/io.h" -#include "../../src/common/linalg_op.h" -#include "../../src/common/random.h" -#include "filesystem.h" // dmlc::TemporaryDirectory -#include "helpers.h" -#include "xgboost/json.h" +#include "../../src/collective/communicator-inl.h" // for GetRank, GetWorldSize +#include "../../src/common/api_entry.h" // for XGBAPIThreadLocalEntry +#include "../../src/common/io.h" // for LoadSequentialFile +#include "../../src/common/linalg_op.h" // for ElementWiseTransformHost, begin, end +#include "../../src/common/random.h" // for GlobalRandom +#include "../../src/common/transform_iterator.h" // for IndexTransformIter +#include "dmlc/io.h" // for Stream +#include "dmlc/omp.h" // for omp_get_max_threads +#include "dmlc/registry.h" // for Registry +#include "filesystem.h" // for TemporaryDirectory +#include "helpers.h" // for GetBaseScore, RandomDataGenerator +#include "xgboost/base.h" // for bst_float, Args, bst_feature_t, bst_int +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, Object, get, String, IsA, opera... +#include "xgboost/linalg.h" // for Tensor, TensorView +#include "xgboost/logging.h" // for ConsoleLogger +#include "xgboost/predictor.h" // for PredictionCacheEntry +#include "xgboost/span.h" // for Span, operator!=, SpanIterator +#include "xgboost/string_view.h" // for StringView namespace xgboost { TEST(Learner, Basic) { @@ -608,74 +634,103 @@ TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); } TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); } -void TestColumnSplit(std::shared_ptr dmat, std::vector const& expected_base_scores, - std::vector const& expected_models) { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; +class TestColumnSplit : public ::testing::TestWithParam { + static auto MakeFmat(std::string const& obj) { + auto constexpr kRows = 10, kCols = 10; + auto p_fmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); + auto& h_upper = p_fmat->Info().labels_upper_bound_.HostVector(); + auto& h_lower = p_fmat->Info().labels_lower_bound_.HostVector(); + h_lower.resize(kRows); + h_upper.resize(kRows); + for (size_t i = 0; i < kRows; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } + if (obj.find("rank:") != std::string::npos) { + auto h_label = p_fmat->Info().labels.HostView(); + std::size_t k = 0; + for (auto& v : h_label) { + v = k % 2 == 0; + ++k; + } + } + return p_fmat; + }; - auto i = 0; - for (auto const* entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { + void TestBaseScore(std::string objective, float expected_base_score, Json expected_model) { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + + auto p_fmat = MakeFmat(objective); + std::shared_ptr sliced{p_fmat->SliceCol(world_size, rank)}; std::unique_ptr learner{Learner::Create({sliced})}; learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", entry->name); - if (entry->name.find("quantile") != std::string::npos) { + learner->SetParam("objective", objective); + if (objective.find("quantile") != std::string::npos) { learner->SetParam("quantile_alpha", "0.5"); } - if (entry->name.find("multi") != std::string::npos) { + if (objective.find("multi") != std::string::npos) { learner->SetParam("num_class", "3"); } learner->UpdateOneIter(0, sliced); Json config{Object{}}; learner->SaveConfig(&config); auto base_score = GetBaseScore(config); - ASSERT_EQ(base_score, expected_base_scores[i]); + ASSERT_EQ(base_score, expected_base_score); Json model{Object{}}; learner->SaveModel(&model); - ASSERT_EQ(model, expected_models[i]); - - i++; - } -} - -TEST(ColumnSplit, Objectives) { - auto constexpr kRows = 10, kCols = 10; - std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; - - auto& h_upper = dmat->Info().labels_upper_bound_.HostVector(); - auto& h_lower = dmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(kRows); - h_upper.resize(kRows); - for (size_t i = 0; i < kRows; ++i) { - h_lower[i] = 1; - h_upper[i] = 10; + ASSERT_EQ(model, expected_model); } - std::vector base_scores; - std::vector models; - for (auto const* entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr learner{Learner::Create({dmat})}; + public: + void Run(std::string objective) { + auto p_fmat = MakeFmat(objective); + std::unique_ptr learner{Learner::Create({p_fmat})}; learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", entry->name); - if (entry->name.find("quantile") != std::string::npos) { + learner->SetParam("objective", objective); + if (objective.find("quantile") != std::string::npos) { learner->SetParam("quantile_alpha", "0.5"); } - if (entry->name.find("multi") != std::string::npos) { + if (objective.find("multi") != std::string::npos) { learner->SetParam("num_class", "3"); } - learner->UpdateOneIter(0, dmat); + learner->UpdateOneIter(0, p_fmat); Json config{Object{}}; learner->SaveConfig(&config); - base_scores.emplace_back(GetBaseScore(config)); Json model{Object{}}; learner->SaveModel(&model); - models.emplace_back(model); - } - auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplit, dmat, base_scores, models); + auto constexpr kWorldSize{3}; + auto call = [this, &objective](auto&... args) { TestBaseScore(objective, args...); }; + auto score = GetBaseScore(config); + RunWithInMemoryCommunicator(kWorldSize, call, score, model); + } +}; + +TEST_P(TestColumnSplit, Objective) { + std::string objective = GetParam(); + this->Run(objective); } + +auto MakeValues() { + auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List(); + std::vector names; + std::transform(list.cbegin(), list.cend(), std::back_inserter(names), + [](auto const* entry) { return entry->name; }); + return names; +} + +INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, ::testing::ValuesIn(MakeValues()), + [](const ::testing::TestParamInfo& info) { + auto name = std::string{info.param}; + // Name must be a valid c++ symbol + auto it = std::find(name.cbegin(), name.cend(), ':'); + if (it != name.cend()) { + name[std::distance(name.cbegin(), it)] = '_'; + } + return name; + }); } // namespace xgboost From ebd64f6e22825ed098e0e31fbd49facd3780630f Mon Sep 17 00:00:00 2001 From: Sarah Charlotte Johnson Date: Thu, 6 Apr 2023 10:09:15 -0700 Subject: [PATCH 10/12] [doc] Update Dask deployment options (#9008) --- doc/tutorials/dask.rst | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index c66c6131f..c33a90c81 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -66,7 +66,7 @@ on a dask cluster: Here we first create a cluster in single-node mode with :py:class:`distributed.LocalCluster`, then connect a :py:class:`distributed.Client` to this cluster, setting up an environment for later computation. Notice that the cluster -construction is guared by ``__name__ == "__main__"``, which is necessary otherwise there +construction is guarded by ``__name__ == "__main__"``, which is necessary otherwise there might be obscure errors. We then create a :py:class:`xgboost.dask.DaskDMatrix` object and pass it to @@ -226,13 +226,9 @@ collection. Working with other clusters *************************** -``LocalCluster`` is mostly used for testing. In real world applications some other -clusters might be preferred. Examples are like ``LocalCUDACluster`` for single node -multi-GPU instance, manually launched cluster by using command line utilities like -``dask-worker`` from ``distributed`` for not yet automated environments. Some special -clusters like ``KubeCluster`` from ``dask-kubernetes`` package are also possible. The -dask API in xgboost is orthogonal to the cluster type and can be used with any of them. A -typical testing workflow with ``KubeCluster`` looks like this: +Using Dask's ``LocalCluster`` is convenient for getting started quickly on a single-machine. Once you're ready to scale your work, though, there are a number of ways to deploy Dask on a distributed cluster. You can use `Dask-CUDA `_, for example, for GPUs and you can use Dask Cloud Provider to `deploy Dask clusters in the cloud `_. See the `Dask documentation for a more comprehensive list `_. + +In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernetes `_: .. code-block:: python @@ -272,8 +268,7 @@ typical testing workflow with ``KubeCluster`` looks like this: # main function will connect to that cluster and start training xgboost model. main() - -However, these clusters might have their subtle differences like network configuration, or +Different cluster classes might have subtle differences like network configuration, or specific cluster implementation might contains bugs that we are not aware of. Open an issue if such case is found and there's no documentation on how to resolve it in that cluster implementation. From 2c8d735cb31cdbc237128fee322b6699171268ec Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 11 Apr 2023 00:17:34 +0800 Subject: [PATCH 11/12] Fix tests with pandas 2.0. (#9014) * Fix tests with pandas 2.0. - `is_categorical` is replaced by `is_categorical_dtype`. - one hot encoding returns boolean type instead of integer type. --- tests/python/test_basic_models.py | 2 +- tests/python/test_with_pandas.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index f9d6f37e1..d76205593 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -664,7 +664,7 @@ class TestModels: y = rng.randn(rows) feature_names = ["test_feature_" + str(i) for i in range(cols)] X_pd = pd.DataFrame(X, columns=feature_names) - X_pd.iloc[:, 3] = X_pd.iloc[:, 3].astype(np.int32) + X_pd[f"test_feature_{3}"] = X_pd.iloc[:, 3].astype(np.int32) Xy = xgb.DMatrix(X_pd, y) assert Xy.feature_types[3] == "int" diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index e5783b24d..07295eb6c 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -77,7 +77,10 @@ class TestPandas: np.testing.assert_array_equal(result, exp) dm = xgb.DMatrix(dummies) assert dm.feature_names == ['B', 'A_X', 'A_Y', 'A_Z'] - assert dm.feature_types == ['int', 'int', 'int', 'int'] + if int(pd.__version__[0]) >= 2: + assert dm.feature_types == ['int', 'i', 'i', 'i'] + else: + assert dm.feature_types == ['int', 'int', 'int', 'int'] assert dm.num_row() == 3 assert dm.num_col() == 4 @@ -298,14 +301,14 @@ class TestPandas: @pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix]) def test_nullable_type(self, DMatrixT) -> None: - from pandas.api.types import is_categorical + from pandas.api.types import is_categorical_dtype for orig, df in pd_dtypes(): if hasattr(df.dtypes, "__iter__"): - enable_categorical = any(is_categorical for dtype in df.dtypes) + enable_categorical = any(is_categorical_dtype for dtype in df.dtypes) else: # series - enable_categorical = is_categorical(df.dtype) + enable_categorical = is_categorical_dtype(df.dtype) f0_orig = orig[orig.columns[0]] if isinstance(orig, pd.DataFrame) else orig f0 = df[df.columns[0]] if isinstance(df, pd.DataFrame) else df From fe9dff339c5f5840bdc461e096f4e9733dafb5be Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 11 Apr 2023 09:52:55 +0800 Subject: [PATCH 12/12] Convert federated learner test into test suite. (#9018) * Convert federated learner test into test suite. - Add specialization to learning to rank. --- tests/cpp/objective_helpers.h | 32 ++++ tests/cpp/plugin/helpers.h | 29 ++-- tests/cpp/plugin/test_federated_adapter.cu | 4 +- .../cpp/plugin/test_federated_communicator.cc | 10 +- tests/cpp/plugin/test_federated_data.cc | 6 +- tests/cpp/plugin/test_federated_learner.cc | 147 ++++++++++-------- tests/cpp/plugin/test_federated_server.cc | 8 +- tests/cpp/test_learner.cc | 20 +-- 8 files changed, 152 insertions(+), 104 deletions(-) create mode 100644 tests/cpp/objective_helpers.h diff --git a/tests/cpp/objective_helpers.h b/tests/cpp/objective_helpers.h new file mode 100644 index 000000000..b26470746 --- /dev/null +++ b/tests/cpp/objective_helpers.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2023, XGBoost contributors + */ +#include // for Registry +#include +#include // for ObjFunctionReg + +#include // for transform +#include // for back_insert_iterator, back_inserter +#include // for string +#include // for vector + +namespace xgboost { +inline auto MakeObjNamesForTest() { + auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List(); + std::vector names; + std::transform(list.cbegin(), list.cend(), std::back_inserter(names), + [](auto const* entry) { return entry->name; }); + return names; +} + +template +inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo& info) { + auto name = std::string{info.param}; + // Name must be a valid c++ symbol + auto it = std::find(name.cbegin(), name.cend(), ':'); + if (it != name.cend()) { + name[std::distance(name.cbegin(), it)] = '_'; + } + return name; +}; +} // namespace xgboost diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index 7edfc5efc..10ba68b49 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -8,6 +8,7 @@ #include #include +#include // for thread, sleep_for #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" @@ -33,13 +34,17 @@ inline std::string GetServerAddress() { namespace xgboost { -class BaseFederatedTest : public ::testing::Test { - protected: - void SetUp() override { +class ServerForTest { + std::string server_address_; + std::unique_ptr server_thread_; + std::unique_ptr server_; + + public: + explicit ServerForTest(std::int32_t world_size) { server_address_ = GetServerAddress(); - server_thread_.reset(new std::thread([this] { + server_thread_.reset(new std::thread([this, world_size] { grpc::ServerBuilder builder; - xgboost::federated::FederatedService service{kWorldSize}; + xgboost::federated::FederatedService service{world_size}; builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); builder.RegisterService(&service); server_ = builder.BuildAndStart(); @@ -47,15 +52,21 @@ class BaseFederatedTest : public ::testing::Test { })); } - void TearDown() override { + ~ServerForTest() { server_->Shutdown(); server_thread_->join(); } + auto Address() const { return server_address_; } +}; + +class BaseFederatedTest : public ::testing::Test { + protected: + void SetUp() override { server_ = std::make_unique(kWorldSize); } + + void TearDown() override { server_.reset(nullptr); } static int const kWorldSize{3}; - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; + std::unique_ptr server_; }; template diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index c4816ff18..a5e901f26 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -29,7 +29,7 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back([rank, server_address = server_address_] { + threads.emplace_back([rank, server_address = server_->Address()] { FederatedCommunicator comm{kWorldSize, rank, server_address}; // Assign device 0 to all workers, since we run gtest in a single-GPU machine DeviceCommunicatorAdapter adapter{0, &comm}; @@ -52,7 +52,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { TEST_F(FederatedAdapterTest, DeviceAllGatherV) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back([rank, server_address = server_address_] { + threads.emplace_back([rank, server_address = server_->Address()] { FederatedCommunicator comm{kWorldSize, rank, server_address}; // Assign device 0 to all workers, since we run gtest in a single-GPU machine DeviceCommunicatorAdapter adapter{0, &comm}; diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 5177187c5..340849606 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -92,7 +92,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { config["federated_server_address"] = server_address; config["federated_world_size"] = std::string("1"); config["federated_rank"] = Integer(0); - auto *comm = FederatedCommunicator::Create(config); + FederatedCommunicator::Create(config); }; EXPECT_THROW(construct(), dmlc::Error); } @@ -104,7 +104,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { config["federated_server_address"] = server_address; config["federated_world_size"] = 1; config["federated_rank"] = std::string("0"); - auto *comm = FederatedCommunicator::Create(config); + FederatedCommunicator::Create(config); }; EXPECT_THROW(construct(), dmlc::Error); } @@ -125,7 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { TEST_F(FederatedCommunicatorTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_); + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address()); } for (auto &thread : threads) { thread.join(); @@ -135,7 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) { TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_); + threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address()); } for (auto &thread : threads) { thread.join(); @@ -145,7 +145,7 @@ TEST_F(FederatedCommunicatorTest, Allreduce) { TEST_F(FederatedCommunicatorTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_); + threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address()); } for (auto &thread : threads) { thread.join(); diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index ed877131e..c6efb84d5 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -38,8 +38,8 @@ void VerifyLoadUri() { auto index = 0; int offsets[] = {0, 8, 17}; int offset = offsets[rank]; - for (auto row = 0; row < kRows; row++) { - for (auto col = 0; col < kCols; col++) { + for (std::size_t row = 0; row < kRows; row++) { + for (std::size_t col = 0; col < kCols; col++) { EXPECT_EQ(entries[index].index, col + offset); index++; } @@ -48,6 +48,6 @@ void VerifyLoadUri() { } TEST_F(FederatedDataTest, LoadUri) { - RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri); + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri); } } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index fe7fe6854..85d0a2b7d 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -8,13 +8,34 @@ #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" +#include "../../../src/common/linalg_op.h" #include "../helpers.h" +#include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator #include "helpers.h" namespace xgboost { +namespace { +auto MakeModel(std::string objective, std::shared_ptr dmat) { + std::unique_ptr learner{Learner::Create({dmat})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("objective", objective); + if (objective.find("quantile") != std::string::npos) { + learner->SetParam("quantile_alpha", "0.5"); + } + if (objective.find("multi") != std::string::npos) { + learner->SetParam("num_class", "3"); + } + learner->UpdateOneIter(0, dmat); + Json config{Object{}}; + learner->SaveConfig(&config); -void VerifyObjectives(size_t rows, size_t cols, std::vector const &expected_base_scores, - std::vector const &expected_models) { + Json model{Object{}}; + learner->SaveModel(&model); + return model; +} + +void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model, + std::string objective) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; @@ -28,76 +49,72 @@ void VerifyObjectives(size_t rows, size_t cols, std::vector const &expect h_lower[i] = 1; h_upper[i] = 10; } + + if (objective.find("rank:") != std::string::npos) { + auto h_label = dmat->Info().labels.HostView(); + std::size_t k = 0; + for (auto &v : h_label) { + v = k % 2 == 0; + ++k; + } + } } std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; - auto i = 0; - for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr learner{Learner::Create({sliced})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", entry->name); - if (entry->name.find("quantile") != std::string::npos) { - learner->SetParam("quantile_alpha", "0.5"); - } - if (entry->name.find("multi") != std::string::npos) { - learner->SetParam("num_class", "3"); - } - learner->UpdateOneIter(0, sliced); - - Json config{Object{}}; - learner->SaveConfig(&config); - auto base_score = GetBaseScore(config); - ASSERT_EQ(base_score, expected_base_scores[i]); - - Json model{Object{}}; - learner->SaveModel(&model); - ASSERT_EQ(model, expected_models[i]); - - i++; - } + auto model = MakeModel(objective, sliced); + auto base_score = GetBaseScore(model); + ASSERT_EQ(base_score, expected_base_score); + ASSERT_EQ(model, expected_model); } +} // namespace + +class FederatedLearnerTest : public ::testing::TestWithParam { + std::unique_ptr server_; + static int const kWorldSize{3}; -class FederatedLearnerTest : public BaseFederatedTest { protected: - static auto constexpr kRows{16}; - static auto constexpr kCols{16}; + void SetUp() override { server_ = std::make_unique(kWorldSize); } + void TearDown() override { server_.reset(nullptr); } + + void Run(std::string objective) { + static auto constexpr kRows{16}; + static auto constexpr kCols{16}; + + std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; + + auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); + auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); + h_lower.resize(kRows); + h_upper.resize(kRows); + for (size_t i = 0; i < kRows; ++i) { + h_lower[i] = 1; + h_upper[i] = 10; + } + if (objective.find("rank:") != std::string::npos) { + auto h_label = dmat->Info().labels.HostView(); + std::size_t k = 0; + for (auto &v : h_label) { + v = k % 2 == 0; + ++k; + } + } + + auto model = MakeModel(objective, dmat); + auto score = GetBaseScore(model); + + RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols, + score, model, objective); + } }; -TEST_F(FederatedLearnerTest, Objectives) { - std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; - - auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); - auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(kRows); - h_upper.resize(kRows); - for (size_t i = 0; i < kRows; ++i) { - h_lower[i] = 1; - h_upper[i] = 10; - } - - std::vector base_scores; - std::vector models; - for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr learner{Learner::Create({dmat})}; - learner->SetParam("tree_method", "approx"); - learner->SetParam("objective", entry->name); - if (entry->name.find("quantile") != std::string::npos) { - learner->SetParam("quantile_alpha", "0.5"); - } - if (entry->name.find("multi") != std::string::npos) { - learner->SetParam("num_class", "3"); - } - learner->UpdateOneIter(0, dmat); - Json config{Object{}}; - learner->SaveConfig(&config); - base_scores.emplace_back(GetBaseScore(config)); - - Json model{Object{}}; - learner->SaveModel(&model); - models.emplace_back(model); - } - - RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyObjectives, kRows, kCols, - base_scores, models); +TEST_P(FederatedLearnerTest, Objective) { + std::string objective = GetParam(); + this->Run(objective); } + +INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest, + ::testing::ValuesIn(MakeObjNamesForTest()), + [](const ::testing::TestParamInfo &info) { + return ObjTestNameGenerator(info); + }); } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 79e06bf5f..4dd2f3c40 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -73,7 +73,7 @@ class FederatedServerTest : public BaseFederatedTest { TEST_F(FederatedServerTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); @@ -83,7 +83,7 @@ TEST_F(FederatedServerTest, Allgather) { TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); @@ -93,7 +93,7 @@ TEST_F(FederatedServerTest, Allreduce) { TEST_F(FederatedServerTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); @@ -103,7 +103,7 @@ TEST_F(FederatedServerTest, Broadcast) { TEST_F(FederatedServerTest, Mixture) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_address_); + threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address()); } for (auto& thread : threads) { thread.join(); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index b43a0ecc1..91e8070c2 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -33,6 +33,7 @@ #include "dmlc/registry.h" // for Registry #include "filesystem.h" // for TemporaryDirectory #include "helpers.h" // for GetBaseScore, RandomDataGenerator +#include "objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator #include "xgboost/base.h" // for bst_float, Args, bst_feature_t, bst_int #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo, DataType @@ -715,22 +716,9 @@ TEST_P(TestColumnSplit, Objective) { this->Run(objective); } -auto MakeValues() { - auto list = ::dmlc::Registry<::xgboost::ObjFunctionReg>::List(); - std::vector names; - std::transform(list.cbegin(), list.cend(), std::back_inserter(names), - [](auto const* entry) { return entry->name; }); - return names; -} - -INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, ::testing::ValuesIn(MakeValues()), +INSTANTIATE_TEST_SUITE_P(ColumnSplitObjective, TestColumnSplit, + ::testing::ValuesIn(MakeObjNamesForTest()), [](const ::testing::TestParamInfo& info) { - auto name = std::string{info.param}; - // Name must be a valid c++ symbol - auto it = std::find(name.cbegin(), name.cend(), ':'); - if (it != name.cend()) { - name[std::distance(name.cbegin(), it)] = '_'; - } - return name; + return ObjTestNameGenerator(info); }); } // namespace xgboost