Add quantile metric. (#8761)
This commit is contained in:
74
src/common/quantile_loss_utils.cc
Normal file
74
src/common/quantile_loss_utils.cc
Normal file
@@ -0,0 +1,74 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#include "quantile_loss_utils.h"
|
||||
|
||||
#include <cctype> // std::isspace
|
||||
#include <istream> // std::istream
|
||||
#include <ostream> // std::ostream
|
||||
#include <string> // std::string
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "xgboost/json.h" // F32Array,TypeCheck,get,Number
|
||||
#include "xgboost/json_io.h" // JsonWriter
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
|
||||
auto const& t = array.Get();
|
||||
xgboost::F32Array arr{t.size()};
|
||||
for (std::size_t i = 0; i < t.size(); ++i) {
|
||||
arr.Set(i, t[i]);
|
||||
}
|
||||
std::vector<char> stream;
|
||||
xgboost::JsonWriter writer{&stream};
|
||||
arr.Save(&writer);
|
||||
for (auto c : stream) {
|
||||
os << c;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
std::istream& operator>>(std::istream& is, ParamFloatArray& array) {
|
||||
auto& t = array.Get();
|
||||
t.clear();
|
||||
std::string str;
|
||||
while (!is.eof()) {
|
||||
std::string tmp;
|
||||
is >> tmp;
|
||||
str += tmp;
|
||||
}
|
||||
std::size_t head{0};
|
||||
// unify notation for parsing.
|
||||
while (std::isspace(str[head])) {
|
||||
++head;
|
||||
}
|
||||
if (str[head] == '(') {
|
||||
str[head] = '[';
|
||||
}
|
||||
auto tail = str.size() - 1;
|
||||
while (std::isspace(str[tail])) {
|
||||
--tail;
|
||||
}
|
||||
if (str[tail] == ')') {
|
||||
str[tail] = ']';
|
||||
}
|
||||
|
||||
auto jarr = xgboost::Json::Load(xgboost::StringView{str});
|
||||
// return if there's only one element
|
||||
if (xgboost::IsA<xgboost::Number>(jarr)) {
|
||||
t.emplace_back(xgboost::get<xgboost::Number const>(jarr));
|
||||
return is;
|
||||
}
|
||||
|
||||
auto jvec = xgboost::get<xgboost::Array const>(jarr);
|
||||
for (auto v : jvec) {
|
||||
xgboost::TypeCheck<xgboost::Number>(v, "alpha");
|
||||
t.emplace_back(get<xgboost::Number const>(v));
|
||||
}
|
||||
return is;
|
||||
}
|
||||
|
||||
DMLC_REGISTER_PARAMETER(QuantileLossParam);
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
51
src/common/quantile_loss_utils.h
Normal file
51
src/common/quantile_loss_utils.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
|
||||
#define XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
|
||||
|
||||
#include <algorithm> // std::all_of
|
||||
#include <istream> // std::istream
|
||||
#include <ostream> // std::ostream
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "xgboost/logging.h" // CHECK
|
||||
#include "xgboost/parameter.h" // XGBoostParameter
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
// A shim to enable ADL for parameter parsing. Alternatively, we can put the stream
|
||||
// operators in std namespace, which seems to be less ideal.
|
||||
class ParamFloatArray {
|
||||
std::vector<float> values_;
|
||||
|
||||
public:
|
||||
std::vector<float>& Get() { return values_; }
|
||||
std::vector<float> const& Get() const { return values_; }
|
||||
decltype(values_)::const_reference operator[](decltype(values_)::size_type i) const {
|
||||
return values_[i];
|
||||
}
|
||||
};
|
||||
|
||||
// For parsing quantile parameters. Input can be a string to a single float or a list of
|
||||
// floats.
|
||||
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& t);
|
||||
std::istream& operator>>(std::istream& is, ParamFloatArray& t);
|
||||
|
||||
struct QuantileLossParam : public XGBoostParameter<QuantileLossParam> {
|
||||
ParamFloatArray quantile_alpha;
|
||||
DMLC_DECLARE_PARAMETER(QuantileLossParam) {
|
||||
DMLC_DECLARE_FIELD(quantile_alpha).describe("List of quantiles for quantile loss.");
|
||||
}
|
||||
void Validate() const {
|
||||
CHECK(GetInitialised());
|
||||
CHECK(!quantile_alpha.Get().empty());
|
||||
auto const& array = quantile_alpha.Get();
|
||||
auto valid =
|
||||
std::all_of(array.cbegin(), array.cend(), [](auto q) { return q >= 0.0 && q <= 1.0; });
|
||||
CHECK(valid) << "quantile alpha must be in the range [0.0, 1.0].";
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
|
||||
Reference in New Issue
Block a user