merge latest changes
This commit is contained in:
@@ -1508,6 +1508,83 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
||||
* @{
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Handle to tracker.
|
||||
*
|
||||
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
|
||||
* other one is `federated`.
|
||||
*
|
||||
* This is still under development.
|
||||
*/
|
||||
typedef void *TrackerHandle; /* NOLINT */
|
||||
|
||||
/**
|
||||
* @brief Create a new tracker.
|
||||
*
|
||||
* @param config JSON encoded parameters.
|
||||
*
|
||||
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
|
||||
* and `federated`.
|
||||
* - n_workers: Integer, the number of workers.
|
||||
* - port: (Optional) Integer, the port this tracker should listen to.
|
||||
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
|
||||
*
|
||||
* Some configurations are `rabit` specific:
|
||||
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
|
||||
*
|
||||
* Some `federated` specific configurations:
|
||||
* - federated_secure: Boolean, whether this is a secure server.
|
||||
* - server_key_path: Path to the server key. Used only if this is a secure server.
|
||||
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
|
||||
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
|
||||
*
|
||||
* @param handle The handle to the created tracker.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||
|
||||
/**
|
||||
* @brief Get the arguments needed for running workers. This should be called after
|
||||
* XGTrackerRun() and XGTrackerWait()
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param args The arguments returned as a JSON document.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||
|
||||
/**
|
||||
* @brief Run the tracker.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||
|
||||
/**
|
||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||
* the future.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
||||
|
||||
/**
|
||||
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
||||
* cannot close properly, manual interruption is required.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
@@ -1536,6 +1613,8 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
|
||||
* `USE_DLOPEN_NCCL`.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
|
||||
@@ -412,19 +412,24 @@ class TCPSocket {
|
||||
return Success();
|
||||
}
|
||||
|
||||
void SetKeepAlive() {
|
||||
[[nodiscard]] Result SetKeepAlive() {
|
||||
std::int32_t keepalive = 1;
|
||||
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
|
||||
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
|
||||
0);
|
||||
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
|
||||
sizeof(keepalive));
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set TCP keeaplive.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
void SetNoDelay() {
|
||||
[[nodiscard]] Result SetNoDelay() {
|
||||
std::int32_t tcp_no_delay = 1;
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||
sizeof(tcp_no_delay)),
|
||||
0);
|
||||
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||
sizeof(tcp_no_delay));
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set TCP no delay.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -250,9 +250,15 @@ struct Context : public XGBoostParameter<Context> {
|
||||
default:
|
||||
// Do not use the device name as this is likely an internal error, the name
|
||||
// wouldn't be valid.
|
||||
LOG(FATAL) << "Unknown device type:"
|
||||
<< static_cast<std::underlying_type_t<DeviceOrd::Type>>(this->Device().device);
|
||||
break;
|
||||
if (this->Device().IsSycl()) {
|
||||
LOG(WARNING) << "The requested feature doesn't have SYCL specific implementation yet. "
|
||||
<< "CPU implementation is used";
|
||||
return cpu_fn();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown device type:"
|
||||
<< static_cast<std::underlying_type_t<DeviceOrd::Type>>(this->Device().device);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::invoke_result_t<CPUFn>();
|
||||
}
|
||||
@@ -262,7 +268,6 @@ struct Context : public XGBoostParameter<Context> {
|
||||
*/
|
||||
template <typename CPUFn, typename CUDAFn, typename SYCLFn>
|
||||
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>);
|
||||
if (this->Device().IsSycl()) {
|
||||
return sycl_fn();
|
||||
|
||||
@@ -178,7 +178,7 @@ class MetaInfo {
|
||||
* in vertical federated learning, since each worker loads its own list of columns,
|
||||
* we need to sum them.
|
||||
*/
|
||||
void SynchronizeNumberOfColumns();
|
||||
void SynchronizeNumberOfColumns(Context const* ctx);
|
||||
|
||||
/*! \brief Whether the data is split row-wise. */
|
||||
bool IsRowSplit() const {
|
||||
|
||||
@@ -582,20 +582,20 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
|
||||
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T, ext> data, S &&...shape) {
|
||||
std::size_t in_shape[sizeof...(S)];
|
||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
||||
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, common::Span<T, ext> data, S &&...shape) {
|
||||
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
|
||||
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data, S &&...shape) {
|
||||
std::size_t in_shape[sizeof...(S)];
|
||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
|
||||
|
||||
@@ -92,8 +92,8 @@ class Predictor {
|
||||
* \param out_predt Prediction vector to be initialized.
|
||||
* \param model Tree model used for prediction.
|
||||
*/
|
||||
void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
|
||||
const gbm::GBTreeModel& model) const;
|
||||
virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
|
||||
const gbm::GBTreeModel& model) const;
|
||||
|
||||
/**
|
||||
* \brief Generate batch predictions for a given feature matrix. May use
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_STRING_VIEW_H_
|
||||
#define XGBOOST_STRING_VIEW_H_
|
||||
#include <xgboost/logging.h> // CHECK_LT
|
||||
#include <xgboost/span.h> // Span
|
||||
|
||||
#include <algorithm> // std::equal,std::min
|
||||
#include <iterator> // std::reverse_iterator
|
||||
#include <ostream> // std::ostream
|
||||
#include <string> // std::char_traits,std::string
|
||||
#include <algorithm> // for equal, min
|
||||
#include <cstddef> // for size_t
|
||||
#include <iterator> // for reverse_iterator
|
||||
#include <ostream> // for ostream
|
||||
#include <string> // for char_traits, string
|
||||
|
||||
namespace xgboost {
|
||||
struct StringView {
|
||||
private:
|
||||
using CharT = char; // unsigned char
|
||||
using CharT = char;
|
||||
using Traits = std::char_traits<CharT>;
|
||||
CharT const* str_{nullptr};
|
||||
size_t size_{0};
|
||||
std::size_t size_{0};
|
||||
|
||||
public:
|
||||
using value_type = CharT; // NOLINT
|
||||
@@ -28,40 +29,41 @@ struct StringView {
|
||||
|
||||
public:
|
||||
constexpr StringView() = default;
|
||||
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||
constexpr StringView(value_type const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||
StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {} // NOLINT
|
||||
constexpr StringView(CharT const* str) // NOLINT
|
||||
constexpr StringView(value_type const* str) // NOLINT
|
||||
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}
|
||||
|
||||
CharT const& operator[](size_t p) const { return str_[p]; }
|
||||
CharT const& at(size_t p) const { // NOLINT
|
||||
[[nodiscard]] value_type const& operator[](std::size_t p) const { return str_[p]; }
|
||||
[[nodiscard]] explicit operator std::string() const { return {this->c_str(), this->size()}; }
|
||||
[[nodiscard]] value_type const& at(std::size_t p) const { // NOLINT
|
||||
CHECK_LT(p, size_);
|
||||
return str_[p];
|
||||
}
|
||||
constexpr std::size_t size() const { return size_; } // NOLINT
|
||||
constexpr bool empty() const { return size() == 0; } // NOLINT
|
||||
StringView substr(size_t beg, size_t n) const { // NOLINT
|
||||
[[nodiscard]] constexpr std::size_t size() const { return size_; } // NOLINT
|
||||
[[nodiscard]] constexpr bool empty() const { return size() == 0; } // NOLINT
|
||||
[[nodiscard]] StringView substr(std::size_t beg, std::size_t n) const { // NOLINT
|
||||
CHECK_LE(beg, size_);
|
||||
size_t len = std::min(n, size_ - beg);
|
||||
std::size_t len = std::min(n, size_ - beg);
|
||||
return {str_ + beg, len};
|
||||
}
|
||||
CharT const* c_str() const { return str_; } // NOLINT
|
||||
[[nodiscard]] value_type const* c_str() const { return str_; } // NOLINT
|
||||
|
||||
constexpr CharT const* cbegin() const { return str_; } // NOLINT
|
||||
constexpr CharT const* cend() const { return str_ + size(); } // NOLINT
|
||||
constexpr CharT const* begin() const { return str_; } // NOLINT
|
||||
constexpr CharT const* end() const { return str_ + size(); } // NOLINT
|
||||
[[nodiscard]] constexpr const_iterator cbegin() const { return str_; } // NOLINT
|
||||
[[nodiscard]] constexpr const_iterator cend() const { return str_ + size(); } // NOLINT
|
||||
[[nodiscard]] constexpr iterator begin() const { return str_; } // NOLINT
|
||||
[[nodiscard]] constexpr iterator end() const { return str_ + size(); } // NOLINT
|
||||
|
||||
const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->end());
|
||||
}
|
||||
const_reverse_iterator crbegin() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator crbegin() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->end());
|
||||
}
|
||||
const_reverse_iterator rend() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator rend() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->begin());
|
||||
}
|
||||
const_reverse_iterator crend() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator crend() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->begin());
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user