[coll] Add C API for the tracker. (#9773)
This commit is contained in:
parent
06bdc15e9b
commit
44099f585d
@ -1508,6 +1508,83 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
|||||||
* @{
|
* @{
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Handle to tracker.
|
||||||
|
*
|
||||||
|
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
|
||||||
|
* other one is `federated`.
|
||||||
|
*
|
||||||
|
* This is still under development.
|
||||||
|
*/
|
||||||
|
typedef void *TrackerHandle; /* NOLINT */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create a new tracker.
|
||||||
|
*
|
||||||
|
* @param config JSON encoded parameters.
|
||||||
|
*
|
||||||
|
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
|
||||||
|
* and `federated`.
|
||||||
|
* - n_workers: Integer, the number of workers.
|
||||||
|
* - port: (Optional) Integer, the port this tracker should listen to.
|
||||||
|
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
|
||||||
|
*
|
||||||
|
* Some configurations are `rabit` specific:
|
||||||
|
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
|
||||||
|
*
|
||||||
|
* Some `federated` specific configurations:
|
||||||
|
* - federated_secure: Boolean, whether this is a secure server.
|
||||||
|
* - server_key_path: Path to the server key. Used only if this is a secure server.
|
||||||
|
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
|
||||||
|
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
|
||||||
|
*
|
||||||
|
* @param handle The handle to the created tracker.
|
||||||
|
*
|
||||||
|
* @return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the arguments needed for running workers. This should be called after
|
||||||
|
* XGTrackerRun() and XGTrackerWait()
|
||||||
|
*
|
||||||
|
* @param handle The handle to the tracker.
|
||||||
|
* @param args The arguments returned as a JSON document.
|
||||||
|
*
|
||||||
|
* @return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run the tracker.
|
||||||
|
*
|
||||||
|
* @param handle The handle to the tracker.
|
||||||
|
*
|
||||||
|
* @return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
||||||
|
*
|
||||||
|
* @param handle The handle to the tracker.
|
||||||
|
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||||
|
* the future.
|
||||||
|
*
|
||||||
|
* @return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
||||||
|
* cannot close properly, manual interruption is required.
|
||||||
|
*
|
||||||
|
* @param handle The handle to the tracker.
|
||||||
|
*
|
||||||
|
* @return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGTrackerFree(TrackerHandle handle);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Initialize the collective communicator.
|
* \brief Initialize the collective communicator.
|
||||||
*
|
*
|
||||||
|
|||||||
119
src/c_api/coll_c_api.cc
Normal file
119
src/c_api/coll_c_api.cc
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <chrono> // for seconds
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <future> // for future
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||||
|
#include <utility> // for pair
|
||||||
|
|
||||||
|
#include "../collective/tracker.h" // for RabitTracker
|
||||||
|
#include "c_api_error.h" // for API_BEGIN
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
#include "xgboost/string_view.h" // for StringView
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
|
||||||
|
#else
|
||||||
|
#include "../common/error_msg.h" // for NoFederated
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using namespace xgboost; // NOLINT
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using TrackerHandleT =
|
||||||
|
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||||
|
|
||||||
|
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(handle);
|
||||||
|
auto *ptr = static_cast<TrackerHandleT *>(handle);
|
||||||
|
CHECK(ptr);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CollAPIEntry {
|
||||||
|
std::string ret_str;
|
||||||
|
};
|
||||||
|
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||||
|
|
||||||
|
void WaitImpl(TrackerHandleT *ptr) {
|
||||||
|
std::chrono::seconds wait_for{100};
|
||||||
|
auto fut = ptr->second;
|
||||||
|
while (fut.valid()) {
|
||||||
|
auto res = fut.wait_for(wait_for);
|
||||||
|
CHECK(res != std::future_status::deferred);
|
||||||
|
if (res == std::future_status::ready) {
|
||||||
|
auto const &rc = ptr->second.get();
|
||||||
|
CHECK(rc.OK()) << rc.Report();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
|
||||||
|
API_BEGIN();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
|
|
||||||
|
Json jconfig = Json::Load(config);
|
||||||
|
|
||||||
|
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
||||||
|
std::unique_ptr<collective::Tracker> tptr;
|
||||||
|
if (type == "federated") {
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
|
||||||
|
#else
|
||||||
|
LOG(FATAL) << error::NoFederated();
|
||||||
|
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||||
|
} else if (type == "rabit") {
|
||||||
|
tptr = std::make_unique<collective::RabitTracker>(jconfig);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Unknown communicator:" << type;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ptr = new TrackerHandleT{std::move(tptr), std::future<collective::Result>{}};
|
||||||
|
static_assert(std::is_same_v<std::remove_pointer_t<decltype(ptr)>, TrackerHandleT>);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(handle);
|
||||||
|
*handle = ptr;
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
|
||||||
|
API_BEGIN();
|
||||||
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
|
auto &local = *CollAPIThreadLocalStore::Get();
|
||||||
|
local.ret_str = Json::Dump(ptr->first->WorkerArgs());
|
||||||
|
xgboost_CHECK_C_ARG_PTR(args);
|
||||||
|
*args = local.ret_str.c_str();
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
||||||
|
API_BEGIN();
|
||||||
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
|
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
||||||
|
ptr->second = ptr->first->Run();
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||||
|
API_BEGIN();
|
||||||
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
|
auto jconfig = Json::Load(StringView{config});
|
||||||
|
WaitImpl(ptr);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||||
|
API_BEGIN();
|
||||||
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
|
WaitImpl(ptr);
|
||||||
|
delete ptr;
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
@ -114,6 +114,9 @@ class RabitTracker : public Tracker {
|
|||||||
// record for how to reach out to workers if error happens.
|
// record for how to reach out to workers if error happens.
|
||||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||||
// listening socket for incoming workers.
|
// listening socket for incoming workers.
|
||||||
|
//
|
||||||
|
// At the moment, the listener calls accept without first polling. We can add an
|
||||||
|
// additional unix domain socket to allow cancelling the accept.
|
||||||
TCPSocket listener_;
|
TCPSocket listener_;
|
||||||
|
|
||||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||||
|
|||||||
@ -97,5 +97,7 @@ constexpr StringView InvalidCUDAOrdinal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MismatchedDevices(Context const* booster, Context const* data);
|
void MismatchedDevices(Context const* booster, Context const* data);
|
||||||
|
|
||||||
|
inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
|
||||||
} // namespace xgboost::error
|
} // namespace xgboost::error
|
||||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|||||||
63
tests/cpp/collective/test_coll_c_api.cc
Normal file
63
tests/cpp/collective/test_coll_c_api.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/c_api.h>
|
||||||
|
|
||||||
|
#include <chrono> // for ""s
|
||||||
|
#include <thread> // for thread
|
||||||
|
|
||||||
|
#include "../../../src/collective/tracker.h"
|
||||||
|
#include "test_worker.h" // for SocketTest
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class TrackerAPITest : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(TrackerAPITest, CAPI) {
|
||||||
|
TrackerHandle handle;
|
||||||
|
Json config{Object{}};
|
||||||
|
config["dmlc_communicator"] = String{"rabit"};
|
||||||
|
config["n_workers"] = 2;
|
||||||
|
config["timeout"] = 1;
|
||||||
|
auto config_str = Json::Dump(config);
|
||||||
|
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
||||||
|
ASSERT_EQ(rc, 0);
|
||||||
|
rc = XGTrackerRun(handle);
|
||||||
|
ASSERT_EQ(rc, 0);
|
||||||
|
|
||||||
|
std::thread bg_wait{[&] {
|
||||||
|
Json config{Object{}};
|
||||||
|
auto config_str = Json::Dump(config);
|
||||||
|
auto rc = XGTrackerWait(handle, config_str.c_str());
|
||||||
|
ASSERT_EQ(rc, 0);
|
||||||
|
}};
|
||||||
|
|
||||||
|
char const* cargs;
|
||||||
|
rc = XGTrackerWorkerArgs(handle, &cargs);
|
||||||
|
ASSERT_EQ(rc, 0);
|
||||||
|
auto args = Json::Load(StringView{cargs});
|
||||||
|
|
||||||
|
std::string host;
|
||||||
|
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||||
|
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
|
||||||
|
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
|
||||||
|
ASSERT_NE(port, 0);
|
||||||
|
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
using namespace std::chrono_literals; // NOLINT
|
||||||
|
for (std::int32_t r = 0; r < 2; ++r) {
|
||||||
|
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
|
||||||
|
}
|
||||||
|
for (auto& w : workers) {
|
||||||
|
w.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
rc = XGTrackerFree(handle);
|
||||||
|
ASSERT_EQ(rc, 0);
|
||||||
|
|
||||||
|
bg_wait.join();
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
Loading…
x
Reference in New Issue
Block a user