Allow using string view to find JSON value. (#8332)
- Allow comparison between string and string view. - Fix compiler warnings.
This commit is contained in:
parent
29595102b9
commit
3ef1703553
@ -187,13 +187,17 @@ using I32Array = JsonTypedArray<int32_t, Value::ValueKind::kI32Array>;
|
|||||||
using I64Array = JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;
|
using I64Array = JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;
|
||||||
|
|
||||||
class JsonObject : public Value {
|
class JsonObject : public Value {
|
||||||
std::map<std::string, Json> object_;
|
public:
|
||||||
|
using Map = std::map<std::string, Json, std::less<>>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Map object_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
JsonObject() : Value(ValueKind::kObject) {}
|
JsonObject() : Value(ValueKind::kObject) {}
|
||||||
JsonObject(std::map<std::string, Json>&& object) noexcept; // NOLINT
|
JsonObject(Map&& object) noexcept; // NOLINT
|
||||||
JsonObject(JsonObject const& that) = delete;
|
JsonObject(JsonObject const& that) = delete;
|
||||||
JsonObject(JsonObject && that) noexcept;
|
JsonObject(JsonObject&& that) noexcept;
|
||||||
|
|
||||||
void Save(JsonWriter* writer) const override;
|
void Save(JsonWriter* writer) const override;
|
||||||
|
|
||||||
@ -201,15 +205,13 @@ class JsonObject : public Value {
|
|||||||
Json& operator[](int ind) override { return Value::operator[](ind); }
|
Json& operator[](int ind) override { return Value::operator[](ind); }
|
||||||
Json& operator[](std::string const& key) override { return object_[key]; }
|
Json& operator[](std::string const& key) override { return object_[key]; }
|
||||||
|
|
||||||
std::map<std::string, Json> const& GetObject() && { return object_; }
|
Map const& GetObject() && { return object_; }
|
||||||
std::map<std::string, Json> const& GetObject() const & { return object_; }
|
Map const& GetObject() const& { return object_; }
|
||||||
std::map<std::string, Json> & GetObject() & { return object_; }
|
Map& GetObject() & { return object_; }
|
||||||
|
|
||||||
bool operator==(Value const& rhs) const override;
|
bool operator==(Value const& rhs) const override;
|
||||||
|
|
||||||
static bool IsClassOf(Value const* value) {
|
static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kObject; }
|
||||||
return value->Type() == ValueKind::kObject;
|
|
||||||
}
|
|
||||||
~JsonObject() override = default;
|
~JsonObject() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -559,16 +561,13 @@ std::vector<T> const& GetImpl(JsonTypedArray<T, kind> const& val) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Object
|
// Object
|
||||||
template <typename T,
|
template <typename T, typename std::enable_if<std::is_same<T, JsonObject>::value>::type* = nullptr>
|
||||||
typename std::enable_if<
|
JsonObject::Map& GetImpl(T& val) { // NOLINT
|
||||||
std::is_same<T, JsonObject>::value>::type* = nullptr>
|
|
||||||
std::map<std::string, Json>& GetImpl(T& val) { // NOLINT
|
|
||||||
return val.GetObject();
|
return val.GetObject();
|
||||||
}
|
}
|
||||||
template <typename T,
|
template <typename T,
|
||||||
typename std::enable_if<
|
typename std::enable_if<std::is_same<T, JsonObject const>::value>::type* = nullptr>
|
||||||
std::is_same<T, JsonObject const>::value>::type* = nullptr>
|
JsonObject::Map const& GetImpl(T& val) { // NOLINT
|
||||||
std::map<std::string, Json> const& GetImpl(T& val) { // NOLINT
|
|
||||||
return val.GetObject();
|
return val.GetObject();
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#ifndef XGBOOST_STRING_VIEW_H_
|
#ifndef XGBOOST_STRING_VIEW_H_
|
||||||
#define XGBOOST_STRING_VIEW_H_
|
#define XGBOOST_STRING_VIEW_H_
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
#include <xgboost/span.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
@ -19,6 +20,7 @@ struct StringView {
|
|||||||
size_t size_{0};
|
size_t size_{0};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using value_type = CharT; // NOLINT
|
||||||
using iterator = const CharT*; // NOLINT
|
using iterator = const CharT*; // NOLINT
|
||||||
using const_iterator = iterator; // NOLINT
|
using const_iterator = iterator; // NOLINT
|
||||||
using reverse_iterator = std::reverse_iterator<const_iterator>; // NOLINT
|
using reverse_iterator = std::reverse_iterator<const_iterator>; // NOLINT
|
||||||
@ -77,5 +79,14 @@ inline bool operator==(StringView l, StringView r) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline bool operator!=(StringView l, StringView r) { return !(l == r); }
|
inline bool operator!=(StringView l, StringView r) { return !(l == r); }
|
||||||
|
|
||||||
|
inline bool operator<(StringView l, StringView r) {
|
||||||
|
return common::Span<StringView::value_type const>{l.c_str(), l.size()} <
|
||||||
|
common::Span<StringView::value_type const>{r.c_str(), r.size()};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool operator<(std::string const& l, StringView r) { return StringView{l} < r; }
|
||||||
|
|
||||||
|
inline bool operator<(StringView l, std::string const& r) { return l < StringView{r}; }
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_STRING_VIEW_H_
|
#endif // XGBOOST_STRING_VIEW_H_
|
||||||
|
|||||||
@ -258,7 +258,7 @@ void TypeCheck(Json const &value, StringView name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename JT>
|
template <typename JT>
|
||||||
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
|
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
||||||
auto const &obj = get<Object const>(in);
|
auto const &obj = get<Object const>(in);
|
||||||
auto it = obj.find(key);
|
auto it = obj.find(key);
|
||||||
if (it == obj.cend() || IsA<Null>(it->second)) {
|
if (it == obj.cend() || IsA<Null>(it->second)) {
|
||||||
@ -269,11 +269,11 @@ auto const &RequiredArg(Json const &in, std::string const &key, StringView func)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename JT, typename T>
|
template <typename JT, typename T>
|
||||||
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
|
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
|
||||||
auto const &obj = get<Object const>(in);
|
auto const &obj = get<Object const>(in);
|
||||||
auto it = obj.find(key);
|
auto it = obj.find(key);
|
||||||
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
||||||
TypeCheck<JT>(it->second, StringView{key});
|
TypeCheck<JT>(it->second, key);
|
||||||
return get<std::remove_const_t<JT> const>(it->second);
|
return get<std::remove_const_t<JT> const>(it->second);
|
||||||
}
|
}
|
||||||
return dft;
|
return dft;
|
||||||
|
|||||||
@ -199,8 +199,8 @@ JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
|
|||||||
std::swap(that.object_, this->object_);
|
std::swap(that.object_, this->object_);
|
||||||
}
|
}
|
||||||
|
|
||||||
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
|
JsonObject::JsonObject(Map&& object) noexcept
|
||||||
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
|
: Value(ValueKind::kObject), object_{std::forward<Map>(object)} {}
|
||||||
|
|
||||||
bool JsonObject::operator==(Value const& rhs) const {
|
bool JsonObject::operator==(Value const& rhs) const {
|
||||||
if (!IsA<JsonObject>(&rhs)) {
|
if (!IsA<JsonObject>(&rhs)) {
|
||||||
@ -502,7 +502,7 @@ Json JsonReader::ParseArray() {
|
|||||||
Json JsonReader::ParseObject() {
|
Json JsonReader::ParseObject() {
|
||||||
GetConsecutiveChar('{');
|
GetConsecutiveChar('{');
|
||||||
|
|
||||||
std::map<std::string, Json> data;
|
Object::Map data;
|
||||||
SkipSpaces();
|
SkipSpaces();
|
||||||
char ch = PeekNextChar();
|
char ch = PeekNextChar();
|
||||||
|
|
||||||
@ -777,7 +777,7 @@ std::string UBJReader::DecodeStr() {
|
|||||||
|
|
||||||
Json UBJReader::ParseObject() {
|
Json UBJReader::ParseObject() {
|
||||||
auto marker = PeekNextChar();
|
auto marker = PeekNextChar();
|
||||||
std::map<std::string, Json> results;
|
Object::Map results;
|
||||||
|
|
||||||
while (marker != '}') {
|
while (marker != '}') {
|
||||||
auto str = this->DecodeStr();
|
auto str = this->DecodeStr();
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class ArrayInterfaceHandler {
|
|||||||
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||||
|
|
||||||
template <typename PtrType>
|
template <typename PtrType>
|
||||||
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const &obj) {
|
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||||
auto data_it = obj.find("data");
|
auto data_it = obj.find("data");
|
||||||
if (data_it == obj.cend()) {
|
if (data_it == obj.cend()) {
|
||||||
LOG(FATAL) << "Empty data passed in.";
|
LOG(FATAL) << "Empty data passed in.";
|
||||||
@ -109,7 +109,7 @@ class ArrayInterfaceHandler {
|
|||||||
return p_data;
|
return p_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void Validate(std::map<std::string, Json> const &array) {
|
static void Validate(Object::Map const &array) {
|
||||||
auto version_it = array.find("version");
|
auto version_it = array.find("version");
|
||||||
if (version_it == array.cend()) {
|
if (version_it == array.cend()) {
|
||||||
LOG(FATAL) << "Missing `version' field for array interface";
|
LOG(FATAL) << "Missing `version' field for array interface";
|
||||||
@ -136,7 +136,7 @@ class ArrayInterfaceHandler {
|
|||||||
|
|
||||||
// Find null mask (validity mask) field
|
// Find null mask (validity mask) field
|
||||||
// Mask object is also an array interface, but with different requirements.
|
// Mask object is also an array interface, but with different requirements.
|
||||||
static size_t ExtractMask(std::map<std::string, Json> const &column,
|
static size_t ExtractMask(Object::Map const &column,
|
||||||
common::Span<RBitField8::value_type> *p_out) {
|
common::Span<RBitField8::value_type> *p_out) {
|
||||||
auto &s_mask = *p_out;
|
auto &s_mask = *p_out;
|
||||||
if (column.find("mask") != column.cend()) {
|
if (column.find("mask") != column.cend()) {
|
||||||
@ -208,7 +208,7 @@ class ArrayInterfaceHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int32_t D>
|
template <int32_t D>
|
||||||
static void ExtractShape(std::map<std::string, Json> const &array, size_t (&out_shape)[D]) {
|
static void ExtractShape(Object::Map const &array, size_t (&out_shape)[D]) {
|
||||||
auto const &j_shape = get<Array const>(array.at("shape"));
|
auto const &j_shape = get<Array const>(array.at("shape"));
|
||||||
std::vector<size_t> shape_arr(j_shape.size(), 0);
|
std::vector<size_t> shape_arr(j_shape.size(), 0);
|
||||||
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
|
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
|
||||||
@ -229,7 +229,7 @@ class ArrayInterfaceHandler {
|
|||||||
* \brief Extracts the optiona `strides' field and returns whether the array is c-contiguous.
|
* \brief Extracts the optiona `strides' field and returns whether the array is c-contiguous.
|
||||||
*/
|
*/
|
||||||
template <int32_t D>
|
template <int32_t D>
|
||||||
static bool ExtractStride(std::map<std::string, Json> const &array, size_t itemsize,
|
static bool ExtractStride(Object::Map const &array, size_t itemsize,
|
||||||
size_t (&shape)[D], size_t (&stride)[D]) {
|
size_t (&shape)[D], size_t (&stride)[D]) {
|
||||||
auto strides_it = array.find("strides");
|
auto strides_it = array.find("strides");
|
||||||
// No stride is provided
|
// No stride is provided
|
||||||
@ -272,7 +272,7 @@ class ArrayInterfaceHandler {
|
|||||||
return std::equal(stride_tmp, stride_tmp + D, stride);
|
return std::equal(stride_tmp, stride_tmp + D, stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void *ExtractData(std::map<std::string, Json> const &array, size_t size) {
|
static void *ExtractData(Object::Map const &array, size_t size) {
|
||||||
Validate(array);
|
Validate(array);
|
||||||
void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
|
void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
|
||||||
if (!p_data) {
|
if (!p_data) {
|
||||||
@ -378,7 +378,7 @@ class ArrayInterface {
|
|||||||
* to a vector of size n_samples. For for inputs like weights, this should be a 1
|
* to a vector of size n_samples. For for inputs like weights, this should be a 1
|
||||||
* dimension column vector even though user might provide a matrix.
|
* dimension column vector even though user might provide a matrix.
|
||||||
*/
|
*/
|
||||||
void Initialize(std::map<std::string, Json> const &array) {
|
void Initialize(Object::Map const &array) {
|
||||||
ArrayInterfaceHandler::Validate(array);
|
ArrayInterfaceHandler::Validate(array);
|
||||||
|
|
||||||
auto typestr = get<String const>(array.at("typestr"));
|
auto typestr = get<String const>(array.at("typestr"));
|
||||||
@ -413,7 +413,7 @@ class ArrayInterface {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
ArrayInterface() = default;
|
ArrayInterface() = default;
|
||||||
explicit ArrayInterface(std::map<std::string, Json> const &array) { this->Initialize(array); }
|
explicit ArrayInterface(Object::Map const &array) { this->Initialize(array); }
|
||||||
|
|
||||||
explicit ArrayInterface(Json const &array) {
|
explicit ArrayInterface(Json const &array) {
|
||||||
if (IsA<Object>(array)) {
|
if (IsA<Object>(array)) {
|
||||||
|
|||||||
@ -60,8 +60,7 @@ struct DeviceAUCCache {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <bool is_multi>
|
template <bool is_multi>
|
||||||
void InitCacheOnce(common::Span<float const> predts, int32_t device,
|
void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||||
std::shared_ptr<DeviceAUCCache>* p_cache) {
|
|
||||||
auto& cache = *p_cache;
|
auto& cache = *p_cache;
|
||||||
if (!cache) {
|
if (!cache) {
|
||||||
cache.reset(new DeviceAUCCache);
|
cache.reset(new DeviceAUCCache);
|
||||||
@ -167,7 +166,7 @@ std::tuple<double, double, double>
|
|||||||
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
|
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||||
auto &cache = *p_cache;
|
auto &cache = *p_cache;
|
||||||
InitCacheOnce<false>(predts, device, p_cache);
|
InitCacheOnce<false>(predts, p_cache);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create sorted index for each class
|
* Create sorted index for each class
|
||||||
@ -196,8 +195,7 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
|
|||||||
}
|
}
|
||||||
|
|
||||||
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
|
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
|
||||||
common::Span<double> tp, common::Span<double> auc,
|
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
|
||||||
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
|
|
||||||
dh::XGBDeviceAllocator<char> alloc;
|
dh::XGBDeviceAllocator<char> alloc;
|
||||||
if (collective::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
int32_t device = dh::CurrentDevice();
|
int32_t device = dh::CurrentDevice();
|
||||||
@ -330,7 +328,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
|
|||||||
auto local_area = d_results.subspan(0, n_classes);
|
auto local_area = d_results.subspan(0, n_classes);
|
||||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||||
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
|
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -434,7 +432,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
|
|||||||
tp[c] = 1.0f;
|
tp[c] = 1.0f;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
|
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultiClassSortedIdx(common::Span<float const> predts,
|
void MultiClassSortedIdx(common::Span<float const> predts,
|
||||||
@ -458,7 +456,7 @@ double GPUMultiClassROCAUC(common::Span<float const> predts,
|
|||||||
std::shared_ptr<DeviceAUCCache> *p_cache,
|
std::shared_ptr<DeviceAUCCache> *p_cache,
|
||||||
size_t n_classes) {
|
size_t n_classes) {
|
||||||
auto& cache = *p_cache;
|
auto& cache = *p_cache;
|
||||||
InitCacheOnce<true>(predts, device, p_cache);
|
InitCacheOnce<true>(predts, p_cache);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create sorted index for each class
|
* Create sorted index for each class
|
||||||
@ -486,7 +484,7 @@ std::pair<double, uint32_t>
|
|||||||
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
|
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||||
auto& cache = *p_cache;
|
auto& cache = *p_cache;
|
||||||
InitCacheOnce<false>(predts, device, p_cache);
|
InitCacheOnce<false>(predts, p_cache);
|
||||||
|
|
||||||
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
@ -606,7 +604,7 @@ std::tuple<double, double, double>
|
|||||||
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
|
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||||
auto& cache = *p_cache;
|
auto& cache = *p_cache;
|
||||||
InitCacheOnce<false>(predts, device, p_cache);
|
InitCacheOnce<false>(predts, p_cache);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create sorted index for each class
|
* Create sorted index for each class
|
||||||
@ -647,7 +645,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
|
|||||||
std::shared_ptr<DeviceAUCCache> *p_cache,
|
std::shared_ptr<DeviceAUCCache> *p_cache,
|
||||||
size_t n_classes) {
|
size_t n_classes) {
|
||||||
auto& cache = *p_cache;
|
auto& cache = *p_cache;
|
||||||
InitCacheOnce<true>(predts, device, p_cache);
|
InitCacheOnce<true>(predts, p_cache);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create sorted index for each class
|
* Create sorted index for each class
|
||||||
@ -827,7 +825,7 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto &cache = *p_cache;
|
auto &cache = *p_cache;
|
||||||
InitCacheOnce<false>(predts, device, p_cache);
|
InitCacheOnce<false>(predts, p_cache);
|
||||||
|
|
||||||
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
|
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
|
||||||
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
|
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
|
||||||
|
|||||||
@ -499,8 +499,7 @@ TEST(Json, WrongCasts) {
|
|||||||
ASSERT_ANY_THROW(get<Number>(json));
|
ASSERT_ANY_THROW(get<Number>(json));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
Json json = Json{ Object{std::map<std::string, Json>{
|
Json json = Json{Object{{{"key", Json{String{"value"}}}}}};
|
||||||
{"key", Json{String{"value"}}}} } };
|
|
||||||
ASSERT_ANY_THROW(get<Number>(json));
|
ASSERT_ANY_THROW(get<Number>(json));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,7 +31,7 @@ inline GradientQuantiser DummyRoundingFactor() {
|
|||||||
thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPairPrecise> x) {
|
thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPairPrecise> x) {
|
||||||
auto r = DummyRoundingFactor();
|
auto r = DummyRoundingFactor();
|
||||||
std::vector<GradientPairInt64> y(x.size());
|
std::vector<GradientPairInt64> y(x.size());
|
||||||
for (int i = 0; i < x.size(); i++) {
|
for (std::size_t i = 0; i < x.size(); i++) {
|
||||||
y[i] = r.ToFixedPoint(GradientPair(x[i]));
|
y[i] = r.ToFixedPoint(GradientPair(x[i]));
|
||||||
}
|
}
|
||||||
return y;
|
return y;
|
||||||
@ -51,14 +51,13 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
|||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor();
|
||||||
EvaluateSplitInputs input{1, 0, quantiser.ToFixedPoint(parent_sum_), dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, quantiser.ToFixedPoint(parent_sum_), dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
d_feature_types,
|
||||||
d_feature_types,
|
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
cuts_.cut_values_.ConstDeviceSpan(),
|
||||||
cuts_.cut_values_.ConstDeviceSpan(),
|
cuts_.min_vals_.ConstDeviceSpan(),
|
||||||
cuts_.min_vals_.ConstDeviceSpan(), false
|
false};
|
||||||
};
|
|
||||||
|
|
||||||
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), 0};
|
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||||
|
|
||||||
@ -99,6 +98,7 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts.cut_values_.ConstDeviceSpan(),
|
cuts.cut_values_.ConstDeviceSpan(),
|
||||||
cuts.min_vals_.ConstDeviceSpan(),
|
cuts.min_vals_.ConstDeviceSpan(),
|
||||||
|
false,
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||||
@ -204,14 +204,13 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
|||||||
cuts.SetCategorical(true, max_cat);
|
cuts.SetCategorical(true, max_cat);
|
||||||
|
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor();
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
d_feature_types,
|
||||||
d_feature_types,
|
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
cuts.cut_values_.ConstDeviceSpan(),
|
||||||
cuts.cut_values_.ConstDeviceSpan(),
|
cuts.min_vals_.ConstDeviceSpan(),
|
||||||
cuts.min_vals_.ConstDeviceSpan(),
|
false};
|
||||||
};
|
|
||||||
|
|
||||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||||
@ -263,14 +262,13 @@ TEST(GpuHist, PartitionTwoNodes) {
|
|||||||
cuts.SetCategorical(true, max_cat);
|
cuts.SetCategorical(true, max_cat);
|
||||||
|
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor();
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
d_feature_types,
|
||||||
d_feature_types,
|
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
cuts.cut_values_.ConstDeviceSpan(),
|
||||||
cuts.cut_values_.ConstDeviceSpan(),
|
cuts.min_vals_.ConstDeviceSpan(),
|
||||||
cuts.min_vals_.ConstDeviceSpan(),
|
false};
|
||||||
};
|
|
||||||
|
|
||||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||||
@ -320,14 +318,13 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
|||||||
|
|
||||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
d_feature_types,
|
||||||
d_feature_types,
|
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
cuts.cut_values_.ConstDeviceSpan(),
|
||||||
cuts.cut_values_.ConstDeviceSpan(),
|
cuts.min_vals_.ConstDeviceSpan(),
|
||||||
cuts.min_vals_.ConstDeviceSpan(),
|
false};
|
||||||
};
|
|
||||||
|
|
||||||
GPUHistEvaluator evaluator{
|
GPUHistEvaluator evaluator{
|
||||||
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||||
@ -368,14 +365,13 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
parent_sum,
|
parent_sum,
|
||||||
dh::ToSpan(feature_set),
|
dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
{},
|
||||||
{},
|
dh::ToSpan(feature_segments),
|
||||||
dh::ToSpan(feature_segments),
|
dh::ToSpan(feature_values),
|
||||||
dh::ToSpan(feature_values),
|
dh::ToSpan(feature_min_values),
|
||||||
dh::ToSpan(feature_min_values),
|
false};
|
||||||
};
|
|
||||||
|
|
||||||
GPUHistEvaluator evaluator(tparam, feature_set.size(), 0);
|
GPUHistEvaluator evaluator(tparam, feature_set.size(), 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
@ -394,7 +390,8 @@ TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
|||||||
evaluator
|
evaluator
|
||||||
.EvaluateSingleSplit(
|
.EvaluateSingleSplit(
|
||||||
EvaluateSplitInputs{},
|
EvaluateSplitInputs{},
|
||||||
EvaluateSplitSharedInputs{GPUTrainingParam(tparam), DummyRoundingFactor()})
|
EvaluateSplitSharedInputs{
|
||||||
|
GPUTrainingParam(tparam), DummyRoundingFactor(), {}, {}, {}, {}, false})
|
||||||
.split;
|
.split;
|
||||||
EXPECT_EQ(result.findex, -1);
|
EXPECT_EQ(result.findex, -1);
|
||||||
EXPECT_LT(result.loss_chg, 0.0f);
|
EXPECT_LT(result.loss_chg, 0.0f);
|
||||||
@ -421,14 +418,13 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
|||||||
parent_sum,
|
parent_sum,
|
||||||
dh::ToSpan(feature_set),
|
dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
|
||||||
{},
|
{},
|
||||||
dh::ToSpan(feature_segments),
|
dh::ToSpan(feature_segments),
|
||||||
dh::ToSpan(feature_values),
|
dh::ToSpan(feature_values),
|
||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
};
|
false};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
@ -460,14 +456,13 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
|||||||
parent_sum,
|
parent_sum,
|
||||||
dh::ToSpan(feature_set),
|
dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
param,
|
quantiser,
|
||||||
quantiser,
|
|
||||||
{},
|
{},
|
||||||
dh::ToSpan(feature_segments),
|
dh::ToSpan(feature_segments),
|
||||||
dh::ToSpan(feature_values),
|
dh::ToSpan(feature_values),
|
||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
};
|
false};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input,shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input,shared_inputs).split;
|
||||||
@ -507,10 +502,11 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{
|
||||||
param,
|
param,
|
||||||
quantiser,
|
quantiser,
|
||||||
{},
|
{},
|
||||||
dh::ToSpan(feature_segments),
|
dh::ToSpan(feature_segments),
|
||||||
dh::ToSpan(feature_values),
|
dh::ToSpan(feature_values),
|
||||||
dh::ToSpan(feature_min_values),
|
dh::ToSpan(feature_min_values),
|
||||||
|
false
|
||||||
};
|
};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator{
|
GPUHistEvaluator evaluator{
|
||||||
@ -548,14 +544,13 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
|||||||
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
|
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
|
||||||
|
|
||||||
EvaluateSplitInputs input{0, 0, quantiser.ToFixedPoint(total_gpair_), dh::ToSpan(feature_set), dh::ToSpan(d_hist)};
|
EvaluateSplitInputs input{0, 0, quantiser.ToFixedPoint(total_gpair_), dh::ToSpan(feature_set), dh::ToSpan(d_hist)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param_},
|
||||||
GPUTrainingParam{param_},
|
quantiser,
|
||||||
quantiser,
|
dh::ToSpan(ft),
|
||||||
dh::ToSpan(ft),
|
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
cuts_.cut_values_.ConstDeviceSpan(),
|
||||||
cuts_.cut_values_.ConstDeviceSpan(),
|
cuts_.min_vals_.ConstDeviceSpan(),
|
||||||
cuts_.min_vals_.ConstDeviceSpan(),
|
false};
|
||||||
};
|
|
||||||
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-2);
|
ASSERT_NEAR(split.loss_chg, best_score_, 1e-2);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user