From 43152657d443fb1857ef9ed5c45a9ccd19b60df5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 17 Jan 2023 03:11:07 +0800 Subject: [PATCH] Extract JSON type check. (#8677) - Reuse it in `GetMissing`. - Add test. --- include/xgboost/json.h | 61 +++++++++++++++++++++++++++++++++-- src/c_api/c_api_utils.h | 11 +------ tests/cpp/common/test_json.cc | 16 +++++++++ 3 files changed, 76 insertions(+), 12 deletions(-) diff --git a/include/xgboost/json.h b/include/xgboost/json.h index d35df78f6..3546e58d1 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -1,5 +1,5 @@ -/*! - * Copyright (c) by XGBoost Contributors 2019-2022 +/** + * Copyright by XGBoost Contributors 2019-2023 */ #ifndef XGBOOST_JSON_H_ #define XGBOOST_JSON_H_ @@ -13,6 +13,7 @@ #include #include #include +#include // std::enable_if,std::enable_if_t #include #include @@ -595,6 +596,52 @@ using String = JsonString; using Null = JsonNull; // Utils tailored for XGBoost. +namespace detail { +template +bool TypeCheckImpl(Json const& value) { + return IsA(value); +} + +template +std::enable_if_t TypeCheckImpl(Json const& value) { + return IsA(value) || TypeCheckImpl(value); +} + +template +std::string TypeCheckError() { + return "`" + Head{}.TypeStr() + "`"; +} + +template +std::enable_if_t TypeCheckError() { + return "`" + Head{}.TypeStr() + "`, " + TypeCheckError(); +} +} // namespace detail + +/** + * \brief Type check for JSON-based parameters + * + * \tparam JT Expected JSON types. + * \param value Value to be checked. + */ +template +void TypeCheck(Json const& value, StringView name) { + if (!detail::TypeCheckImpl(value)) { + LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`" + << detail::TypeCheckError() << "}, got: `" << value.GetValue().TypeStr() + << "`"; + } +} + +/** + * \brief Convert XGBoost parameter to JSON object. + * + * \tparam Parameter An instantiation of XGBoostParameter + * + * \param param Input parameter + * + * \return JSON object representing the input parameter + */ template Object ToJson(Parameter const& param) { Object obj; @@ -604,6 +651,16 @@ Object ToJson(Parameter const& param) { return obj; } +/** + * \brief Load a XGBoost parameter from a JSON object. + * + * \tparam Parameter An instantiation of XGBoostParameter + * + * \param obj JSON object representing the parameter. + * \param param Output parameter. + * + * \return Unknown arguments in the JSON object. + */ template Args FromJson(Json const& obj, Parameter* param) { auto const& j_param = get(obj); diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 17c88e8c5..2ccf628bf 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -164,7 +164,7 @@ inline float GetMissing(Json const &config) { missing = get(j_missing); } else { missing = nan(""); - LOG(FATAL) << "Invalid missing value: " << j_missing; + TypeCheck(j_missing, "missing"); } return missing; } @@ -248,15 +248,6 @@ inline void GenerateFeatureMap(Learner const *learner, void XGBBuildInfoDevice(Json* p_info); -template -void TypeCheck(Json const &value, StringView name) { - using T = std::remove_const_t const; - if (!IsA(value)) { - LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr() - << "`, got: `" << value.GetValue().TypeStr() << "`."; - } -} - template auto const &RequiredArg(Json const &in, StringView key, StringView func) { auto const &obj = get(in); diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 71afe4db3..cf8bcd81d 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -10,6 +10,7 @@ #include "../../../src/common/io.h" #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" +#include "dmlc/logging.h" #include "xgboost/json.h" #include "xgboost/json_io.h" #include "xgboost/logging.h" @@ -675,4 +676,19 @@ TEST(UBJson, Basic) { ASSERT_FLOAT_EQ(2.71, get(get(ret["test"])[0])); } } + +TEST(Json, TypeCheck) { + Json config{Object{}}; + config["foo"] = String{"bar"}; + auto test = [&]() { TypeCheck(config["foo"], "foo"); }; + ASSERT_THROW({ test(); }, dmlc::Error); + try { + test(); + } catch (dmlc::Error const& e) { + auto err = std::string{e.what()}; + ASSERT_NE(err.find("Number"), std::string::npos); + ASSERT_NE(err.find("I32Array"), std::string::npos); + ASSERT_NE(err.find("foo"), std::string::npos); + } +} } // namespace xgboost