Better error message when world size and rank are set as strings (#8316)
Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
parent
210915c985
commit
39afdac3be
@ -7,7 +7,7 @@ find_package(Threads)
|
|||||||
add_library(federated_proto federated.proto)
|
add_library(federated_proto federated.proto)
|
||||||
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||||
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
|
xgboost_target_properties(federated_proto)
|
||||||
|
|
||||||
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
|
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
|
||||||
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
|
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
|
|
||||||
|
#include "../../src/c_api/c_api_utils.h"
|
||||||
#include "../../src/collective/communicator.h"
|
#include "../../src/collective/communicator.h"
|
||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
#include "federated_client.h"
|
#include "federated_client.h"
|
||||||
@ -89,31 +90,14 @@ class FederatedCommunicator : public Communicator {
|
|||||||
client_cert = value;
|
client_cert = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runtime configuration overrides.
|
// Runtime configuration overrides, optional as users can specify them as env vars.
|
||||||
auto const &j_server_address = config["federated_server_address"];
|
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||||
if (IsA<String const>(j_server_address)) {
|
world_size =
|
||||||
server_address = get<String const>(j_server_address);
|
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||||
}
|
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||||
auto const &j_world_size = config["federated_world_size"];
|
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
|
||||||
if (IsA<Integer const>(j_world_size)) {
|
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
|
||||||
world_size = static_cast<int>(get<Integer const>(j_world_size));
|
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
|
||||||
}
|
|
||||||
auto const &j_rank = config["federated_rank"];
|
|
||||||
if (IsA<Integer const>(j_rank)) {
|
|
||||||
rank = static_cast<int>(get<Integer const>(j_rank));
|
|
||||||
}
|
|
||||||
auto const &j_server_cert = config["federated_server_cert"];
|
|
||||||
if (IsA<String const>(j_server_cert)) {
|
|
||||||
server_cert = get<String const>(j_server_cert);
|
|
||||||
}
|
|
||||||
auto const &j_client_key = config["federated_client_key"];
|
|
||||||
if (IsA<String const>(j_client_key)) {
|
|
||||||
client_key = get<String const>(j_client_key);
|
|
||||||
}
|
|
||||||
auto const &j_client_cert = config["federated_client_cert"];
|
|
||||||
if (IsA<String const>(j_client_cert)) {
|
|
||||||
client_cert = get<String const>(j_client_cert);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (server_address.empty()) {
|
if (server_address.empty()) {
|
||||||
LOG(FATAL) << "Federated server address must be set.";
|
LOG(FATAL) << "Federated server address must be set.";
|
||||||
|
|||||||
@ -248,13 +248,23 @@ inline void GenerateFeatureMap(Learner const *learner,
|
|||||||
|
|
||||||
void XGBBuildInfoDevice(Json* p_info);
|
void XGBBuildInfoDevice(Json* p_info);
|
||||||
|
|
||||||
|
template <typename JT>
|
||||||
|
void TypeCheck(Json const &value, StringView name) {
|
||||||
|
using T = std::remove_const_t<JT> const;
|
||||||
|
if (!IsA<T>(value)) {
|
||||||
|
LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr()
|
||||||
|
<< "`, got: `" << value.GetValue().TypeStr() << "`.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename JT>
|
template <typename JT>
|
||||||
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
|
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
|
||||||
auto const &obj = get<Object const>(in);
|
auto const &obj = get<Object const>(in);
|
||||||
auto it = obj.find(key);
|
auto it = obj.find(key);
|
||||||
if (it == obj.cend() || IsA<Null>(it->second)) {
|
if (it == obj.cend() || IsA<Null>(it->second)) {
|
||||||
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`";
|
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
|
||||||
}
|
}
|
||||||
|
TypeCheck<JT>(it->second, StringView{key});
|
||||||
return get<std::remove_const_t<JT> const>(it->second);
|
return get<std::remove_const_t<JT> const>(it->second);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,7 +272,8 @@ template <typename JT, typename T>
|
|||||||
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
|
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
|
||||||
auto const &obj = get<Object const>(in);
|
auto const &obj = get<Object const>(in);
|
||||||
auto it = obj.find(key);
|
auto it = obj.find(key);
|
||||||
if (it != obj.cend()) {
|
if (it != obj.cend() && !IsA<Null>(it->second)) {
|
||||||
|
TypeCheck<JT>(it->second, StringView{key});
|
||||||
return get<std::remove_const_t<JT> const>(it->second);
|
return get<std::remove_const_t<JT> const>(it->second);
|
||||||
}
|
}
|
||||||
return dft;
|
return dft;
|
||||||
|
|||||||
@ -17,9 +17,8 @@ class NoOpCommunicator : public Communicator {
|
|||||||
NoOpCommunicator() : Communicator(1, 0) {}
|
NoOpCommunicator() : Communicator(1, 0) {}
|
||||||
bool IsDistributed() const override { return false; }
|
bool IsDistributed() const override { return false; }
|
||||||
bool IsFederated() const override { return false; }
|
bool IsFederated() const override { return false; }
|
||||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||||
Operation op) override {}
|
void Broadcast(void *, std::size_t, int) override {}
|
||||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
|
|
||||||
std::string GetProcessorName() override { return ""; }
|
std::string GetProcessorName() override { return ""; }
|
||||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||||
|
|
||||||
|
|||||||
@ -324,4 +324,36 @@ TEST(CAPI, NullPtr) {
|
|||||||
ASSERT_NE(pos, std::string::npos);
|
ASSERT_NE(pos, std::string::npos);
|
||||||
XGBAPISetLastError("");
|
XGBAPISetLastError("");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, JArgs) {
|
||||||
|
{
|
||||||
|
Json args{Object{}};
|
||||||
|
args["key"] = String{"value"};
|
||||||
|
args["null"] = Null{};
|
||||||
|
auto value = OptionalArg<String>(args, "key", std::string{"foo"});
|
||||||
|
ASSERT_EQ(value, "value");
|
||||||
|
value = OptionalArg<String const>(args, "key", std::string{"foo"});
|
||||||
|
ASSERT_EQ(value, "value");
|
||||||
|
|
||||||
|
ASSERT_THROW({ OptionalArg<Number>(args, "key", 0.0f); }, dmlc::Error);
|
||||||
|
value = OptionalArg<String const>(args, "bar", std::string{"foo"});
|
||||||
|
ASSERT_EQ(value, "foo");
|
||||||
|
value = OptionalArg<String const>(args, "null", std::string{"foo"});
|
||||||
|
ASSERT_EQ(value, "foo");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Json args{Object{}};
|
||||||
|
args["key"] = String{"value"};
|
||||||
|
args["null"] = Null{};
|
||||||
|
auto value = RequiredArg<String>(args, "key", __func__);
|
||||||
|
ASSERT_EQ(value, "value");
|
||||||
|
value = RequiredArg<String const>(args, "key", __func__);
|
||||||
|
ASSERT_EQ(value, "value");
|
||||||
|
|
||||||
|
ASSERT_THROW({ RequiredArg<Integer>(args, "key", __func__); }, dmlc::Error);
|
||||||
|
ASSERT_THROW({ RequiredArg<String const>(args, "foo", __func__); }, dmlc::Error);
|
||||||
|
ASSERT_THROW({ RequiredArg<String>(args, "null", __func__); }, dmlc::Error);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -85,6 +85,28 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
|
|||||||
EXPECT_THROW(construct(), dmlc::Error);
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
|
||||||
|
auto construct = []() {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["federated_server_address"] = kServerAddress;
|
||||||
|
config["federated_world_size"] = std::string("1");
|
||||||
|
config["federated_rank"] = Integer(0);
|
||||||
|
auto *comm = FederatedCommunicator::Create(config);
|
||||||
|
};
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
|
||||||
|
auto construct = []() {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["federated_server_address"] = kServerAddress;
|
||||||
|
config["federated_world_size"] = 1;
|
||||||
|
config["federated_rank"] = std::string("0");
|
||||||
|
auto *comm = FederatedCommunicator::Create(config);
|
||||||
|
};
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
|
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
|
||||||
FederatedCommunicator comm{6, 3, kServerAddress};
|
FederatedCommunicator comm{6, 3, kServerAddress};
|
||||||
EXPECT_EQ(comm.GetWorldSize(), 6);
|
EXPECT_EQ(comm.GetWorldSize(), 6);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user