Extract JSON utils. (#9645)
This commit is contained in:
parent
4e5a7729c3
commit
680d53db43
@ -608,44 +608,6 @@ using Boolean = JsonBoolean;
|
|||||||
using String = JsonString;
|
using String = JsonString;
|
||||||
using Null = JsonNull;
|
using Null = JsonNull;
|
||||||
|
|
||||||
// Utils tailored for XGBoost.
|
|
||||||
namespace detail {
|
|
||||||
template <typename Head>
|
|
||||||
bool TypeCheckImpl(Json const& value) {
|
|
||||||
return IsA<Head>(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Head, typename... JT>
|
|
||||||
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const& value) {
|
|
||||||
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Head>
|
|
||||||
std::string TypeCheckError() {
|
|
||||||
return "`" + Head{}.TypeStr() + "`";
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Head, typename... JT>
|
|
||||||
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
|
|
||||||
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
|
|
||||||
}
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Type check for JSON-based parameters
|
|
||||||
*
|
|
||||||
* \tparam JT Expected JSON types.
|
|
||||||
* \param value Value to be checked.
|
|
||||||
*/
|
|
||||||
template <typename... JT>
|
|
||||||
void TypeCheck(Json const& value, StringView name) {
|
|
||||||
if (!detail::TypeCheckImpl<JT...>(value)) {
|
|
||||||
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
|
|
||||||
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
|
|
||||||
<< "`";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Convert XGBoost parameter to JSON object.
|
* \brief Convert XGBoost parameter to JSON object.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023 by XGBoost Contributors
|
* Copyright 2021-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||||
@ -13,6 +13,7 @@
|
|||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/json_utils.h" // for TypeCheck
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/data.h" // DMatrix
|
#include "xgboost/data.h" // DMatrix
|
||||||
#include "xgboost/feature_map.h" // for FeatureMap
|
#include "xgboost/feature_map.h" // for FeatureMap
|
||||||
@ -254,28 +255,6 @@ inline void GenerateFeatureMap(Learner const *learner,
|
|||||||
|
|
||||||
void XGBBuildInfoDevice(Json* p_info);
|
void XGBBuildInfoDevice(Json* p_info);
|
||||||
|
|
||||||
template <typename JT>
|
|
||||||
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
|
||||||
auto const &obj = get<Object const>(in);
|
|
||||||
auto it = obj.find(key);
|
|
||||||
if (it == obj.cend() || IsA<Null>(it->second)) {
|
|
||||||
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
|
|
||||||
}
|
|
||||||
TypeCheck<JT>(it->second, StringView{key});
|
|
||||||
return get<std::remove_const_t<JT> const>(it->second);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename JT, typename T>
|
|
||||||
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
|
|
||||||
auto const &obj = get<Object const>(in);
|
|
||||||
auto it = obj.find(key);
|
|
||||||
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
|
||||||
TypeCheck<JT>(it->second, key);
|
|
||||||
return get<std::remove_const_t<JT> const>(it->second);
|
|
||||||
}
|
|
||||||
return dft;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get shared ptr from DMatrix C handle with additional checks.
|
* \brief Get shared ptr from DMatrix C handle with additional checks.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
#define XGBOOST_COMMON_IO_H_
|
#define XGBOOST_COMMON_IO_H_
|
||||||
|
|
||||||
#include <dmlc/io.h>
|
#include <dmlc/io.h>
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
|
||||||
|
|
||||||
#include <algorithm> // for min, fill_n, copy_n
|
#include <algorithm> // for min, fill_n, copy_n
|
||||||
#include <array> // for array
|
#include <array> // for array
|
||||||
|
|||||||
74
src/common/json_utils.h
Normal file
74
src/common/json_utils.h
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*
|
||||||
|
* @brief Utils tailored for XGBoost.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string> // for string
|
||||||
|
#include <type_traits> // for enable_if_t, remove_const_t
|
||||||
|
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/string_view.h" // for StringView
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace detail {
|
||||||
|
template <typename Head>
|
||||||
|
bool TypeCheckImpl(Json const &value) {
|
||||||
|
return IsA<Head>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Head, typename... JT>
|
||||||
|
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const &value) {
|
||||||
|
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Head>
|
||||||
|
std::string TypeCheckError() {
|
||||||
|
return "`" + Head{}.TypeStr() + "`";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Head, typename... JT>
|
||||||
|
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
|
||||||
|
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Type check for JSON-based parameters
|
||||||
|
*
|
||||||
|
* @tparam JT Expected JSON types.
|
||||||
|
* @param value Value to be checked.
|
||||||
|
*/
|
||||||
|
template <typename... JT>
|
||||||
|
void TypeCheck(Json const &value, StringView name) {
|
||||||
|
if (!detail::TypeCheckImpl<JT...>(value)) {
|
||||||
|
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
|
||||||
|
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
|
||||||
|
<< "`";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename JT>
|
||||||
|
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
||||||
|
auto const &obj = get<Object const>(in);
|
||||||
|
auto it = obj.find(key);
|
||||||
|
if (it == obj.cend() || IsA<Null>(it->second)) {
|
||||||
|
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
|
||||||
|
}
|
||||||
|
TypeCheck<JT>(it->second, StringView{key});
|
||||||
|
return get<std::remove_const_t<JT> const>(it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename JT, typename T>
|
||||||
|
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
|
||||||
|
auto const &obj = get<Object const>(in);
|
||||||
|
auto it = obj.find(key);
|
||||||
|
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
||||||
|
TypeCheck<JT>(it->second, key);
|
||||||
|
|
||||||
|
return get<std::remove_const_t<JT> const>(it->second);
|
||||||
|
}
|
||||||
|
return dft;
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,19 +1,19 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023 by XGBoost contributors
|
* Copyright 2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "quantile_loss_utils.h"
|
#include "quantile_loss_utils.h"
|
||||||
|
|
||||||
#include <cctype> // std::isspace
|
#include <cctype> // for isspace
|
||||||
#include <istream> // std::istream
|
#include <istream> // for istream
|
||||||
#include <ostream> // std::ostream
|
#include <ostream> // for ostream
|
||||||
#include <string> // std::string
|
#include <string> // for string
|
||||||
#include <vector> // std::vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "xgboost/json.h" // F32Array,TypeCheck,get,Number
|
#include "../common/json_utils.h" // for TypeCheck
|
||||||
#include "xgboost/json_io.h" // JsonWriter
|
#include "xgboost/json.h" // for F32Array, get, Number
|
||||||
|
#include "xgboost/json_io.h" // for JsonWriter
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
|
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
|
||||||
auto const& t = array.Get();
|
auto const& t = array.Get();
|
||||||
xgboost::F32Array arr{t.size()};
|
xgboost::F32Array arr{t.size()};
|
||||||
@ -70,5 +70,4 @@ std::istream& operator>>(std::istream& is, ParamFloatArray& array) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DMLC_REGISTER_PARAMETER(QuantileLossParam);
|
DMLC_REGISTER_PARAMETER(QuantileLossParam);
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright (c) 2019-2023, XGBoost Contributors
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "../../../src/common/charconv.h"
|
#include "../../../src/common/charconv.h"
|
||||||
#include "../../../src/common/io.h"
|
#include "../../../src/common/io.h"
|
||||||
|
#include "../../../src/common/json_utils.h"
|
||||||
#include "../../../src/common/threading_utils.h" // for ParallelFor
|
#include "../../../src/common/threading_utils.h" // for ParallelFor
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user