Extract JSON utils. (#9645)
This commit is contained in:
parent
4e5a7729c3
commit
680d53db43
@ -608,44 +608,6 @@ using Boolean = JsonBoolean;
|
||||
using String = JsonString;
|
||||
using Null = JsonNull;
|
||||
|
||||
// Utils tailored for XGBoost.
|
||||
namespace detail {
|
||||
template <typename Head>
|
||||
bool TypeCheckImpl(Json const& value) {
|
||||
return IsA<Head>(value);
|
||||
}
|
||||
|
||||
template <typename Head, typename... JT>
|
||||
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const& value) {
|
||||
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
|
||||
}
|
||||
|
||||
template <typename Head>
|
||||
std::string TypeCheckError() {
|
||||
return "`" + Head{}.TypeStr() + "`";
|
||||
}
|
||||
|
||||
template <typename Head, typename... JT>
|
||||
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
|
||||
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* \brief Type check for JSON-based parameters
|
||||
*
|
||||
* \tparam JT Expected JSON types.
|
||||
* \param value Value to be checked.
|
||||
*/
|
||||
template <typename... JT>
|
||||
void TypeCheck(Json const& value, StringView name) {
|
||||
if (!detail::TypeCheckImpl<JT...>(value)) {
|
||||
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
|
||||
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
|
||||
<< "`";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Convert XGBoost parameter to JSON object.
|
||||
*
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||
@ -13,6 +13,7 @@
|
||||
#include <utility> // for move
|
||||
#include <vector>
|
||||
|
||||
#include "../common/json_utils.h" // for TypeCheck
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h" // DMatrix
|
||||
#include "xgboost/feature_map.h" // for FeatureMap
|
||||
@ -254,28 +255,6 @@ inline void GenerateFeatureMap(Learner const *learner,
|
||||
|
||||
void XGBBuildInfoDevice(Json* p_info);
|
||||
|
||||
template <typename JT>
|
||||
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
||||
auto const &obj = get<Object const>(in);
|
||||
auto it = obj.find(key);
|
||||
if (it == obj.cend() || IsA<Null>(it->second)) {
|
||||
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
|
||||
}
|
||||
TypeCheck<JT>(it->second, StringView{key});
|
||||
return get<std::remove_const_t<JT> const>(it->second);
|
||||
}
|
||||
|
||||
template <typename JT, typename T>
|
||||
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
|
||||
auto const &obj = get<Object const>(in);
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
||||
TypeCheck<JT>(it->second, key);
|
||||
return get<std::remove_const_t<JT> const>(it->second);
|
||||
}
|
||||
return dft;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get shared ptr from DMatrix C handle with additional checks.
|
||||
*/
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#define XGBOOST_COMMON_IO_H_
|
||||
|
||||
#include <dmlc/io.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
|
||||
|
||||
#include <algorithm> // for min, fill_n, copy_n
|
||||
#include <array> // for array
|
||||
|
||||
74
src/common/json_utils.h
Normal file
74
src/common/json_utils.h
Normal file
@ -0,0 +1,74 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*
|
||||
* @brief Utils tailored for XGBoost.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for enable_if_t, remove_const_t
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost {
|
||||
namespace detail {
|
||||
template <typename Head>
|
||||
bool TypeCheckImpl(Json const &value) {
|
||||
return IsA<Head>(value);
|
||||
}
|
||||
|
||||
template <typename Head, typename... JT>
|
||||
std::enable_if_t<sizeof...(JT) != 0, bool> TypeCheckImpl(Json const &value) {
|
||||
return IsA<Head>(value) || TypeCheckImpl<JT...>(value);
|
||||
}
|
||||
|
||||
template <typename Head>
|
||||
std::string TypeCheckError() {
|
||||
return "`" + Head{}.TypeStr() + "`";
|
||||
}
|
||||
|
||||
template <typename Head, typename... JT>
|
||||
std::enable_if_t<sizeof...(JT) != 0, std::string> TypeCheckError() {
|
||||
return "`" + Head{}.TypeStr() + "`, " + TypeCheckError<JT...>();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* @brief Type check for JSON-based parameters
|
||||
*
|
||||
* @tparam JT Expected JSON types.
|
||||
* @param value Value to be checked.
|
||||
*/
|
||||
template <typename... JT>
|
||||
void TypeCheck(Json const &value, StringView name) {
|
||||
if (!detail::TypeCheckImpl<JT...>(value)) {
|
||||
LOG(FATAL) << "Invalid type for: `" << name << "`, expecting one of the: {`"
|
||||
<< detail::TypeCheckError<JT...>() << "}, got: `" << value.GetValue().TypeStr()
|
||||
<< "`";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename JT>
|
||||
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
|
||||
auto const &obj = get<Object const>(in);
|
||||
auto it = obj.find(key);
|
||||
if (it == obj.cend() || IsA<Null>(it->second)) {
|
||||
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
|
||||
}
|
||||
TypeCheck<JT>(it->second, StringView{key});
|
||||
return get<std::remove_const_t<JT> const>(it->second);
|
||||
}
|
||||
|
||||
template <typename JT, typename T>
|
||||
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
|
||||
auto const &obj = get<Object const>(in);
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
||||
TypeCheck<JT>(it->second, key);
|
||||
|
||||
return get<std::remove_const_t<JT> const>(it->second);
|
||||
}
|
||||
return dft;
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -1,19 +1,19 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "quantile_loss_utils.h"
|
||||
|
||||
#include <cctype> // std::isspace
|
||||
#include <istream> // std::istream
|
||||
#include <ostream> // std::ostream
|
||||
#include <string> // std::string
|
||||
#include <vector> // std::vector
|
||||
#include <cctype> // for isspace
|
||||
#include <istream> // for istream
|
||||
#include <ostream> // for ostream
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "xgboost/json.h" // F32Array,TypeCheck,get,Number
|
||||
#include "xgboost/json_io.h" // JsonWriter
|
||||
#include "../common/json_utils.h" // for TypeCheck
|
||||
#include "xgboost/json.h" // for F32Array, get, Number
|
||||
#include "xgboost/json_io.h" // for JsonWriter
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
|
||||
auto const& t = array.Get();
|
||||
xgboost::F32Array arr{t.size()};
|
||||
@ -70,5 +70,4 @@ std::istream& operator>>(std::istream& is, ParamFloatArray& array) {
|
||||
}
|
||||
|
||||
DMLC_REGISTER_PARAMETER(QuantileLossParam);
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright (c) 2019-2023, XGBoost Contributors
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include "../../../src/common/charconv.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/common/json_utils.h"
|
||||
#include "../../../src/common/threading_utils.h" // for ParallelFor
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../helpers.h"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user