merge latest change from upstream
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2015~2023 by XGBoost Contributors
|
||||
* Copyright 2015-2024, XGBoost Contributors
|
||||
* \file c_api.h
|
||||
* \author Tianqi Chen
|
||||
* \brief C API of XGBoost, used for interfacing to other languages.
|
||||
@@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
|
||||
* \param len length of array
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const char *field,
|
||||
const float *array,
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array,
|
||||
bst_ulong len);
|
||||
/*!
|
||||
* \brief set uint32 vector to a content in info
|
||||
* \param handle a instance of data matrix
|
||||
* \param field field name
|
||||
* \param array pointer to unsigned int vector
|
||||
* \param len length of array
|
||||
* \return 0 when success, -1 when failure happens
|
||||
/**
|
||||
* @deprecated since 2.1.0
|
||||
*
|
||||
* Use @ref XGDMatrixSetInfoFromInterface instead.
|
||||
*/
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char *field,
|
||||
const unsigned *array,
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array,
|
||||
bst_ulong len);
|
||||
|
||||
/*!
|
||||
@@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
bst_ulong *size,
|
||||
const char ***out_features);
|
||||
|
||||
/*!
|
||||
* \brief Set meta info from dense matrix. Valid field names are:
|
||||
/**
|
||||
* @deprecated since 2.1.0
|
||||
*
|
||||
* - label
|
||||
* - weight
|
||||
* - base_margin
|
||||
* - group
|
||||
* - label_lower_bound
|
||||
* - label_upper_bound
|
||||
* - feature_weights
|
||||
*
|
||||
* \param handle An instance of data matrix
|
||||
* \param field Field name
|
||||
* \param data Pointer to consecutive memory storing data.
|
||||
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
|
||||
* of bytes.)
|
||||
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
|
||||
* - float = 1
|
||||
* - double = 2
|
||||
* - uint32_t = 3
|
||||
* - uint64_t = 4
|
||||
* \return 0 when success, -1 when failure happens
|
||||
* Use @ref XGDMatrixSetInfoFromInterface instead.
|
||||
*/
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
||||
void const *data, bst_ulong size, int type);
|
||||
|
||||
/*!
|
||||
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param group pointer to group size
|
||||
* \param len length of array
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned *group,
|
||||
bst_ulong len);
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
|
||||
bst_ulong size, int type);
|
||||
|
||||
/*!
|
||||
* \brief get float info vector from matrix.
|
||||
@@ -1591,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||
|
||||
/**
|
||||
* @brief Get the arguments needed for running workers. This should be called after
|
||||
* XGTrackerRun() and XGTrackerWait()
|
||||
* XGTrackerRun().
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param args The arguments returned as a JSON document.
|
||||
@@ -1601,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||
|
||||
/**
|
||||
* @brief Run the tracker.
|
||||
* @brief Start the tracker. The tracker runs in the background and this function returns
|
||||
* once the tracker is started.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param config Unused at the moment, preserved for the future.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
|
||||
|
||||
/**
|
||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
|
||||
* function will block until the tracker task is finished or timeout is reached.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||
@@ -1618,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
||||
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
|
||||
|
||||
/**
|
||||
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
||||
* cannot close properly, manual interruption is required.
|
||||
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
|
||||
* tracker is not properly waited, this function will shutdown all connections with
|
||||
* the tracker, potentially leading to undefined behavior.
|
||||
*
|
||||
* @param handle The handle to the tracker.
|
||||
*
|
||||
|
||||
@@ -3,13 +3,11 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
#include <system_error> // for error_code
|
||||
#include <utility> // for move
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
@@ -48,48 +46,19 @@ struct ResultImpl {
|
||||
return cur_eq;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string Report() {
|
||||
std::stringstream ss;
|
||||
ss << "\n- " << this->message;
|
||||
if (this->errc != std::error_code{}) {
|
||||
ss << " system error:" << this->errc.message();
|
||||
}
|
||||
[[nodiscard]] std::string Report() const;
|
||||
[[nodiscard]] std::error_code Code() const;
|
||||
|
||||
auto ptr = prev.get();
|
||||
while (ptr) {
|
||||
ss << "\n- ";
|
||||
ss << ptr->message;
|
||||
|
||||
if (ptr->errc != std::error_code{}) {
|
||||
ss << " " << ptr->errc.message();
|
||||
}
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
[[nodiscard]] auto Code() const {
|
||||
// Find the root error.
|
||||
std::stack<ResultImpl const*> stack;
|
||||
auto ptr = this;
|
||||
while (ptr) {
|
||||
stack.push(ptr);
|
||||
if (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
auto frame = stack.top();
|
||||
stack.pop();
|
||||
if (frame->errc != std::error_code{}) {
|
||||
return frame->errc;
|
||||
}
|
||||
}
|
||||
return std::error_code{};
|
||||
}
|
||||
void Concat(std::unique_ptr<ResultImpl> rhs);
|
||||
};
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
#define __builtin_FILE() nullptr
|
||||
#define __builtin_LINE() (-1)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
@@ -131,8 +100,21 @@ struct Result {
|
||||
}
|
||||
return *impl_ == *that.impl_;
|
||||
}
|
||||
|
||||
friend Result operator+(Result&& lhs, Result&& rhs);
|
||||
};
|
||||
|
||||
[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
|
||||
if (lhs.OK()) {
|
||||
return std::forward<Result>(rhs);
|
||||
}
|
||||
if (rhs.OK()) {
|
||||
return std::forward<Result>(lhs);
|
||||
}
|
||||
lhs.impl_->Concat(std::move(rhs.impl_));
|
||||
return std::forward<Result>(lhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return success.
|
||||
*/
|
||||
@@ -140,38 +122,43 @@ struct Result {
|
||||
/**
|
||||
* @brief Return failure.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
|
||||
[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with `errno`.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
|
||||
return Result{std::move(msg), std::move(errc)};
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
|
||||
char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with a previous error.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
|
||||
return Result{std::move(msg), std::forward<Result>(prev)};
|
||||
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with a previous error and a new `errno`.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
|
||||
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
|
||||
char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
|
||||
std::forward<Result>(prev)};
|
||||
}
|
||||
|
||||
// We don't have monad, a simple helper would do.
|
||||
template <typename Fn>
|
||||
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
|
||||
[[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
|
||||
if (!r.OK()) {
|
||||
return std::forward<Result>(r);
|
||||
}
|
||||
return fn();
|
||||
}
|
||||
|
||||
inline void SafeColl(Result const& rc) {
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
}
|
||||
void SafeColl(Result const& rc);
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright (c) 2022-2023, XGBoost Contributors
|
||||
* Copyright (c) 2022-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::int32_t, std::uint16_t
|
||||
#include <cstring> // memset
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <string> // std::string
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
#include <utility> // std::swap
|
||||
@@ -125,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::int32_t ShutdownSocket(SocketT fd) {
|
||||
#if defined(_WIN32)
|
||||
auto rc = shutdown(fd, SD_BOTH);
|
||||
if (rc != 0 && LastError() == WSANOTINITIALISED) {
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
auto rc = shutdown(fd, SHUT_RDWR);
|
||||
if (rc != 0 && LastError() == ENOTCONN) {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
return rc;
|
||||
}
|
||||
|
||||
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
||||
#ifdef _WIN32
|
||||
return errsv == WSAEWOULDBLOCK;
|
||||
@@ -468,19 +482,30 @@ class TCPSocket {
|
||||
*addr = SockAddress{SockAddrV6{caddr}};
|
||||
*out = TCPSocket{newfd};
|
||||
}
|
||||
// On MacOS, this is automatically set to async socket if the parent socket is async
|
||||
// We make sure all socket are blocking by default.
|
||||
//
|
||||
// On Windows, a closed socket is returned during shutdown. We guard against it when
|
||||
// setting non-blocking.
|
||||
if (!out->IsClosed()) {
|
||||
return out->NonBlocking(false);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
~TCPSocket() {
|
||||
if (!IsClosed()) {
|
||||
Close();
|
||||
auto rc = this->Close();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << rc.Report();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket(TCPSocket const &that) = delete;
|
||||
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
|
||||
TCPSocket &operator=(TCPSocket const &that) = delete;
|
||||
TCPSocket &operator=(TCPSocket &&that) {
|
||||
TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
|
||||
std::swap(this->handle_, that.handle_);
|
||||
return *this;
|
||||
}
|
||||
@@ -489,36 +514,49 @@ class TCPSocket {
|
||||
*/
|
||||
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
||||
/**
|
||||
* \brief Listen to incoming requests. Should be called after bind.
|
||||
* @brief Listen to incoming requests. Should be called after bind.
|
||||
*/
|
||||
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
|
||||
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
|
||||
if (listen(handle_, backlog) != 0) {
|
||||
return system::FailWithCode("Failed to listen.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||
*/
|
||||
[[nodiscard]] in_port_t BindHost() {
|
||||
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
|
||||
// Use int32 instead of in_port_t for consistency. We take port as parameter from
|
||||
// users using other languages, the port is usually stored and passed around as int.
|
||||
if (Domain() == SockDomain::kV6) {
|
||||
auto addr = SockAddrV6::InaddrAny();
|
||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
||||
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||
return system::FailWithCode("bind failed.");
|
||||
}
|
||||
|
||||
sockaddr_in6 res_addr;
|
||||
socklen_t addrlen = sizeof(res_addr);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
||||
return ntohs(res_addr.sin6_port);
|
||||
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||
return system::FailWithCode("getsockname failed.");
|
||||
}
|
||||
*p_out = ntohs(res_addr.sin6_port);
|
||||
} else {
|
||||
auto addr = SockAddrV4::InaddrAny();
|
||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
||||
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||
return system::FailWithCode("bind failed.");
|
||||
}
|
||||
|
||||
sockaddr_in res_addr;
|
||||
socklen_t addrlen = sizeof(res_addr);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
||||
return ntohs(res_addr.sin_port);
|
||||
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||
return system::FailWithCode("getsockname failed.");
|
||||
}
|
||||
*p_out = ntohs(res_addr.sin_port);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] auto Port() const {
|
||||
@@ -631,26 +669,49 @@ class TCPSocket {
|
||||
*/
|
||||
std::size_t Send(StringView str);
|
||||
/**
|
||||
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||
*/
|
||||
std::size_t Recv(std::string *p_str);
|
||||
[[nodiscard]] Result Recv(std::string *p_str);
|
||||
/**
|
||||
* \brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||
* @brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||
*/
|
||||
void Close() {
|
||||
[[nodiscard]] Result Close() {
|
||||
if (InvalidSocket() != handle_) {
|
||||
#if defined(_WIN32)
|
||||
auto rc = system::CloseSocket(handle_);
|
||||
#if defined(_WIN32)
|
||||
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
|
||||
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
|
||||
system::ThrowAtError("close", rc);
|
||||
return system::FailWithCode("Failed to close the socket.");
|
||||
}
|
||||
#else
|
||||
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to close the socket.");
|
||||
}
|
||||
#endif
|
||||
handle_ = InvalidSocket();
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* @brief Call shutdown on the socket.
|
||||
*/
|
||||
[[nodiscard]] Result Shutdown() {
|
||||
if (this->IsClosed()) {
|
||||
return Success();
|
||||
}
|
||||
auto rc = system::ShutdownSocket(this->Handle());
|
||||
#if defined(_WIN32)
|
||||
// Windows cannot shutdown a socket if it's not connected.
|
||||
if (rc == -1 && system::LastError() == WSAENOTCONN) {
|
||||
return Success();
|
||||
}
|
||||
#endif
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to shutdown socket.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create a TCP socket on specified domain.
|
||||
*/
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -137,14 +136,6 @@ class MetaInfo {
|
||||
* \param fo The output stream.
|
||||
*/
|
||||
void SaveBinary(dmlc::Stream* fo) const;
|
||||
/*!
|
||||
* \brief Set information in the meta info.
|
||||
* \param key The key of the information.
|
||||
* \param dptr The data pointer of the source array.
|
||||
* \param dtype The type of the source data.
|
||||
* \param num Number of elements in the source array.
|
||||
*/
|
||||
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
|
||||
/*!
|
||||
* \brief Set information in the meta info with array interface.
|
||||
* \param key The key of the information.
|
||||
@@ -517,10 +508,6 @@ class DMatrix {
|
||||
DMatrix() = default;
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual MetaInfo& Info() = 0;
|
||||
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||
auto const& ctx = *this->Ctx();
|
||||
this->Info().SetInfo(ctx, key, dptr, dtype, num);
|
||||
}
|
||||
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
||||
auto const& ctx = *this->Ctx();
|
||||
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
||||
|
||||
@@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) {
|
||||
// uint division optimization inspired by the CIndexer in cupy. Division operation is
|
||||
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
|
||||
// bit when the index is smaller, then try to avoid division when it's exp of 2.
|
||||
template <typename I, int32_t D>
|
||||
template <typename I, std::int32_t D>
|
||||
LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
|
||||
size_t index[D]{0};
|
||||
std::size_t index[D]{0};
|
||||
static_assert(std::is_signed<decltype(D)>::value,
|
||||
"Don't change the type without changing the for loop.");
|
||||
auto const sptr = shape.data();
|
||||
for (int32_t dim = D; --dim > 0;) {
|
||||
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(shape[dim]);
|
||||
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(sptr[dim]);
|
||||
if (s & (s - 1)) {
|
||||
auto t = idx / s;
|
||||
index[dim] = idx - t * s;
|
||||
@@ -745,6 +746,14 @@ auto ArrayInterfaceStr(TensorView<T, D> const &t) {
|
||||
return str;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto Make1dInterface(T const *vec, std::size_t len) {
|
||||
Context ctx;
|
||||
auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len);
|
||||
auto str = linalg::ArrayInterfaceStr(t);
|
||||
return str;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief A tensor storage. To use it for other functionality like slicing one needs to
|
||||
* obtain a view first. This way we can use it on both host and device.
|
||||
|
||||
@@ -30,9 +30,8 @@
|
||||
#define XGBOOST_SPAN_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <cinttypes> // size_t
|
||||
#include <cstddef> // size_t
|
||||
#include <cstdio>
|
||||
#include <iterator>
|
||||
#include <limits> // numeric_limits
|
||||
@@ -75,8 +74,7 @@
|
||||
|
||||
#endif // defined(_MSC_VER) && _MSC_VER < 1910
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
// Usual logging facility is not available inside device code.
|
||||
@@ -744,8 +742,8 @@ class IterSpan {
|
||||
return it_ + size();
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
|
||||
#if defined(_MSC_VER) &&_MSC_VER < 1910
|
||||
#undef constexpr
|
||||
|
||||
Reference in New Issue
Block a user