Extract make metric name from ranking metric. (#8768)
- Extract the metric parsing routine from ranking. - Add a test. - Accept null for string view.
This commit is contained in:
parent
4ead65a28c
commit
5f76edd296
@ -90,6 +90,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/common/stats.o \
|
$(PKGROOT)/src/common/stats.o \
|
||||||
$(PKGROOT)/src/common/survival_util.o \
|
$(PKGROOT)/src/common/survival_util.o \
|
||||||
$(PKGROOT)/src/common/threading_utils.o \
|
$(PKGROOT)/src/common/threading_utils.o \
|
||||||
|
$(PKGROOT)/src/common/ranking_utils.o \
|
||||||
$(PKGROOT)/src/common/timer.o \
|
$(PKGROOT)/src/common/timer.o \
|
||||||
$(PKGROOT)/src/common/version.o \
|
$(PKGROOT)/src/common/version.o \
|
||||||
$(PKGROOT)/src/c_api/c_api.o \
|
$(PKGROOT)/src/c_api/c_api.o \
|
||||||
|
|||||||
@ -90,6 +90,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/common/stats.o \
|
$(PKGROOT)/src/common/stats.o \
|
||||||
$(PKGROOT)/src/common/survival_util.o \
|
$(PKGROOT)/src/common/survival_util.o \
|
||||||
$(PKGROOT)/src/common/threading_utils.o \
|
$(PKGROOT)/src/common/threading_utils.o \
|
||||||
|
$(PKGROOT)/src/common/ranking_utils.o \
|
||||||
$(PKGROOT)/src/common/timer.o \
|
$(PKGROOT)/src/common/timer.o \
|
||||||
$(PKGROOT)/src/common/version.o \
|
$(PKGROOT)/src/common/version.o \
|
||||||
$(PKGROOT)/src/c_api/c_api.o \
|
$(PKGROOT)/src/c_api/c_api.o \
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 by XGBoost Contributors
|
* Copyright 2021-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_STRING_VIEW_H_
|
#ifndef XGBOOST_STRING_VIEW_H_
|
||||||
#define XGBOOST_STRING_VIEW_H_
|
#define XGBOOST_STRING_VIEW_H_
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h> // CHECK_LT
|
||||||
#include <xgboost/span.h>
|
#include <xgboost/span.h> // Span
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // std::equal,std::min
|
||||||
#include <iterator>
|
#include <iterator> // std::reverse_iterator
|
||||||
#include <ostream>
|
#include <ostream> // std::ostream
|
||||||
#include <string>
|
#include <string> // std::char_traits,std::string
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
struct StringView {
|
struct StringView {
|
||||||
@ -28,29 +28,31 @@ struct StringView {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
constexpr StringView() = default;
|
constexpr StringView() = default;
|
||||||
constexpr StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
|
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||||
explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {}
|
explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {}
|
||||||
StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {} // NOLINT
|
constexpr StringView(CharT const* str) // NOLINT
|
||||||
|
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}
|
||||||
|
|
||||||
CharT const& operator[](size_t p) const { return str_[p]; }
|
CharT const& operator[](size_t p) const { return str_[p]; }
|
||||||
CharT const& at(size_t p) const { // NOLINT
|
CharT const& at(size_t p) const { // NOLINT
|
||||||
CHECK_LT(p, size_);
|
CHECK_LT(p, size_);
|
||||||
return str_[p];
|
return str_[p];
|
||||||
}
|
}
|
||||||
constexpr size_t size() const { return size_; } // NOLINT
|
constexpr std::size_t size() const { return size_; } // NOLINT
|
||||||
StringView substr(size_t beg, size_t n) const { // NOLINT
|
constexpr bool empty() const { return size() == 0; } // NOLINT
|
||||||
|
StringView substr(size_t beg, size_t n) const { // NOLINT
|
||||||
CHECK_LE(beg, size_);
|
CHECK_LE(beg, size_);
|
||||||
size_t len = std::min(n, size_ - beg);
|
size_t len = std::min(n, size_ - beg);
|
||||||
return {str_ + beg, len};
|
return {str_ + beg, len};
|
||||||
}
|
}
|
||||||
CharT const* c_str() const { return str_; } // NOLINT
|
CharT const* c_str() const { return str_; } // NOLINT
|
||||||
|
|
||||||
constexpr CharT const* cbegin() const { return str_; } // NOLINT
|
constexpr CharT const* cbegin() const { return str_; } // NOLINT
|
||||||
constexpr CharT const* cend() const { return str_ + size(); } // NOLINT
|
constexpr CharT const* cend() const { return str_ + size(); } // NOLINT
|
||||||
constexpr CharT const* begin() const { return str_; } // NOLINT
|
constexpr CharT const* begin() const { return str_; } // NOLINT
|
||||||
constexpr CharT const* end() const { return str_ + size(); } // NOLINT
|
constexpr CharT const* end() const { return str_ + size(); } // NOLINT
|
||||||
|
|
||||||
const_reverse_iterator rbegin() const noexcept { // NOLINT
|
const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||||
return const_reverse_iterator(this->end());
|
return const_reverse_iterator(this->end());
|
||||||
}
|
}
|
||||||
const_reverse_iterator crbegin() const noexcept { // NOLINT
|
const_reverse_iterator crbegin() const noexcept { // NOLINT
|
||||||
|
|||||||
34
src/common/ranking_utils.cc
Normal file
34
src/common/ranking_utils.cc
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "ranking_utils.h"
|
||||||
|
|
||||||
|
#include <cstdint> // std::uint32_t
|
||||||
|
#include <sstream> // std::ostringstream
|
||||||
|
#include <string> // std::string,std::sscanf
|
||||||
|
|
||||||
|
#include "xgboost/string_view.h" // StringView
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace ltr {
|
||||||
|
std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus) {
|
||||||
|
std::string out_name;
|
||||||
|
if (!param.empty()) {
|
||||||
|
std::ostringstream os;
|
||||||
|
if (std::sscanf(param.c_str(), "%u[-]?", topn) == 1) {
|
||||||
|
os << name << '@' << param;
|
||||||
|
out_name = os.str();
|
||||||
|
} else {
|
||||||
|
os << name << param;
|
||||||
|
out_name = os.str();
|
||||||
|
}
|
||||||
|
if (*param.crbegin() == '-') {
|
||||||
|
*minus = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_name = name.c_str();
|
||||||
|
}
|
||||||
|
return out_name;
|
||||||
|
}
|
||||||
|
} // namespace ltr
|
||||||
|
} // namespace xgboost
|
||||||
29
src/common/ranking_utils.h
Normal file
29
src/common/ranking_utils.h
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
|
||||||
|
#define XGBOOST_COMMON_RANKING_UTILS_H_
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
#include <cstdint> // std::uint32_t
|
||||||
|
#include <string> // std::string
|
||||||
|
|
||||||
|
#include "xgboost/string_view.h" // StringView
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace ltr {
|
||||||
|
/**
|
||||||
|
* \brief Construct name for ranking metric given parameters.
|
||||||
|
*
|
||||||
|
* \param [in] name Null terminated string for metric name
|
||||||
|
* \param [in] param Null terminated string for parameter like the `3-` in `ndcg@3-`.
|
||||||
|
* \param [out] topn Top n documents parsed from param. Unchanged if it's not specified.
|
||||||
|
* \param [out] minus Whether we should turn the score into loss. Unchanged if it's not
|
||||||
|
* specified.
|
||||||
|
*
|
||||||
|
* \return The name of the metric.
|
||||||
|
*/
|
||||||
|
std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus);
|
||||||
|
} // namespace ltr
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_RANKING_UTILS_H_
|
||||||
@ -28,6 +28,7 @@
|
|||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
#include "../common/ranking_utils.h" // MakeMetricName
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
@ -232,23 +233,7 @@ struct EvalRank : public Metric, public EvalRankConfig {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit EvalRank(const char* name, const char* param) {
|
explicit EvalRank(const char* name, const char* param) {
|
||||||
using namespace std; // NOLINT(*)
|
this->name = ltr::MakeMetricName(name, param, &topn, &minus);
|
||||||
|
|
||||||
if (param != nullptr) {
|
|
||||||
std::ostringstream os;
|
|
||||||
if (sscanf(param, "%u[-]?", &topn) == 1) {
|
|
||||||
os << name << '@' << param;
|
|
||||||
this->name = os.str();
|
|
||||||
} else {
|
|
||||||
os << name << param;
|
|
||||||
this->name = os.str();
|
|
||||||
}
|
|
||||||
if (param[strlen(param) - 1] == '-') {
|
|
||||||
minus = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
this->name = name;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
|
virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
|
||||||
|
|||||||
38
tests/cpp/common/test_ranking_utils.cc
Normal file
38
tests/cpp/common/test_ranking_utils.cc
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <cstdint> // std::uint32_t
|
||||||
|
|
||||||
|
#include "../../../src/common/ranking_utils.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace ltr {
|
||||||
|
TEST(RankingUtils, MakeMetricName) {
|
||||||
|
std::uint32_t topn{32};
|
||||||
|
bool minus{false};
|
||||||
|
auto name = MakeMetricName("ndcg", "3-", &topn, &minus);
|
||||||
|
ASSERT_EQ(name, "ndcg@3-");
|
||||||
|
ASSERT_EQ(topn, 3);
|
||||||
|
ASSERT_TRUE(minus);
|
||||||
|
|
||||||
|
name = MakeMetricName("ndcg", "6", &topn, &minus);
|
||||||
|
ASSERT_EQ(topn, 6);
|
||||||
|
ASSERT_TRUE(minus); // unchanged
|
||||||
|
|
||||||
|
minus = false;
|
||||||
|
name = MakeMetricName("ndcg", "-", &topn, &minus);
|
||||||
|
ASSERT_EQ(topn, 6); // unchanged
|
||||||
|
ASSERT_TRUE(minus);
|
||||||
|
|
||||||
|
name = MakeMetricName("ndcg", nullptr, &topn, &minus);
|
||||||
|
ASSERT_EQ(topn, 6); // unchanged
|
||||||
|
ASSERT_TRUE(minus); // unchanged
|
||||||
|
|
||||||
|
name = MakeMetricName("ndcg", StringView{}, &topn, &minus);
|
||||||
|
ASSERT_EQ(topn, 6); // unchanged
|
||||||
|
ASSERT_TRUE(minus); // unchanged
|
||||||
|
}
|
||||||
|
} // namespace ltr
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,3 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 by XGBoost Contributors
|
||||||
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "../../../src/common/ranking_utils.cuh"
|
#include "../../../src/common/ranking_utils.cuh"
|
||||||
#include "../../../src/common/device_helpers.cuh"
|
#include "../../../src/common/device_helpers.cuh"
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) by XGBoost Contributors 2021
|
* Copyright 2021-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/string_view.h>
|
#include <xgboost/string_view.h>
|
||||||
#include <string_view>
|
|
||||||
|
#include <algorithm> // std::equal
|
||||||
|
#include <sstream> // std::stringstream
|
||||||
|
#include <string> // std::string
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
TEST(StringView, Basic) {
|
TEST(StringView, Basic) {
|
||||||
StringView str{"This is a string."};
|
StringView str{"This is a string."};
|
||||||
@ -24,5 +28,16 @@ TEST(StringView, Basic) {
|
|||||||
ASSERT_FALSE(substr == "i");
|
ASSERT_FALSE(substr == "i");
|
||||||
|
|
||||||
ASSERT_TRUE(std::equal(substr.crbegin(), substr.crend(), StringView{"si"}.cbegin()));
|
ASSERT_TRUE(std::equal(substr.crbegin(), substr.crend(), StringView{"si"}.cbegin()));
|
||||||
|
|
||||||
|
{
|
||||||
|
StringView empty{nullptr};
|
||||||
|
ASSERT_TRUE(empty.empty());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
StringView empty{""};
|
||||||
|
ASSERT_TRUE(empty.empty());
|
||||||
|
StringView empty2{nullptr};
|
||||||
|
ASSERT_EQ(empty, empty2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user