Better error message when world size and rank are set as strings (#8316)

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou
2022-10-12 00:53:25 -07:00
committed by GitHub
parent 210915c985
commit 39afdac3be
6 changed files with 79 additions and 31 deletions

View File

@@ -7,7 +7,7 @@ find_package(Threads)
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
xgboost_target_properties(federated_proto)
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)

View File

@@ -4,6 +4,7 @@
#pragma once
#include <xgboost/json.h>
#include "../../src/c_api/c_api_utils.h"
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"
@@ -89,31 +90,14 @@ class FederatedCommunicator : public Communicator {
client_cert = value;
}
// Runtime configuration overrides.
auto const &j_server_address = config["federated_server_address"];
if (IsA<String const>(j_server_address)) {
server_address = get<String const>(j_server_address);
}
auto const &j_world_size = config["federated_world_size"];
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
}
auto const &j_client_key = config["federated_client_key"];
if (IsA<String const>(j_client_key)) {
client_key = get<String const>(j_client_key);
}
auto const &j_client_cert = config["federated_client_cert"];
if (IsA<String const>(j_client_cert)) {
client_cert = get<String const>(j_client_cert);
}
// Runtime configuration overrides, optional as users can specify them as env vars.
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";