From 680d53db43ffbd76023a0dec8f998d6a4c4e3c88 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 10 Oct 2023 07:15:14 +0800 Subject: [PATCH] Extract JSON utils. (#9645) --- include/xgboost/json.h | 38 ---------------- src/c_api/c_api_utils.h | 25 +---------- src/common/io.h | 2 +- src/common/json_utils.h | 74 +++++++++++++++++++++++++++++++ src/common/quantile_loss_utils.cc | 23 +++++----- tests/cpp/common/test_json.cc | 3 +- 6 files changed, 90 insertions(+), 75 deletions(-) create mode 100644 src/common/json_utils.h diff --git a/include/xgboost/json.h b/include/xgboost/json.h index b099d1c47..c2c16ef8f 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -608,44 +608,6 @@ using Boolean = JsonBoolean; using String = JsonString; using Null = JsonNull; -// Utils tailored for XGBoost. -namespace detail { -template -bool TypeCheckImpl(Json const& value) { - return IsA(value); -} - -template -std::enable_if_t TypeCheckImpl(Json const& value) { - return IsA(value) || TypeCheckImpl(value); -} - -template -std::string TypeCheckError() { - return "`" + Head{}.TypeStr() + "`"; -} - -template -std::enable_if_t TypeCheckError() { - return "`" + Head{}.TypeStr() + "`, " + TypeCheckError(); -} -} // namespace detail - -/** - * \brief Type check for JSON-based parameters - * - * \tparam JT Expected JSON types. - * \param value Value to be checked. - */ -template -void TypeCheck(Json const& value, StringView name) { - if (!detail::TypeCheckImpl(value)) { - LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`" - << detail::TypeCheckError() << "}, got: `" << value.GetValue().TypeStr() - << "`"; - } -} - /** * \brief Convert XGBoost parameter to JSON object. * diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 19dd6d639..95efb5b9d 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2023, XGBoost Contributors */ #ifndef XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_ @@ -13,6 +13,7 @@ #include // for move #include +#include "../common/json_utils.h" // for TypeCheck #include "xgboost/c_api.h" #include "xgboost/data.h" // DMatrix #include "xgboost/feature_map.h" // for FeatureMap @@ -254,28 +255,6 @@ inline void GenerateFeatureMap(Learner const *learner, void XGBBuildInfoDevice(Json* p_info); -template -auto const &RequiredArg(Json const &in, StringView key, StringView func) { - auto const &obj = get(in); - auto it = obj.find(key); - if (it == obj.cend() || IsA(it->second)) { - LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`."; - } - TypeCheck(it->second, StringView{key}); - return get const>(it->second); -} - -template -auto const &OptionalArg(Json const &in, StringView key, T const &dft) { - auto const &obj = get(in); - auto it = obj.find(key); - if (it != obj.cend() && !IsA(it->second)) { - TypeCheck(it->second, key); - return get const>(it->second); - } - return dft; -} - /** * \brief Get shared ptr from DMatrix C handle with additional checks. */ diff --git a/src/common/io.h b/src/common/io.h index 2eb62b094..5e9d27582 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -8,7 +8,7 @@ #define XGBOOST_COMMON_IO_H_ #include -#include +#include // for MemoryFixSizeBuffer, MemoryBufferStream #include // for min, fill_n, copy_n #include // for array diff --git a/src/common/json_utils.h b/src/common/json_utils.h new file mode 100644 index 000000000..a2a8a3cae --- /dev/null +++ b/src/common/json_utils.h @@ -0,0 +1,74 @@ +/** + * Copyright 2023, XGBoost Contributors + * + * @brief Utils tailored for XGBoost. + */ +#pragma once + +#include // for string +#include // for enable_if_t, remove_const_t + +#include "xgboost/json.h" +#include "xgboost/string_view.h" // for StringView + +namespace xgboost { +namespace detail { +template +bool TypeCheckImpl(Json const &value) { + return IsA(value); +} + +template +std::enable_if_t TypeCheckImpl(Json const &value) { + return IsA(value) || TypeCheckImpl(value); +} + +template +std::string TypeCheckError() { + return "`" + Head{}.TypeStr() + "`"; +} + +template +std::enable_if_t TypeCheckError() { + return "`" + Head{}.TypeStr() + "`, " + TypeCheckError(); +} +} // namespace detail + +/** + * @brief Type check for JSON-based parameters + * + * @tparam JT Expected JSON types. + * @param value Value to be checked. + */ +template +void TypeCheck(Json const &value, StringView name) { + if (!detail::TypeCheckImpl(value)) { + LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`" + << detail::TypeCheckError() << "}, got: `" << value.GetValue().TypeStr() + << "`"; + } +} + +template +auto const &RequiredArg(Json const &in, StringView key, StringView func) { + auto const &obj = get(in); + auto it = obj.find(key); + if (it == obj.cend() || IsA(it->second)) { + LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`."; + } + TypeCheck(it->second, StringView{key}); + return get const>(it->second); +} + +template +auto const &OptionalArg(Json const &in, StringView key, T const &dft) { + auto const &obj = get(in); + auto it = obj.find(key); + if (it != obj.cend() && !IsA(it->second)) { + TypeCheck(it->second, key); + + return get const>(it->second); + } + return dft; +} +} // namespace xgboost diff --git a/src/common/quantile_loss_utils.cc b/src/common/quantile_loss_utils.cc index 59397b701..df2fa6edd 100644 --- a/src/common/quantile_loss_utils.cc +++ b/src/common/quantile_loss_utils.cc @@ -1,19 +1,19 @@ /** - * Copyright 2023 by XGBoost contributors + * Copyright 2023, XGBoost contributors */ #include "quantile_loss_utils.h" -#include // std::isspace -#include // std::istream -#include // std::ostream -#include // std::string -#include // std::vector +#include // for isspace +#include // for istream +#include // for ostream +#include // for string +#include // for vector -#include "xgboost/json.h" // F32Array,TypeCheck,get,Number -#include "xgboost/json_io.h" // JsonWriter +#include "../common/json_utils.h" // for TypeCheck +#include "xgboost/json.h" // for F32Array, get, Number +#include "xgboost/json_io.h" // for JsonWriter -namespace xgboost { -namespace common { +namespace xgboost::common { std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) { auto const& t = array.Get(); xgboost::F32Array arr{t.size()}; @@ -70,5 +70,4 @@ std::istream& operator>>(std::istream& is, ParamFloatArray& array) { } DMLC_REGISTER_PARAMETER(QuantileLossParam); -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 1d1319274..d361552ce 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2019-2023, XGBoost Contributors + * Copyright 2019-2023, XGBoost Contributors */ #include @@ -9,6 +9,7 @@ #include "../../../src/common/charconv.h" #include "../../../src/common/io.h" +#include "../../../src/common/json_utils.h" #include "../../../src/common/threading_utils.h" // for ParallelFor #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h"