Extract JSON utils. (#9645)

This commit is contained in:
Jiaming Yuan 2023-10-10 07:15:14 +08:00 committed by GitHub
parent 4e5a7729c3
commit 680d53db43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 75 deletions

View File

@ -608,44 +608,6 @@ using Boolean = JsonBoolean;
using String = JsonString; using String = JsonString;
using Null = JsonNull; using Null = JsonNull;
// Utils tailored for XGBoost.
namespace detail {
template <typename Head>
bool TypeCheckImpl(Json const& value) {
return IsA<Head>(value);
}
template <typename Head, typename... JT>
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const& value) {
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
}
template <typename Head>
std::string TypeCheckError() {
return "`" + Head{}.TypeStr() + "`";
}
template <typename Head, typename... JT>
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
}
} // namespace detail
/**
* \brief Type check for JSON-based parameters
*
* \tparam JT Expected JSON types.
* \param value Value to be checked.
*/
template <typename... JT>
void TypeCheck(Json const& value, StringView name) {
if (!detail::TypeCheckImpl<JT...>(value)) {
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
<< "`";
}
}
/** /**
* \brief Convert XGBoost parameter to JSON object. * \brief Convert XGBoost parameter to JSON object.
* *

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021-2023 by XGBoost Contributors * Copyright 2021-2023, XGBoost Contributors
*/ */
#ifndef XGBOOST_C_API_C_API_UTILS_H_ #ifndef XGBOOST_C_API_C_API_UTILS_H_
#define XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_
@ -13,6 +13,7 @@
#include <utility> // for move #include <utility> // for move
#include <vector> #include <vector>
#include "../common/json_utils.h" // for TypeCheck
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "xgboost/data.h" // DMatrix #include "xgboost/data.h" // DMatrix
#include "xgboost/feature_map.h" // for FeatureMap #include "xgboost/feature_map.h" // for FeatureMap
@ -254,28 +255,6 @@ inline void GenerateFeatureMap(Learner const *learner,
void XGBBuildInfoDevice(Json* p_info); void XGBBuildInfoDevice(Json* p_info);
template <typename JT>
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
}
TypeCheck<JT>(it->second, StringView{key});
return get<std::remove_const_t<JT> const>(it->second);
}
template <typename JT, typename T>
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend() && !IsA<Null>(it->second)) {
TypeCheck<JT>(it->second, key);
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
}
/** /**
* \brief Get shared ptr from DMatrix C handle with additional checks. * \brief Get shared ptr from DMatrix C handle with additional checks.
*/ */

View File

@ -8,7 +8,7 @@
#define XGBOOST_COMMON_IO_H_ #define XGBOOST_COMMON_IO_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <rabit/rabit.h> #include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
#include <algorithm> // for min, fill_n, copy_n #include <algorithm> // for min, fill_n, copy_n
#include <array> // for array #include <array> // for array

74
src/common/json_utils.h Normal file
View File

@ -0,0 +1,74 @@
/**
* Copyright 2023, XGBoost Contributors
*
* @brief Utils tailored for XGBoost.
*/
#pragma once
#include <string> // for string
#include <type_traits> // for enable_if_t, remove_const_t
#include "xgboost/json.h"
#include "xgboost/string_view.h" // for StringView
namespace xgboost {
namespace detail {
template <typename Head>
bool TypeCheckImpl(Json const &value) {
return IsA<Head>(value);
}
template <typename Head, typename... JT>
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const &value) {
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
}
template <typename Head>
std::string TypeCheckError() {
return "`" + Head{}.TypeStr() + "`";
}
template <typename Head, typename... JT>
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
}
} // namespace detail
/**
* @brief Type check for JSON-based parameters
*
* @tparam JT Expected JSON types.
* @param value Value to be checked.
*/
template <typename... JT>
void TypeCheck(Json const &value, StringView name) {
if (!detail::TypeCheckImpl<JT...>(value)) {
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
<< "`";
}
}
template <typename JT>
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
}
TypeCheck<JT>(it->second, StringView{key});
return get<std::remove_const_t<JT> const>(it->second);
}
template <typename JT, typename T>
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend() && !IsA<Null>(it->second)) {
TypeCheck<JT>(it->second, key);
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
}
} // namespace xgboost

View File

@ -1,19 +1,19 @@
/** /**
* Copyright 2023 by XGBoost contributors * Copyright 2023, XGBoost contributors
*/ */
#include "quantile_loss_utils.h" #include "quantile_loss_utils.h"
#include <cctype> // std::isspace #include <cctype> // for isspace
#include <istream> // std::istream #include <istream> // for istream
#include <ostream> // std::ostream #include <ostream> // for ostream
#include <string> // std::string #include <string> // for string
#include <vector> // std::vector #include <vector> // for vector
#include "xgboost/json.h" // F32Array,TypeCheck,get,Number #include "../common/json_utils.h" // for TypeCheck
#include "xgboost/json_io.h" // JsonWriter #include "xgboost/json.h" // for F32Array, get, Number
#include "xgboost/json_io.h" // for JsonWriter
namespace xgboost { namespace xgboost::common {
namespace common {
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) { std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
auto const& t = array.Get(); auto const& t = array.Get();
xgboost::F32Array arr{t.size()}; xgboost::F32Array arr{t.size()};
@ -70,5 +70,4 @@ std::istream& operator>>(std::istream& is, ParamFloatArray& array) {
} }
DMLC_REGISTER_PARAMETER(QuantileLossParam); DMLC_REGISTER_PARAMETER(QuantileLossParam);
} // namespace common } // namespace xgboost::common
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/** /**
* Copyright (c) 2019-2023, XGBoost Contributors * Copyright 2019-2023, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -9,6 +9,7 @@
#include "../../../src/common/charconv.h" #include "../../../src/common/charconv.h"
#include "../../../src/common/io.h" #include "../../../src/common/io.h"
#include "../../../src/common/json_utils.h"
#include "../../../src/common/threading_utils.h" // for ParallelFor #include "../../../src/common/threading_utils.h" // for ParallelFor
#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h" #include "../helpers.h"