Allow using string view to find JSON value. (#8332)
- Allow comparison between string and string view. - Fix compiler warnings.
This commit is contained in:
@@ -258,7 +258,7 @@ void TypeCheck(Json const &value, StringView name) {
|
||||
}
|
||||
|
||||
template <typename JT>
|
||||
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
|
||||
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
||||
auto const &obj = get<Object const>(in);
|
||||
auto it = obj.find(key);
|
||||
if (it == obj.cend() || IsA<Null>(it->second)) {
|
||||
@@ -269,11 +269,11 @@ auto const &RequiredArg(Json const &in, std::string const &key, StringView func)
|
||||
}
|
||||
|
||||
template <typename JT, typename T>
|
||||
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
|
||||
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
|
||||
auto const &obj = get<Object const>(in);
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
||||
TypeCheck<JT>(it->second, StringView{key});
|
||||
TypeCheck<JT>(it->second, key);
|
||||
return get<std::remove_const_t<JT> const>(it->second);
|
||||
}
|
||||
return dft;
|
||||
|
||||
@@ -199,8 +199,8 @@ JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
|
||||
std::swap(that.object_, this->object_);
|
||||
}
|
||||
|
||||
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
|
||||
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
|
||||
JsonObject::JsonObject(Map&& object) noexcept
|
||||
: Value(ValueKind::kObject), object_{std::forward<Map>(object)} {}
|
||||
|
||||
bool JsonObject::operator==(Value const& rhs) const {
|
||||
if (!IsA<JsonObject>(&rhs)) {
|
||||
@@ -502,7 +502,7 @@ Json JsonReader::ParseArray() {
|
||||
Json JsonReader::ParseObject() {
|
||||
GetConsecutiveChar('{');
|
||||
|
||||
std::map<std::string, Json> data;
|
||||
Object::Map data;
|
||||
SkipSpaces();
|
||||
char ch = PeekNextChar();
|
||||
|
||||
@@ -777,7 +777,7 @@ std::string UBJReader::DecodeStr() {
|
||||
|
||||
Json UBJReader::ParseObject() {
|
||||
auto marker = PeekNextChar();
|
||||
std::map<std::string, Json> results;
|
||||
Object::Map results;
|
||||
|
||||
while (marker != '}') {
|
||||
auto str = this->DecodeStr();
|
||||
|
||||
@@ -99,7 +99,7 @@ class ArrayInterfaceHandler {
|
||||
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
template <typename PtrType>
|
||||
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const &obj) {
|
||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||
auto data_it = obj.find("data");
|
||||
if (data_it == obj.cend()) {
|
||||
LOG(FATAL) << "Empty data passed in.";
|
||||
@@ -109,7 +109,7 @@ class ArrayInterfaceHandler {
|
||||
return p_data;
|
||||
}
|
||||
|
||||
static void Validate(std::map<std::string, Json> const &array) {
|
||||
static void Validate(Object::Map const &array) {
|
||||
auto version_it = array.find("version");
|
||||
if (version_it == array.cend()) {
|
||||
LOG(FATAL) << "Missing `version' field for array interface";
|
||||
@@ -136,7 +136,7 @@ class ArrayInterfaceHandler {
|
||||
|
||||
// Find null mask (validity mask) field
|
||||
// Mask object is also an array interface, but with different requirements.
|
||||
static size_t ExtractMask(std::map<std::string, Json> const &column,
|
||||
static size_t ExtractMask(Object::Map const &column,
|
||||
common::Span<RBitField8::value_type> *p_out) {
|
||||
auto &s_mask = *p_out;
|
||||
if (column.find("mask") != column.cend()) {
|
||||
@@ -208,7 +208,7 @@ class ArrayInterfaceHandler {
|
||||
}
|
||||
|
||||
template <int32_t D>
|
||||
static void ExtractShape(std::map<std::string, Json> const &array, size_t (&out_shape)[D]) {
|
||||
static void ExtractShape(Object::Map const &array, size_t (&out_shape)[D]) {
|
||||
auto const &j_shape = get<Array const>(array.at("shape"));
|
||||
std::vector<size_t> shape_arr(j_shape.size(), 0);
|
||||
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
|
||||
@@ -229,7 +229,7 @@ class ArrayInterfaceHandler {
|
||||
* \brief Extracts the optiona `strides' field and returns whether the array is c-contiguous.
|
||||
*/
|
||||
template <int32_t D>
|
||||
static bool ExtractStride(std::map<std::string, Json> const &array, size_t itemsize,
|
||||
static bool ExtractStride(Object::Map const &array, size_t itemsize,
|
||||
size_t (&shape)[D], size_t (&stride)[D]) {
|
||||
auto strides_it = array.find("strides");
|
||||
// No stride is provided
|
||||
@@ -272,7 +272,7 @@ class ArrayInterfaceHandler {
|
||||
return std::equal(stride_tmp, stride_tmp + D, stride);
|
||||
}
|
||||
|
||||
static void *ExtractData(std::map<std::string, Json> const &array, size_t size) {
|
||||
static void *ExtractData(Object::Map const &array, size_t size) {
|
||||
Validate(array);
|
||||
void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
|
||||
if (!p_data) {
|
||||
@@ -378,7 +378,7 @@ class ArrayInterface {
|
||||
* to a vector of size n_samples. For for inputs like weights, this should be a 1
|
||||
* dimension column vector even though user might provide a matrix.
|
||||
*/
|
||||
void Initialize(std::map<std::string, Json> const &array) {
|
||||
void Initialize(Object::Map const &array) {
|
||||
ArrayInterfaceHandler::Validate(array);
|
||||
|
||||
auto typestr = get<String const>(array.at("typestr"));
|
||||
@@ -413,7 +413,7 @@ class ArrayInterface {
|
||||
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &array) { this->Initialize(array); }
|
||||
explicit ArrayInterface(Object::Map const &array) { this->Initialize(array); }
|
||||
|
||||
explicit ArrayInterface(Json const &array) {
|
||||
if (IsA<Object>(array)) {
|
||||
|
||||
@@ -60,8 +60,7 @@ struct DeviceAUCCache {
|
||||
};
|
||||
|
||||
template <bool is_multi>
|
||||
void InitCacheOnce(common::Span<float const> predts, int32_t device,
|
||||
std::shared_ptr<DeviceAUCCache>* p_cache) {
|
||||
void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||
auto& cache = *p_cache;
|
||||
if (!cache) {
|
||||
cache.reset(new DeviceAUCCache);
|
||||
@@ -167,7 +166,7 @@ std::tuple<double, double, double>
|
||||
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||
auto &cache = *p_cache;
|
||||
InitCacheOnce<false>(predts, device, p_cache);
|
||||
InitCacheOnce<false>(predts, p_cache);
|
||||
|
||||
/**
|
||||
* Create sorted index for each class
|
||||
@@ -196,8 +195,7 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
|
||||
}
|
||||
|
||||
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
|
||||
common::Span<double> tp, common::Span<double> auc,
|
||||
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
|
||||
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
@@ -330,7 +328,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
|
||||
auto local_area = d_results.subspan(0, n_classes);
|
||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
|
||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -434,7 +432,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
|
||||
tp[c] = 1.0f;
|
||||
}
|
||||
});
|
||||
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
|
||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
void MultiClassSortedIdx(common::Span<float const> predts,
|
||||
@@ -458,7 +456,7 @@ double GPUMultiClassROCAUC(common::Span<float const> predts,
|
||||
std::shared_ptr<DeviceAUCCache> *p_cache,
|
||||
size_t n_classes) {
|
||||
auto& cache = *p_cache;
|
||||
InitCacheOnce<true>(predts, device, p_cache);
|
||||
InitCacheOnce<true>(predts, p_cache);
|
||||
|
||||
/**
|
||||
* Create sorted index for each class
|
||||
@@ -486,7 +484,7 @@ std::pair<double, uint32_t>
|
||||
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||
auto& cache = *p_cache;
|
||||
InitCacheOnce<false>(predts, device, p_cache);
|
||||
InitCacheOnce<false>(predts, p_cache);
|
||||
|
||||
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
@@ -606,7 +604,7 @@ std::tuple<double, double, double>
|
||||
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
|
||||
auto& cache = *p_cache;
|
||||
InitCacheOnce<false>(predts, device, p_cache);
|
||||
InitCacheOnce<false>(predts, p_cache);
|
||||
|
||||
/**
|
||||
* Create sorted index for each class
|
||||
@@ -647,7 +645,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
|
||||
std::shared_ptr<DeviceAUCCache> *p_cache,
|
||||
size_t n_classes) {
|
||||
auto& cache = *p_cache;
|
||||
InitCacheOnce<true>(predts, device, p_cache);
|
||||
InitCacheOnce<true>(predts, p_cache);
|
||||
|
||||
/**
|
||||
* Create sorted index for each class
|
||||
@@ -827,7 +825,7 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
}
|
||||
|
||||
auto &cache = *p_cache;
|
||||
InitCacheOnce<false>(predts, device, p_cache);
|
||||
InitCacheOnce<false>(predts, p_cache);
|
||||
|
||||
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
|
||||
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
|
||||
|
||||
Reference in New Issue
Block a user