From 6967ef726723e906a2f31bf6d1cdcd5c57ec59f7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 24 Jan 2022 04:35:49 +0800 Subject: [PATCH] Remove `omp_get_max_threads` in objective. (#7589) --- include/xgboost/objective.h | 2 +- src/common/transform.h | 45 ++++++++++-------------- src/objective/aft_obj.cu | 13 +++---- src/objective/hinge.cu | 8 ++--- src/objective/multiclass_obj.cu | 12 +++---- src/objective/objective.cc | 2 +- src/objective/rank_obj.cu | 4 +-- src/objective/regression_obj.cu | 31 ++++++++-------- src/tree/split_evaluator.h | 2 +- tests/cpp/common/test_transform_range.cc | 18 ++++++---- tests/cpp/common/test_transform_range.cu | 13 ++++--- 11 files changed, 76 insertions(+), 74 deletions(-) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 40db951b4..44dc46ddc 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -25,7 +25,7 @@ namespace xgboost { /*! \brief interface of objective function */ class ObjFunction : public Configurable { protected: - GenericParameter const* tparam_; + GenericParameter const* ctx_; public: /*! \brief virtual destructor */ diff --git a/src/common/transform.h b/src/common/transform.h index 79c97391e..a7b96766c 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -1,22 +1,21 @@ /*! - * Copyright 2018 XGBoost contributors + * Copyright 2018-2022 XGBoost contributors */ #ifndef XGBOOST_COMMON_TRANSFORM_H_ #define XGBOOST_COMMON_TRANSFORM_H_ -#include #include - +#include #include + +#include // enable_if #include #include -#include // enable_if - -#include "xgboost/host_device_vector.h" -#include "xgboost/span.h" #include "common.h" #include "threading_utils.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/span.h" #if defined (__CUDACC__) #include "device_helpers.cuh" @@ -61,10 +60,8 @@ class Transform { template struct Evaluator { public: - Evaluator(Functor func, Range range, int device, bool shard) : - func_(func), range_{std::move(range)}, - shard_{shard}, - device_{device} {} + Evaluator(Functor func, Range range, int32_t n_threads, int32_t device_idx) + : func_(func), range_{std::move(range)}, n_threads_{n_threads}, device_{device_idx} {} /*! * \brief Evaluate the functor with input pointers to HostDeviceVector. @@ -134,9 +131,7 @@ class Transform { template ::type* = nullptr, typename... HDV> void LaunchCUDA(Functor _func, HDV*... _vectors) const { - if (shard_) { - UnpackShard(device_, _vectors...); - } + UnpackShard(device_, _vectors...); size_t range_size = *range_.end() - *range_.begin(); @@ -167,12 +162,10 @@ class Transform { #endif // defined(__CUDACC__) template - void LaunchCPU(Functor func, HDV*... vectors) const { + void LaunchCPU(Functor func, HDV *...vectors) const { omp_ulong end = static_cast(*(range_.end())); SyncHost(vectors...); - ParallelFor(end, [&](omp_ulong idx) { - func(idx, UnpackHDV(vectors)...); - }); + ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, UnpackHDV(vectors)...); }); } private: @@ -180,9 +173,8 @@ class Transform { Functor func_; /*! \brief Range object specifying parallel threads index range. */ Range range_; - /*! \brief Whether sharding for vectors is required. */ - bool shard_; - int device_; + int32_t n_threads_; + int32_t device_; }; public: @@ -195,14 +187,13 @@ class Transform { * \param func A callable object, accepting a size_t thread index, * followed by a set of Span classes. * \param range Range object specifying parallel threads index range. - * \param device Specify GPU to use. - * \param shard Whether Shard for HostDeviceVector is needed. + * \param n_threads Number of CPU threads + * \param device_idx GPU device ordinal */ template - static Evaluator Init(Functor func, Range const range, - int device, - bool const shard = true) { - return Evaluator {func, std::move(range), device, shard}; + static Evaluator Init(Functor func, Range const range, int32_t n_threads, + int32_t device_idx) { + return Evaluator{func, std::move(range), n_threads, device_idx}; } }; diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 882402a0c..0e2d9290f 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2020 by Contributors + * Copyright 2019-2022 by Contributors * \file aft_obj.cu * \brief Definition of AFT loss for survival analysis. * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking @@ -65,7 +65,7 @@ class AFTObj : public ObjFunction { const bst_float w = is_null_weight ? 1.0f : _weights[_idx]; _out_gpair[_idx] = GradientPair(grad * w, hess * w); }, - common::Range{0, static_cast(ndata)}, device).Eval( + common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( out_gpair, &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, &info.weights_); } @@ -78,7 +78,7 @@ class AFTObj : public ObjFunction { CHECK_EQ(info.labels_lower_bound_.Size(), ndata); CHECK_EQ(info.labels_upper_bound_.Size(), ndata); out_gpair->Resize(ndata); - const int device = tparam_->gpu_id; + const int device = ctx_->gpu_id; const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale; const bool is_null_weight = info.weights_.Size() == 0; if (!is_null_weight) { @@ -108,10 +108,11 @@ class AFTObj : public ObjFunction { // Trees give us a prediction in log scale, so exponentiate common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { - _preds[_idx] = exp(_preds[_idx]); - }, common::Range{0, static_cast(io_preds->Size())}, + _preds[_idx] = exp(_preds[_idx]); + }, + common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), io_preds->DeviceIdx()) - .Eval(io_preds); + .Eval(io_preds); } void EvalTransform(HostDeviceVector *io_preds) override { diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 09b379804..e1f0df74d 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2018-2019 by Contributors + * Copyright 2018-2022 by XGBoost Contributors * \file hinge.cc * \brief Provides an implementation of the hinge loss function * \author Henry Gouk @@ -65,8 +65,8 @@ class HingeObj : public ObjFunction { } _out_gpair[_idx] = GradientPair(g, h); }, - common::Range{0, static_cast(ndata)}, - tparam_->gpu_id).Eval( + common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), + ctx_->gpu_id).Eval( out_gpair, &preds, info.labels.Data(), &info.weights_); } @@ -75,7 +75,7 @@ class HingeObj : public ObjFunction { [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; }, - common::Range{0, static_cast(io_preds->Size()), 1}, + common::Range{0, static_cast(io_preds->Size()), 1}, this->ctx_->Threads(), io_preds->DeviceIdx()) .Eval(io_preds); } diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index a3f01b419..4b912a817 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2018 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file multi_class.cc * \brief Definition of multi-class classification objectives. * \author Tianqi Chen @@ -68,7 +68,7 @@ class SoftmaxMultiClassObj : public ObjFunction { const int nclass = param_.num_class; const auto ndata = static_cast(preds.Size() / nclass); - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; out_gpair->SetDevice(device); info.labels.SetDevice(device); info.weights_.SetDevice(device); @@ -114,7 +114,7 @@ class SoftmaxMultiClassObj : public ObjFunction { p = label == k ? p - 1.0f : p; gpair[idx * nclass + k] = GradientPair(p * wt, h); } - }, common::Range{0, ndata}, device, false) + }, common::Range{0, ndata}, ctx_->Threads(), device) .Eval(out_gpair, info.labels.Data(), &preds, &info.weights_, &label_correct_); std::vector& label_correct_h = label_correct_.HostVector(); @@ -146,8 +146,8 @@ class SoftmaxMultiClassObj : public ObjFunction { _preds.subspan(_idx * nclass, nclass); common::Softmax(point.begin(), point.end()); }, - common::Range{0, ndata}, device) - .Eval(io_preds); + common::Range{0, ndata}, this->ctx_->Threads(), device) + .Eval(io_preds); } else { io_preds->SetDevice(device); HostDeviceVector max_preds; @@ -162,7 +162,7 @@ class SoftmaxMultiClassObj : public ObjFunction { common::FindMaxIndex(point.cbegin(), point.cend()) - point.cbegin(); }, - common::Range{0, ndata}, device, false) + common::Range{0, ndata}, this->ctx_->Threads(), device) .Eval(io_preds, &max_preds); io_preds->Resize(max_preds.Size()); io_preds->Copy(max_preds); diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 8f65cec0d..5991e918d 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -27,7 +27,7 @@ ObjFunction* ObjFunction::Create(const std::string& name, GenericParameter const << ss.str(); } auto pobj = (e->body)(); - pobj->tparam_ = tparam; + pobj->ctx_ = tparam; return pobj; } diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 9f4d86aaf..75acde3ce 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -773,7 +773,7 @@ class LambdaRankObj : public ObjFunction { #if defined(__CUDACC__) // Check if we have a GPU assignment; else, revert back to CPU - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; if (device >= 0) { ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr); } else { @@ -909,7 +909,7 @@ class LambdaRankObj : public ObjFunction { const std::vector &gptr) { LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU."; - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; dh::safe_cuda(cudaSetDevice(device)); bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 63a3f881e..a07de8e44 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file regression_obj.cu * \brief Definition of single-value regression and classification objectives. * \author Tianqi Chen, Kailong Chen @@ -70,7 +70,7 @@ class RegLossObj : public ObjFunction { << "Loss: " << Loss::Name(); size_t const ndata = preds.Size(); out_gpair->Resize(ndata); - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag bool is_null_weight = info.weights_.Size() == 0; @@ -82,7 +82,7 @@ class RegLossObj : public ObjFunction { additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[2] = is_null_weight; - const size_t nthreads = tparam_->Threads(); + const size_t nthreads = ctx_->Threads(); bool on_device = device >= 0; // On CPU we run the transformation each thread processing a contigious block of data // for better performance. @@ -121,7 +121,7 @@ class RegLossObj : public ObjFunction { Loss::SecondOrderGradient(p, label) * w); } }, - common::Range{0, static_cast(n_data_blocks)}, device) + common::Range{0, static_cast(n_data_blocks)}, nthreads, device) .Eval(&additional_input_, out_gpair, &preds, info.labels.Data(), &info.weights_); @@ -140,7 +140,8 @@ class RegLossObj : public ObjFunction { common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = Loss::PredTransform(_preds[_idx]); - }, common::Range{0, static_cast(io_preds->Size())}, + }, + common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), io_preds->DeviceIdx()) .Eval(io_preds); } @@ -228,7 +229,7 @@ class PoissonRegression : public ObjFunction { CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; size_t const ndata = preds.Size(); out_gpair->Resize(ndata); - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; label_correct_.Resize(1); label_correct_.Fill(1); @@ -254,7 +255,7 @@ class PoissonRegression : public ObjFunction { _out_gpair[_idx] = GradientPair{(expf(p) - y) * w, expf(p + max_delta_step) * w}; }, - common::Range{0, static_cast(ndata)}, device).Eval( + common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -269,7 +270,7 @@ class PoissonRegression : public ObjFunction { [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, - common::Range{0, static_cast(io_preds->Size())}, + common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), io_preds->DeviceIdx()) .Eval(io_preds); } @@ -381,7 +382,7 @@ class CoxRegression : public ObjFunction { void PredTransform(HostDeviceVector *io_preds) const override { std::vector &preds = io_preds->HostVector(); const long ndata = static_cast(preds.size()); // NOLINT(*) - common::ParallelFor(ndata, [&](long j) { // NOLINT(*) + common::ParallelFor(ndata, ctx_->Threads(), [&](long j) { // NOLINT(*) preds[j] = std::exp(preds[j]); }); } @@ -423,7 +424,7 @@ class GammaRegression : public ObjFunction { CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const size_t ndata = preds.Size(); - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; out_gpair->Resize(ndata); label_correct_.Resize(1); label_correct_.Fill(1); @@ -448,7 +449,7 @@ class GammaRegression : public ObjFunction { } _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); }, - common::Range{0, static_cast(ndata)}, device).Eval( + common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host @@ -464,7 +465,7 @@ class GammaRegression : public ObjFunction { [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, - common::Range{0, static_cast(io_preds->Size())}, + common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), io_preds->DeviceIdx()) .Eval(io_preds); } @@ -525,7 +526,7 @@ class TweedieRegression : public ObjFunction { const size_t ndata = preds.Size(); out_gpair->Resize(ndata); - auto device = tparam_->gpu_id; + auto device = ctx_->gpu_id; label_correct_.Resize(1); label_correct_.Fill(1); @@ -555,7 +556,7 @@ class TweedieRegression : public ObjFunction { std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p); _out_gpair[_idx] = GradientPair(grad * w, hess * w); }, - common::Range{0, static_cast(ndata), 1}, device) + common::Range{0, static_cast(ndata), 1}, this->ctx_->Threads(), device) .Eval(&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host @@ -571,7 +572,7 @@ class TweedieRegression : public ObjFunction { [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, - common::Range{0, static_cast(io_preds->Size())}, + common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), io_preds->DeviceIdx()) .Eval(io_preds); } diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 4fdf70145..5030fcb6d 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -176,7 +176,7 @@ class TreeEvaluator { lower[rightid] = mid; } }, - common::Range(0, 1), device_, false) + common::Range(0, 1), 1, device_) .Eval(&lower_bounds_, &upper_bounds_, &monotone_); } }; diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index a740c278c..97103d8f3 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2018-2022 by XGBoost Contributors + */ #include #include #include @@ -42,7 +45,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) { out_vec.Fill(0); Transform<>::Init(TestTransformRange{}, - Range{0, static_cast(size)}, + Range{0, static_cast(size)}, common::OmpGetNumThreads(0), TRANSFORM_GPU) .Eval(&out_vec, &in_vec); std::vector res = out_vec.HostVector(); @@ -55,11 +58,14 @@ TEST(TransformDeathTest, Exception) { size_t const kSize {16}; std::vector h_in(kSize); const HostDeviceVector in_vec{h_in, -1}; - EXPECT_DEATH({ - Transform<>::Init([](size_t idx, common::Span _in) { _in[idx + 1]; }, - Range(0, static_cast(kSize)), -1) - .Eval(&in_vec); - }, ""); + EXPECT_DEATH( + { + Transform<>::Init([](size_t idx, common::Span _in) { _in[idx + 1]; }, + Range(0, static_cast(kSize)), + common::OmpGetNumThreads(0), -1) + .Eval(&in_vec); + }, + ""); } #endif diff --git a/tests/cpp/common/test_transform_range.cu b/tests/cpp/common/test_transform_range.cu index 5e1b2b024..c16093127 100644 --- a/tests/cpp/common/test_transform_range.cu +++ b/tests/cpp/common/test_transform_range.cu @@ -1,4 +1,7 @@ -// This converts all tests from CPU to GPU. +/*! + * Copyright 2018-2022 by XGBoost Contributors + * \brief This converts all tests from CPU to GPU. + */ #include "test_transform_range.cc" #if defined(XGBOOST_USE_NCCL) @@ -22,13 +25,13 @@ TEST(Transform, MGPU_SpecifiedGpuId) { // NOLINT const HostDeviceVector in_vec {h_in, device}; HostDeviceVector out_vec {h_out, device}; - ASSERT_NO_THROW( - Transform<>::Init(TestTransformRange{}, Range{0, size}, device) - .Eval(&out_vec, &in_vec)); + ASSERT_NO_THROW(Transform<>::Init(TestTransformRange{}, Range{0, size}, + common::OmpGetNumThreads(0), device) + .Eval(&out_vec, &in_vec)); std::vector res = out_vec.HostVector(); ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); } } // namespace common } // namespace xgboost -#endif \ No newline at end of file +#endif