From 28bb01aa227193727f8ef41f7560287468a488aa Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 7 Feb 2023 03:11:53 +0800 Subject: [PATCH] Extract optional weight. (#8747) - Extract optional weight from coommon.h to reduce dependency on this header. - Add test. --- src/common/common.h | 15 ++--------- src/common/optional_weight.h | 33 ++++++++++++++++++++++++ src/common/quantile.h | 1 + src/common/stats.cu | 2 +- src/common/stats.h | 1 + src/metric/auc.cc | 1 + src/metric/auc.cu | 15 ++++++----- src/metric/elementwise_metric.cu | 1 + src/objective/regression_obj.cu | 3 ++- tests/cpp/common/test_optional_weight.cc | 24 +++++++++++++++++ 10 files changed, 74 insertions(+), 22 deletions(-) create mode 100644 src/common/optional_weight.h create mode 100644 tests/cpp/common/test_optional_weight.cc diff --git a/src/common/common.h b/src/common/common.h index 438669e5f..5ac764817 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2015-2022 by XGBoost Contributors +/** + * Copyright 2015-2023 by XGBoost Contributors * \file common.h * \brief Common utilities */ @@ -199,17 +199,6 @@ std::vector ArgSort(Container const &array, Comp comp = std::less{}) { return result; } -struct OptionalWeights { - Span weights; - float dft{1.0f}; // fixme: make this compile time constant - - explicit OptionalWeights(Span w) : weights{w} {} - explicit OptionalWeights(float w) : dft{w} {} - - XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } - auto Empty() const { return weights.empty(); } -}; - /** * Last index of a group in a CSR style of index pointer. */ diff --git a/src/common/optional_weight.h b/src/common/optional_weight.h new file mode 100644 index 000000000..e929aecb5 --- /dev/null +++ b/src/common/optional_weight.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_OPTIONAL_WEIGHT_H_ +#define XGBOOST_COMMON_OPTIONAL_WEIGHT_H_ +#include "xgboost/base.h" // XGBOOST_DEVICE +#include "xgboost/context.h" // Context +#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "xgboost/span.h" // Span + +namespace xgboost { +namespace common { +struct OptionalWeights { + Span weights; + float dft{1.0f}; // fixme: make this compile time constant + + explicit OptionalWeights(Span w) : weights{w} {} + explicit OptionalWeights(float w) : dft{w} {} + + XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } + auto Empty() const { return weights.empty(); } +}; + +inline OptionalWeights MakeOptionalWeights(Context const* ctx, + HostDeviceVector const& weights) { + if (ctx->IsCUDA()) { + weights.SetDevice(ctx->gpu_id); + } + return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()}; +} +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_ diff --git a/src/common/quantile.h b/src/common/quantile.h index a9955d2a0..c8dcf6ada 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -19,6 +19,7 @@ #include "categorical.h" #include "common.h" +#include "optional_weight.h" // OptionalWeights #include "threading_utils.h" #include "timer.h" diff --git a/src/common/stats.cu b/src/common/stats.cu index b06b20268..ab4871776 100644 --- a/src/common/stats.cu +++ b/src/common/stats.cu @@ -6,9 +6,9 @@ #include // size_t -#include "common.h" // common::OptionalWeights #include "cuda_context.cuh" // CUDAContext #include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend +#include "optional_weight.h" // common::OptionalWeights #include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile #include "xgboost/base.h" // XGBOOST_DEVICE #include "xgboost/context.h" // Context diff --git a/src/common/stats.h b/src/common/stats.h index 1fe344e03..5f7892cb5 100644 --- a/src/common/stats.h +++ b/src/common/stats.h @@ -9,6 +9,7 @@ #include #include "common.h" // AssertGPUSupport, OptionalWeights +#include "optional_weight.h" // OptionalWeights #include "transform_iterator.h" // MakeIndexTransformIter #include "xgboost/context.h" // Context #include "xgboost/linalg.h" diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 89414cb69..8a2e2199e 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -15,6 +15,7 @@ #include #include "../common/math.h" +#include "../common/optional_weight.h" // OptionalWeights #include "xgboost/host_device_vector.h" #include "xgboost/linalg.h" #include "xgboost/metric.h" diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 646e670b6..7673e7b29 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -1,21 +1,22 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2023 by XGBoost Contributors */ #include -#include #include #include +#include #include #include -#include #include +#include -#include "xgboost/span.h" -#include "xgboost/data.h" -#include "auc.h" #include "../collective/device_communicator.cuh" +#include "../common/optional_weight.h" // OptionalWeights #include "../common/ranking_utils.cuh" +#include "auc.h" +#include "xgboost/data.h" +#include "xgboost/span.h" namespace xgboost { namespace metric { diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 5c6e00d68..5f55d85e7 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -14,6 +14,7 @@ #include "../collective/communicator-inl.h" #include "../common/common.h" #include "../common/math.h" +#include "../common/optional_weight.h" // OptionalWeights #include "../common/pseudo_huber.h" #include "../common/threading_utils.h" #include "metric_common.h" diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 165f435d3..7a0df336a 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -14,7 +14,8 @@ #include "../common/common.h" #include "../common/linalg_op.h" -#include "../common/numeric.h" // Reduce +#include "../common/numeric.h" // Reduce +#include "../common/optional_weight.h" // OptionalWeights #include "../common/pseudo_huber.h" #include "../common/stats.h" #include "../common/threading_utils.h" diff --git a/tests/cpp/common/test_optional_weight.cc b/tests/cpp/common/test_optional_weight.cc new file mode 100644 index 000000000..e2c59e608 --- /dev/null +++ b/tests/cpp/common/test_optional_weight.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // Context +#include // HostDeviceVector + +#include "../../../src/common/optional_weight.h" +namespace xgboost { +namespace common { +TEST(OptionalWeight, Basic) { + HostDeviceVector weight{{2.0f, 3.0f, 4.0f}}; + Context ctx; + auto opt_w = MakeOptionalWeights(&ctx, weight); + ASSERT_EQ(opt_w[0], 2.0f); + ASSERT_FALSE(opt_w.Empty()); + + weight.HostVector().clear(); + opt_w = MakeOptionalWeights(&ctx, weight); + ASSERT_EQ(opt_w[0], 1.0f); + ASSERT_TRUE(opt_w.Empty()); +} +} // namespace common +} // namespace xgboost