[coll] Add C API for the tracker. (#9773)

This commit is contained in:
Jiaming Yuan 2023-11-08 18:17:14 +08:00 committed by GitHub
parent 06bdc15e9b
commit 44099f585d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 264 additions and 0 deletions

View File

@ -1508,6 +1508,83 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
* @{
*/
/**
* @brief Handle to tracker.
*
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
* other one is `federated`.
*
* This is still under development.
*/
typedef void *TrackerHandle; /* NOLINT */
/**
* @brief Create a new tracker.
*
* @param config JSON encoded parameters.
*
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
* and `federated`.
* - n_workers: Integer, the number of workers.
* - port: (Optional) Integer, the port this tracker should listen to.
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
*
* Some configurations are `rabit` specific:
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
*
* Some `federated` specific configurations:
* - federated_secure: Boolean, whether this is a secure server.
* - server_key_path: Path to the server key. Used only if this is a secure server.
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
*
* @param handle The handle to the created tracker.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
/**
* @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait()
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
/**
* @brief Run the tracker.
*
* @param handle The handle to the tracker.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerRun(TrackerHandle handle);
/**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
* the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
/**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
* cannot close properly, manual interruption is required.
*
* @param handle The handle to the tracker.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerFree(TrackerHandle handle);
/*!
* \brief Initialize the collective communicator.
*

119
src/c_api/coll_c_api.cc Normal file
View File

@ -0,0 +1,119 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/tracker.h" // for RabitTracker
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
#include "xgboost/string_view.h" // for StringView
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
#else
#include "../common/error_msg.h" // for NoFederated
#endif
using namespace xgboost; // NOLINT
namespace {
using TrackerHandleT =
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
xgboost_CHECK_C_ARG_PTR(handle);
auto *ptr = static_cast<TrackerHandleT *>(handle);
CHECK(ptr);
return ptr;
}
struct CollAPIEntry {
std::string ret_str;
};
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr) {
std::chrono::seconds wait_for{100};
auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
CHECK(res != std::future_status::deferred);
if (res == std::future_status::ready) {
auto const &rc = ptr->second.get();
CHECK(rc.OK()) << rc.Report();
break;
}
}
}
} // namespace
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(config);
Json jconfig = Json::Load(config);
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
std::unique_ptr<collective::Tracker> tptr;
if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
#else
LOG(FATAL) << error::NoFederated();
#endif // defined(XGBOOST_USE_FEDERATED)
} else if (type == "rabit") {
tptr = std::make_unique<collective::RabitTracker>(jconfig);
} else {
LOG(FATAL) << "Unknown communicator:" << type;
}
auto ptr = new TrackerHandleT{std::move(tptr), std::future<collective::Result>{}};
static_assert(std::is_same_v<std::remove_pointer_t<decltype(ptr)>, TrackerHandleT>);
xgboost_CHECK_C_ARG_PTR(handle);
*handle = ptr;
API_END();
}
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
auto &local = *CollAPIThreadLocalStore::Get();
local.ret_str = Json::Dump(ptr->first->WorkerArgs());
xgboost_CHECK_C_ARG_PTR(args);
*args = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
CHECK(!ptr->second.valid()) << "Tracker is already running.";
ptr->second = ptr->first->Run();
API_END();
}
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
WaitImpl(ptr);
API_END();
}
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
WaitImpl(ptr);
delete ptr;
API_END();
}

View File

@ -114,6 +114,9 @@ class RabitTracker : public Tracker {
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
//
// At the moment, the listener calls accept without first polling. We can add an
// additional unix domain socket to allow cancelling the accept.
TCPSocket listener_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);

View File

@ -97,5 +97,7 @@ constexpr StringView InvalidCUDAOrdinal() {
}
void MismatchedDevices(Context const* booster, Context const* data);
inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -0,0 +1,63 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <chrono> // for ""s
#include <thread> // for thread
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for SocketTest
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace {
class TrackerAPITest : public SocketTest {};
} // namespace
TEST_F(TrackerAPITest, CAPI) {
TrackerHandle handle;
Json config{Object{}};
config["dmlc_communicator"] = String{"rabit"};
config["n_workers"] = 2;
config["timeout"] = 1;
auto config_str = Json::Dump(config);
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
ASSERT_EQ(rc, 0);
rc = XGTrackerRun(handle);
ASSERT_EQ(rc, 0);
std::thread bg_wait{[&] {
Json config{Object{}};
auto config_str = Json::Dump(config);
auto rc = XGTrackerWait(handle, config_str.c_str());
ASSERT_EQ(rc, 0);
}};
char const* cargs;
rc = XGTrackerWorkerArgs(handle, &cargs);
ASSERT_EQ(rc, 0);
auto args = Json::Load(StringView{cargs});
std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
ASSERT_NE(port, 0);
std::vector<std::thread> workers;
using namespace std::chrono_literals; // NOLINT
for (std::int32_t r = 0; r < 2; ++r) {
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
}
for (auto& w : workers) {
w.join();
}
rc = XGTrackerFree(handle);
ASSERT_EQ(rc, 0);
bg_wait.join();
}
} // namespace xgboost::collective