Extract JSON type check. (#8677)
- Reuse it in `GetMissing`. - Add test.
This commit is contained in:
parent
9f598efc3e
commit
43152657d4
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) by XGBoost Contributors 2019-2022
|
* Copyright by XGBoost Contributors 2019-2023
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_JSON_H_
|
#ifndef XGBOOST_JSON_H_
|
||||||
#define XGBOOST_JSON_H_
|
#define XGBOOST_JSON_H_
|
||||||
@ -13,6 +13,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <type_traits> // std::enable_if,std::enable_if_t
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -595,6 +596,52 @@ using String = JsonString;
|
|||||||
using Null = JsonNull;
|
using Null = JsonNull;
|
||||||
|
|
||||||
// Utils tailored for XGBoost.
|
// Utils tailored for XGBoost.
|
||||||
|
namespace detail {
|
||||||
|
template <typename Head>
|
||||||
|
bool TypeCheckImpl(Json const& value) {
|
||||||
|
return IsA<Head>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Head, typename... JT>
|
||||||
|
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const& value) {
|
||||||
|
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Head>
|
||||||
|
std::string TypeCheckError() {
|
||||||
|
return "`" + Head{}.TypeStr() + "`";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Head, typename... JT>
|
||||||
|
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
|
||||||
|
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Type check for JSON-based parameters
|
||||||
|
*
|
||||||
|
* \tparam JT Expected JSON types.
|
||||||
|
* \param value Value to be checked.
|
||||||
|
*/
|
||||||
|
template <typename... JT>
|
||||||
|
void TypeCheck(Json const& value, StringView name) {
|
||||||
|
if (!detail::TypeCheckImpl<JT...>(value)) {
|
||||||
|
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
|
||||||
|
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
|
||||||
|
<< "`";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Convert XGBoost parameter to JSON object.
|
||||||
|
*
|
||||||
|
* \tparam Parameter An instantiation of XGBoostParameter
|
||||||
|
*
|
||||||
|
* \param param Input parameter
|
||||||
|
*
|
||||||
|
* \return JSON object representing the input parameter
|
||||||
|
*/
|
||||||
template <typename Parameter>
|
template <typename Parameter>
|
||||||
Object ToJson(Parameter const& param) {
|
Object ToJson(Parameter const& param) {
|
||||||
Object obj;
|
Object obj;
|
||||||
@ -604,6 +651,16 @@ Object ToJson(Parameter const& param) {
|
|||||||
return obj;
|
return obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Load a XGBoost parameter from a JSON object.
|
||||||
|
*
|
||||||
|
* \tparam Parameter An instantiation of XGBoostParameter
|
||||||
|
*
|
||||||
|
* \param obj JSON object representing the parameter.
|
||||||
|
* \param param Output parameter.
|
||||||
|
*
|
||||||
|
* \return Unknown arguments in the JSON object.
|
||||||
|
*/
|
||||||
template <typename Parameter>
|
template <typename Parameter>
|
||||||
Args FromJson(Json const& obj, Parameter* param) {
|
Args FromJson(Json const& obj, Parameter* param) {
|
||||||
auto const& j_param = get<Object const>(obj);
|
auto const& j_param = get<Object const>(obj);
|
||||||
|
|||||||
@ -164,7 +164,7 @@ inline float GetMissing(Json const &config) {
|
|||||||
missing = get<Integer const>(j_missing);
|
missing = get<Integer const>(j_missing);
|
||||||
} else {
|
} else {
|
||||||
missing = nan("");
|
missing = nan("");
|
||||||
LOG(FATAL) << "Invalid missing value: " << j_missing;
|
TypeCheck<Number, Integer>(j_missing, "missing");
|
||||||
}
|
}
|
||||||
return missing;
|
return missing;
|
||||||
}
|
}
|
||||||
@ -248,15 +248,6 @@ inline void GenerateFeatureMap(Learner const *learner,
|
|||||||
|
|
||||||
void XGBBuildInfoDevice(Json* p_info);
|
void XGBBuildInfoDevice(Json* p_info);
|
||||||
|
|
||||||
template <typename JT>
|
|
||||||
void TypeCheck(Json const &value, StringView name) {
|
|
||||||
using T = std::remove_const_t<JT> const;
|
|
||||||
if (!IsA<T>(value)) {
|
|
||||||
LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr()
|
|
||||||
<< "`, got: `" << value.GetValue().TypeStr() << "`.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename JT>
|
template <typename JT>
|
||||||
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
||||||
auto const &obj = get<Object const>(in);
|
auto const &obj = get<Object const>(in);
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "../../../src/common/io.h"
|
#include "../../../src/common/io.h"
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "dmlc/logging.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/json_io.h"
|
#include "xgboost/json_io.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
@ -675,4 +676,19 @@ TEST(UBJson, Basic) {
|
|||||||
ASSERT_FLOAT_EQ(2.71, get<Number>(get<Array>(ret["test"])[0]));
|
ASSERT_FLOAT_EQ(2.71, get<Number>(get<Array>(ret["test"])[0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Json, TypeCheck) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["foo"] = String{"bar"};
|
||||||
|
auto test = [&]() { TypeCheck<Number, Integer, Array, I32Array>(config["foo"], "foo"); };
|
||||||
|
ASSERT_THROW({ test(); }, dmlc::Error);
|
||||||
|
try {
|
||||||
|
test();
|
||||||
|
} catch (dmlc::Error const& e) {
|
||||||
|
auto err = std::string{e.what()};
|
||||||
|
ASSERT_NE(err.find("Number"), std::string::npos);
|
||||||
|
ASSERT_NE(err.find("I32Array"), std::string::npos);
|
||||||
|
ASSERT_NE(err.find("foo"), std::string::npos);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user