[coll] Add C API for the tracker. (#9773)
This commit is contained in:
parent
06bdc15e9b
commit
44099f585d
@ -1508,6 +1508,83 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
||||
* @{
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Handle to tracker.
|
||||
*
|
||||
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
|
||||
* other one is `federated`.
|
||||
*
|
||||
* This is still under development.
|
||||
*/
|
||||
typedef void *TrackerHandle; /* NOLINT */
|
||||
|
||||
/**
|
||||
* @brief Create a new tracker.
|
||||
*
|
||||
* @param config JSON encoded parameters.
|
||||
*
|
||||
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
|
||||
* and `federated`.
|
||||
* - n_workers: Integer, the number of workers.
|
||||
* - port: (Optional) Integer, the port this tracker should listen to.
|
||||
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
|
||||
*
|
||||
* Some configurations are `rabit` specific:
|
||||
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
|
||||
*
|
||||
* Some `federated` specific configurations:
|
||||
* - federated_secure: Boolean, whether this is a secure server.
|
||||
* - server_key_path: Path to the server key. Used only if this is a secure server.
|
||||
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
|
||||
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
|
||||
*
|
||||
* @param handle The handle to the created tracker.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||
|
||||
/**
|
||||
* @brief Get the arguments needed for running workers. This should be called after
|
||||
* XGTrackerRun() and XGTrackerWait()
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param args The arguments returned as a JSON document.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||
|
||||
/**
|
||||
* @brief Run the tracker.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||
* the future.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
||||
|
||||
/**
|
||||
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
||||
* cannot close properly, manual interruption is required.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
|
||||
119
src/c_api/coll_c_api.cc
Normal file
119
src/c_api/coll_c_api.cc
Normal file
@ -0,0 +1,119 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
|
||||
#else
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#endif
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
namespace {
|
||||
using TrackerHandleT =
|
||||
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||
|
||||
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
||||
xgboost_CHECK_C_ARG_PTR(handle);
|
||||
auto *ptr = static_cast<TrackerHandleT *>(handle);
|
||||
CHECK(ptr);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
struct CollAPIEntry {
|
||||
std::string ret_str;
|
||||
};
|
||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr) {
|
||||
std::chrono::seconds wait_for{100};
|
||||
auto fut = ptr->second;
|
||||
while (fut.valid()) {
|
||||
auto res = fut.wait_for(wait_for);
|
||||
CHECK(res != std::future_status::deferred);
|
||||
if (res == std::future_status::ready) {
|
||||
auto const &rc = ptr->second.get();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
|
||||
Json jconfig = Json::Load(config);
|
||||
|
||||
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
||||
std::unique_ptr<collective::Tracker> tptr;
|
||||
if (type == "federated") {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
|
||||
#else
|
||||
LOG(FATAL) << error::NoFederated();
|
||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||
} else if (type == "rabit") {
|
||||
tptr = std::make_unique<collective::RabitTracker>(jconfig);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator:" << type;
|
||||
}
|
||||
|
||||
auto ptr = new TrackerHandleT{std::move(tptr), std::future<collective::Result>{}};
|
||||
static_assert(std::is_same_v<std::remove_pointer_t<decltype(ptr)>, TrackerHandleT>);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(handle);
|
||||
*handle = ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
auto &local = *CollAPIThreadLocalStore::Get();
|
||||
local.ret_str = Json::Dump(ptr->first->WorkerArgs());
|
||||
xgboost_CHECK_C_ARG_PTR(args);
|
||||
*args = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
||||
ptr->second = ptr->first->Run();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
WaitImpl(ptr);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
WaitImpl(ptr);
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
@ -114,6 +114,9 @@ class RabitTracker : public Tracker {
|
||||
// record for how to reach out to workers if error happens.
|
||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||
// listening socket for incoming workers.
|
||||
//
|
||||
// At the moment, the listener calls accept without first polling. We can add an
|
||||
// additional unix domain socket to allow cancelling the accept.
|
||||
TCPSocket listener_;
|
||||
|
||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||
|
||||
@ -97,5 +97,7 @@ constexpr StringView InvalidCUDAOrdinal() {
|
||||
}
|
||||
|
||||
void MismatchedDevices(Context const* booster, Context const* data);
|
||||
|
||||
inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
63
tests/cpp/collective/test_coll_c_api.cc
Normal file
63
tests/cpp/collective/test_coll_c_api.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
#include <chrono> // for ""s
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "test_worker.h" // for SocketTest
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class TrackerAPITest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(TrackerAPITest, CAPI) {
|
||||
TrackerHandle handle;
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = String{"rabit"};
|
||||
config["n_workers"] = 2;
|
||||
config["timeout"] = 1;
|
||||
auto config_str = Json::Dump(config);
|
||||
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
rc = XGTrackerRun(handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
|
||||
std::thread bg_wait{[&] {
|
||||
Json config{Object{}};
|
||||
auto config_str = Json::Dump(config);
|
||||
auto rc = XGTrackerWait(handle, config_str.c_str());
|
||||
ASSERT_EQ(rc, 0);
|
||||
}};
|
||||
|
||||
char const* cargs;
|
||||
rc = XGTrackerWorkerArgs(handle, &cargs);
|
||||
ASSERT_EQ(rc, 0);
|
||||
auto args = Json::Load(StringView{cargs});
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
|
||||
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
|
||||
ASSERT_NE(port, 0);
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
for (std::int32_t r = 0; r < 2; ++r) {
|
||||
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
|
||||
}
|
||||
for (auto& w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
rc = XGTrackerFree(handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
|
||||
bg_wait.join();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Loading…
x
Reference in New Issue
Block a user