Extract make metric name from ranking metric. (#8768)
- Extract the metric parsing routine from ranking. - Add a test. - Accept null for string view.
This commit is contained in:
parent
4ead65a28c
commit
5f76edd296
@ -90,6 +90,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/common/stats.o \
|
||||
$(PKGROOT)/src/common/survival_util.o \
|
||||
$(PKGROOT)/src/common/threading_utils.o \
|
||||
$(PKGROOT)/src/common/ranking_utils.o \
|
||||
$(PKGROOT)/src/common/timer.o \
|
||||
$(PKGROOT)/src/common/version.o \
|
||||
$(PKGROOT)/src/c_api/c_api.o \
|
||||
|
||||
@ -90,6 +90,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/common/stats.o \
|
||||
$(PKGROOT)/src/common/survival_util.o \
|
||||
$(PKGROOT)/src/common/threading_utils.o \
|
||||
$(PKGROOT)/src/common/ranking_utils.o \
|
||||
$(PKGROOT)/src/common/timer.o \
|
||||
$(PKGROOT)/src/common/version.o \
|
||||
$(PKGROOT)/src/c_api/c_api.o \
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_STRING_VIEW_H_
|
||||
#define XGBOOST_STRING_VIEW_H_
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/span.h>
|
||||
#include <xgboost/logging.h> // CHECK_LT
|
||||
#include <xgboost/span.h> // Span
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <algorithm> // std::equal,std::min
|
||||
#include <iterator> // std::reverse_iterator
|
||||
#include <ostream> // std::ostream
|
||||
#include <string> // std::char_traits,std::string
|
||||
|
||||
namespace xgboost {
|
||||
struct StringView {
|
||||
@ -28,29 +28,31 @@ struct StringView {
|
||||
|
||||
public:
|
||||
constexpr StringView() = default;
|
||||
constexpr StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
|
||||
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||
explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {}
|
||||
StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {} // NOLINT
|
||||
constexpr StringView(CharT const* str) // NOLINT
|
||||
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}
|
||||
|
||||
CharT const& operator[](size_t p) const { return str_[p]; }
|
||||
CharT const& at(size_t p) const { // NOLINT
|
||||
CHECK_LT(p, size_);
|
||||
return str_[p];
|
||||
}
|
||||
constexpr size_t size() const { return size_; } // NOLINT
|
||||
StringView substr(size_t beg, size_t n) const { // NOLINT
|
||||
constexpr std::size_t size() const { return size_; } // NOLINT
|
||||
constexpr bool empty() const { return size() == 0; } // NOLINT
|
||||
StringView substr(size_t beg, size_t n) const { // NOLINT
|
||||
CHECK_LE(beg, size_);
|
||||
size_t len = std::min(n, size_ - beg);
|
||||
return {str_ + beg, len};
|
||||
}
|
||||
CharT const* c_str() const { return str_; } // NOLINT
|
||||
CharT const* c_str() const { return str_; } // NOLINT
|
||||
|
||||
constexpr CharT const* cbegin() const { return str_; } // NOLINT
|
||||
constexpr CharT const* cend() const { return str_ + size(); } // NOLINT
|
||||
constexpr CharT const* begin() const { return str_; } // NOLINT
|
||||
constexpr CharT const* end() const { return str_ + size(); } // NOLINT
|
||||
|
||||
const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||
const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->end());
|
||||
}
|
||||
const_reverse_iterator crbegin() const noexcept { // NOLINT
|
||||
|
||||
34
src/common/ranking_utils.cc
Normal file
34
src/common/ranking_utils.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#include "ranking_utils.h"
|
||||
|
||||
#include <cstdint> // std::uint32_t
|
||||
#include <sstream> // std::ostringstream
|
||||
#include <string> // std::string,std::sscanf
|
||||
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
|
||||
namespace xgboost {
|
||||
namespace ltr {
|
||||
std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus) {
|
||||
std::string out_name;
|
||||
if (!param.empty()) {
|
||||
std::ostringstream os;
|
||||
if (std::sscanf(param.c_str(), "%u[-]?", topn) == 1) {
|
||||
os << name << '@' << param;
|
||||
out_name = os.str();
|
||||
} else {
|
||||
os << name << param;
|
||||
out_name = os.str();
|
||||
}
|
||||
if (*param.crbegin() == '-') {
|
||||
*minus = true;
|
||||
}
|
||||
} else {
|
||||
out_name = name.c_str();
|
||||
}
|
||||
return out_name;
|
||||
}
|
||||
} // namespace ltr
|
||||
} // namespace xgboost
|
||||
29
src/common/ranking_utils.h
Normal file
29
src/common/ranking_utils.h
Normal file
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
|
||||
#define XGBOOST_COMMON_RANKING_UTILS_H_
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::uint32_t
|
||||
#include <string> // std::string
|
||||
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
|
||||
namespace xgboost {
|
||||
namespace ltr {
|
||||
/**
|
||||
* \brief Construct name for ranking metric given parameters.
|
||||
*
|
||||
* \param [in] name Null terminated string for metric name
|
||||
* \param [in] param Null terminated string for parameter like the `3-` in `ndcg@3-`.
|
||||
* \param [out] topn Top n documents parsed from param. Unchanged if it's not specified.
|
||||
* \param [out] minus Whether we should turn the score into loss. Unchanged if it's not
|
||||
* specified.
|
||||
*
|
||||
* \return The name of the metric.
|
||||
*/
|
||||
std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus);
|
||||
} // namespace ltr
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_RANKING_UTILS_H_
|
||||
@ -28,6 +28,7 @@
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/ranking_utils.h" // MakeMetricName
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
@ -232,23 +233,7 @@ struct EvalRank : public Metric, public EvalRankConfig {
|
||||
|
||||
protected:
|
||||
explicit EvalRank(const char* name, const char* param) {
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
if (param != nullptr) {
|
||||
std::ostringstream os;
|
||||
if (sscanf(param, "%u[-]?", &topn) == 1) {
|
||||
os << name << '@' << param;
|
||||
this->name = os.str();
|
||||
} else {
|
||||
os << name << param;
|
||||
this->name = os.str();
|
||||
}
|
||||
if (param[strlen(param) - 1] == '-') {
|
||||
minus = true;
|
||||
}
|
||||
} else {
|
||||
this->name = name;
|
||||
}
|
||||
this->name = ltr::MakeMetricName(name, param, &topn, &minus);
|
||||
}
|
||||
|
||||
virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
|
||||
|
||||
38
tests/cpp/common/test_ranking_utils.cc
Normal file
38
tests/cpp/common/test_ranking_utils.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstdint> // std::uint32_t
|
||||
|
||||
#include "../../../src/common/ranking_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace ltr {
|
||||
TEST(RankingUtils, MakeMetricName) {
|
||||
std::uint32_t topn{32};
|
||||
bool minus{false};
|
||||
auto name = MakeMetricName("ndcg", "3-", &topn, &minus);
|
||||
ASSERT_EQ(name, "ndcg@3-");
|
||||
ASSERT_EQ(topn, 3);
|
||||
ASSERT_TRUE(minus);
|
||||
|
||||
name = MakeMetricName("ndcg", "6", &topn, &minus);
|
||||
ASSERT_EQ(topn, 6);
|
||||
ASSERT_TRUE(minus); // unchanged
|
||||
|
||||
minus = false;
|
||||
name = MakeMetricName("ndcg", "-", &topn, &minus);
|
||||
ASSERT_EQ(topn, 6); // unchanged
|
||||
ASSERT_TRUE(minus);
|
||||
|
||||
name = MakeMetricName("ndcg", nullptr, &topn, &minus);
|
||||
ASSERT_EQ(topn, 6); // unchanged
|
||||
ASSERT_TRUE(minus); // unchanged
|
||||
|
||||
name = MakeMetricName("ndcg", StringView{}, &topn, &minus);
|
||||
ASSERT_EQ(topn, 6); // unchanged
|
||||
ASSERT_TRUE(minus); // unchanged
|
||||
}
|
||||
} // namespace ltr
|
||||
} // namespace xgboost
|
||||
@ -1,3 +1,6 @@
|
||||
/**
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../src/common/ranking_utils.cuh"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
/*!
|
||||
* Copyright (c) by XGBoost Contributors 2021
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/string_view.h>
|
||||
#include <string_view>
|
||||
|
||||
#include <algorithm> // std::equal
|
||||
#include <sstream> // std::stringstream
|
||||
#include <string> // std::string
|
||||
|
||||
namespace xgboost {
|
||||
TEST(StringView, Basic) {
|
||||
StringView str{"This is a string."};
|
||||
@ -24,5 +28,16 @@ TEST(StringView, Basic) {
|
||||
ASSERT_FALSE(substr == "i");
|
||||
|
||||
ASSERT_TRUE(std::equal(substr.crbegin(), substr.crend(), StringView{"si"}.cbegin()));
|
||||
|
||||
{
|
||||
StringView empty{nullptr};
|
||||
ASSERT_TRUE(empty.empty());
|
||||
}
|
||||
{
|
||||
StringView empty{""};
|
||||
ASSERT_TRUE(empty.empty());
|
||||
StringView empty2{nullptr};
|
||||
ASSERT_EQ(empty, empty2);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user