Work with IPv6 in the new tracker. (#10125)
This commit is contained in:
parent
53fc17578f
commit
ca4801f81d
@ -436,28 +436,38 @@ class TCPSocket {
|
|||||||
* \brief Accept new connection, returns a new TCP socket for the new connection.
|
* \brief Accept new connection, returns a new TCP socket for the new connection.
|
||||||
*/
|
*/
|
||||||
TCPSocket Accept() {
|
TCPSocket Accept() {
|
||||||
HandleT newfd = accept(Handle(), nullptr, nullptr);
|
SockAddress addr;
|
||||||
|
TCPSocket newsock;
|
||||||
|
auto rc = this->Accept(&newsock, &addr);
|
||||||
|
SafeColl(rc);
|
||||||
|
return newsock;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
auto interrupt = WSAEINTR;
|
auto interrupt = WSAEINTR;
|
||||||
#else
|
#else
|
||||||
auto interrupt = EINTR;
|
auto interrupt = EINTR;
|
||||||
#endif
|
#endif
|
||||||
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
|
if (this->Domain() == SockDomain::kV4) {
|
||||||
system::ThrowAtError("accept");
|
struct sockaddr_in caddr;
|
||||||
|
socklen_t caddr_len = sizeof(caddr);
|
||||||
|
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
|
||||||
|
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
|
||||||
|
return system::FailWithCode("Failed to accept.");
|
||||||
|
}
|
||||||
|
*addr = SockAddress{SockAddrV4{caddr}};
|
||||||
|
*out = TCPSocket{newfd};
|
||||||
|
} else {
|
||||||
|
struct sockaddr_in6 caddr;
|
||||||
|
socklen_t caddr_len = sizeof(caddr);
|
||||||
|
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
|
||||||
|
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
|
||||||
|
return system::FailWithCode("Failed to accept.");
|
||||||
|
}
|
||||||
|
*addr = SockAddress{SockAddrV6{caddr}};
|
||||||
|
*out = TCPSocket{newfd};
|
||||||
}
|
}
|
||||||
TCPSocket newsock{newfd};
|
|
||||||
return newsock;
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] Result Accept(TCPSocket *out, SockAddrV4 *addr) {
|
|
||||||
struct sockaddr_in caddr;
|
|
||||||
socklen_t caddr_len = sizeof(caddr);
|
|
||||||
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
|
|
||||||
if (newfd == InvalidSocket()) {
|
|
||||||
return system::FailWithCode("Failed to accept.");
|
|
||||||
}
|
|
||||||
*addr = SockAddrV4{caddr};
|
|
||||||
*out = TCPSocket{newfd};
|
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -429,8 +429,8 @@ def make_categorical(
|
|||||||
categories = np.arange(0, n_categories)
|
categories = np.arange(0, n_categories)
|
||||||
for col in df.columns:
|
for col in df.columns:
|
||||||
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
|
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
|
||||||
df[col] = df[col].astype("category")
|
df.loc[:, col] = df[col].astype("category")
|
||||||
df[col] = df[col].cat.set_categories(categories)
|
df.loc[:, col] = df[col].cat.set_categories(categories)
|
||||||
|
|
||||||
if sparsity > 0.0:
|
if sparsity > 0.0:
|
||||||
for i in range(n_features):
|
for i in range(n_features):
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "coll.h"
|
#include "coll.h"
|
||||||
|
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int8_t, int64_t
|
#include <cstdint> // for int8_t, int64_t
|
||||||
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
||||||
|
#include <string> // for string
|
||||||
#include <type_traits> // for is_floating_point_v, is_same_v
|
#include <type_traits> // for is_floating_point_v, is_same_v
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
|
|
||||||
@ -56,6 +57,8 @@ bool constexpr IsFloatingPointV() {
|
|||||||
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
|
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string msg{"Floating point is not supported for bit wise collective operations."};
|
||||||
|
|
||||||
auto rc = DispatchDType(type, [&](auto t) {
|
auto rc = DispatchDType(type, [&](auto t) {
|
||||||
using T = decltype(t);
|
using T = decltype(t);
|
||||||
switch (op) {
|
switch (op) {
|
||||||
@ -70,21 +73,21 @@ bool constexpr IsFloatingPointV() {
|
|||||||
}
|
}
|
||||||
case Op::kBitwiseAND: {
|
case Op::kBitwiseAND: {
|
||||||
if constexpr (IsFloatingPointV<T>()) {
|
if constexpr (IsFloatingPointV<T>()) {
|
||||||
return Fail("Invalid type.");
|
return Fail(msg);
|
||||||
} else {
|
} else {
|
||||||
return fn(std::bit_and<>{}, t);
|
return fn(std::bit_and<>{}, t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case Op::kBitwiseOR: {
|
case Op::kBitwiseOR: {
|
||||||
if constexpr (IsFloatingPointV<T>()) {
|
if constexpr (IsFloatingPointV<T>()) {
|
||||||
return Fail("Invalid type.");
|
return Fail(msg);
|
||||||
} else {
|
} else {
|
||||||
return fn(std::bit_or<>{}, t);
|
return fn(std::bit_or<>{}, t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case Op::kBitwiseXOR: {
|
case Op::kBitwiseXOR: {
|
||||||
if constexpr (IsFloatingPointV<T>()) {
|
if constexpr (IsFloatingPointV<T>()) {
|
||||||
return Fail("Invalid type.");
|
return Fail(msg);
|
||||||
} else {
|
} else {
|
||||||
return fn(std::bit_xor<>{}, t);
|
return fn(std::bit_xor<>{}, t);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -75,9 +75,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
|||||||
} << [&] {
|
} << [&] {
|
||||||
return next->NonBlocking(true);
|
return next->NonBlocking(true);
|
||||||
} << [&] {
|
} << [&] {
|
||||||
SockAddrV4 addr;
|
SockAddress addr;
|
||||||
return listener->Accept(prev.get(), &addr);
|
return listener->Accept(prev.get(), &addr);
|
||||||
} << [&] { return prev->NonBlocking(true); };
|
} << [&] {
|
||||||
|
return prev->NonBlocking(true);
|
||||||
|
};
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -157,10 +159,13 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
||||||
SockAddrV4 addr;
|
|
||||||
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
|
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
|
||||||
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
|
rc = std::move(rc) << [&] {
|
||||||
<< [&] { return peer->RecvTimeout(timeout); };
|
SockAddress addr;
|
||||||
|
return listener->Accept(peer.get(), &addr);
|
||||||
|
} << [&] {
|
||||||
|
return peer->RecvTimeout(timeout);
|
||||||
|
};
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -187,7 +192,9 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
|||||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
||||||
nccl_path_{std::move(nccl_path)} {
|
nccl_path_{std::move(nccl_path)} {
|
||||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
if (!rc.OK()) {
|
||||||
|
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_NCCL)
|
#if !defined(XGBOOST_USE_NCCL)
|
||||||
@ -247,10 +254,12 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
// get ring neighbors
|
// get ring neighbors
|
||||||
std::string snext;
|
std::string snext;
|
||||||
tracker.Recv(&snext);
|
tracker.Recv(&snext);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
||||||
|
}
|
||||||
auto jnext = Json::Load(StringView{snext});
|
auto jnext = Json::Load(StringView{snext});
|
||||||
|
|
||||||
proto::PeerInfo ninfo{jnext};
|
proto::PeerInfo ninfo{jnext};
|
||||||
|
|
||||||
// get the rank of this worker
|
// get the rank of this worker
|
||||||
this->rank_ = BootstrapPrev(ninfo.rank, world);
|
this->rank_ = BootstrapPrev(ninfo.rank, world);
|
||||||
this->tracker_.rank = rank_;
|
this->tracker_.rank = rank_;
|
||||||
@ -258,7 +267,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
std::vector<std::shared_ptr<TCPSocket>> workers;
|
std::vector<std::shared_ptr<TCPSocket>> workers;
|
||||||
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
|
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return Fail("Failed to connect to other workers.", std::move(rc));
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(this->channels_.empty());
|
CHECK(this->channels_.empty());
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#if defined(__unix__) || defined(__APPLE__)
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
#include <netdb.h> // gethostbyname
|
#include <netdb.h> // gethostbyname
|
||||||
@ -27,12 +27,14 @@
|
|||||||
#include "tracker.h"
|
#include "tracker.h"
|
||||||
#include "xgboost/collective/result.h" // for Result, Fail, Success
|
#include "xgboost/collective/result.h" // for Result, Fail, Success
|
||||||
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
|
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
Tracker::Tracker(Json const& config)
|
Tracker::Tracker(Json const& config)
|
||||||
: n_workers_{static_cast<std::int32_t>(
|
: sortby_{static_cast<SortBy>(
|
||||||
RequiredArg<Integer const>(config, "n_workers", __func__))},
|
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
|
||||||
|
n_workers_{
|
||||||
|
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||||
@ -56,13 +58,15 @@ Result Tracker::WaitUntilReady() const {
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr)
|
||||||
: sock_{std::move(sock)} {
|
: sock_{std::move(sock)} {
|
||||||
std::int32_t rank{0};
|
std::int32_t rank{0};
|
||||||
Json jcmd;
|
Json jcmd;
|
||||||
std::int32_t port{0};
|
std::int32_t port{0};
|
||||||
|
|
||||||
rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
|
rc_ = Success() << [&] {
|
||||||
|
return proto::Magic{}.Verify(&sock_);
|
||||||
|
} << [&] {
|
||||||
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||||
} << [&] {
|
} << [&] {
|
||||||
std::string cmd;
|
std::string cmd;
|
||||||
@ -83,8 +87,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
|||||||
}
|
}
|
||||||
return Success();
|
return Success();
|
||||||
} << [&] {
|
} << [&] {
|
||||||
auto host = addr.Addr();
|
if (addr.IsV4()) {
|
||||||
info_ = proto::PeerInfo{host, port, rank};
|
auto host = addr.V4().Addr();
|
||||||
|
info_ = proto::PeerInfo{host, port, rank};
|
||||||
|
} else {
|
||||||
|
auto host = addr.V6().Addr();
|
||||||
|
info_ = proto::PeerInfo{host, port, rank};
|
||||||
|
}
|
||||||
return Success();
|
return Success();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -92,19 +101,19 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
|||||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||||
std::string self;
|
std::string self;
|
||||||
auto rc = collective::GetHostAddress(&self);
|
auto rc = collective::GetHostAddress(&self);
|
||||||
auto host = OptionalArg<String>(config, "host", self);
|
host_ = OptionalArg<String>(config, "host", self);
|
||||||
|
|
||||||
host_ = host;
|
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
||||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
||||||
rc = listener_.Bind(host, &this->port_);
|
rc = listener_.Bind(host_, &this->port_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
listener_.Listen();
|
listener_.Listen();
|
||||||
}
|
}
|
||||||
|
|
||||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||||
auto& workers = *p_workers;
|
auto& workers = *p_workers;
|
||||||
|
|
||||||
std::sort(workers.begin(), workers.end(), WorkerCmp{});
|
std::sort(workers.begin(), workers.end(), WorkerCmp{this->sortby_});
|
||||||
|
|
||||||
std::vector<std::thread> bootstrap_threads;
|
std::vector<std::thread> bootstrap_threads;
|
||||||
for (std::int32_t r = 0; r < n_workers_; ++r) {
|
for (std::int32_t r = 0; r < n_workers_; ++r) {
|
||||||
@ -224,7 +233,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
|
|
||||||
while (state.ShouldContinue()) {
|
while (state.ShouldContinue()) {
|
||||||
TCPSocket sock;
|
TCPSocket sock;
|
||||||
SockAddrV4 addr;
|
SockAddress addr;
|
||||||
this->ready_ = true;
|
this->ready_ = true;
|
||||||
auto rc = listener_.Accept(&sock, &addr);
|
auto rc = listener_.Accept(&sock, &addr);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
@ -291,7 +300,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
|
|
||||||
[[nodiscard]] Json RabitTracker::WorkerArgs() const {
|
[[nodiscard]] Json RabitTracker::WorkerArgs() const {
|
||||||
auto rc = this->WaitUntilReady();
|
auto rc = this->WaitUntilReady();
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
|
|
||||||
Json args{Object{}};
|
Json args{Object{}};
|
||||||
args["DMLC_TRACKER_URI"] = String{host_};
|
args["DMLC_TRACKER_URI"] = String{host_};
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
@ -36,6 +36,16 @@ namespace xgboost::collective {
|
|||||||
* signal an error to the tracker and the tracker will notify other workers.
|
* signal an error to the tracker and the tracker will notify other workers.
|
||||||
*/
|
*/
|
||||||
class Tracker {
|
class Tracker {
|
||||||
|
protected:
|
||||||
|
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
||||||
|
// setting, multiple workers can occupy the same host, in which case one should sort
|
||||||
|
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
||||||
|
// we use host as the default.
|
||||||
|
enum class SortBy : std::int8_t {
|
||||||
|
kHost = 0,
|
||||||
|
kTask = 1,
|
||||||
|
} sortby_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::int32_t n_workers_{0};
|
std::int32_t n_workers_{0};
|
||||||
std::int32_t port_{-1};
|
std::int32_t port_{-1};
|
||||||
@ -76,7 +86,7 @@ class RabitTracker : public Tracker {
|
|||||||
Result rc_;
|
Result rc_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
|
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr);
|
||||||
WorkerProxy(WorkerProxy const& that) = delete;
|
WorkerProxy(WorkerProxy const& that) = delete;
|
||||||
WorkerProxy(WorkerProxy&& that) = default;
|
WorkerProxy(WorkerProxy&& that) = default;
|
||||||
WorkerProxy& operator=(WorkerProxy const&) = delete;
|
WorkerProxy& operator=(WorkerProxy const&) = delete;
|
||||||
@ -96,11 +106,14 @@ class RabitTracker : public Tracker {
|
|||||||
|
|
||||||
void Send(StringView value) { this->sock_.Send(value); }
|
void Send(StringView value) { this->sock_.Send(value); }
|
||||||
};
|
};
|
||||||
// provide an ordering for workers, this helps us get deterministic topology.
|
// Provide an ordering for workers, this helps us get deterministic topology.
|
||||||
struct WorkerCmp {
|
struct WorkerCmp {
|
||||||
|
SortBy sortby;
|
||||||
|
explicit WorkerCmp(SortBy sortby) : sortby{sortby} {}
|
||||||
|
|
||||||
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
|
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
|
||||||
auto const& lh = lhs.Host();
|
auto const& lh = sortby == Tracker::SortBy::kHost ? lhs.Host() : lhs.TaskID();
|
||||||
auto const& rh = rhs.Host();
|
auto const& rh = sortby == Tracker::SortBy::kHost ? rhs.Host() : rhs.TaskID();
|
||||||
|
|
||||||
if (lh != rh) {
|
if (lh != rh) {
|
||||||
return lh < rh;
|
return lh < rh;
|
||||||
|
|||||||
@ -18,7 +18,6 @@
|
|||||||
#include <cstdint> // for int32_t, uint32_t, int64_t, uint64_t
|
#include <cstdint> // for int32_t, uint32_t, int64_t, uint64_t
|
||||||
#include <cstdlib> // for atoi
|
#include <cstdlib> // for atoi
|
||||||
#include <cstring> // for memcpy, size_t, memset
|
#include <cstring> // for memcpy, size_t, memset
|
||||||
#include <functional> // for less
|
|
||||||
#include <iomanip> // for operator<<, setiosflags
|
#include <iomanip> // for operator<<, setiosflags
|
||||||
#include <iterator> // for back_insert_iterator, distance, back_inserter
|
#include <iterator> // for back_insert_iterator, distance, back_inserter
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
|
|||||||
@ -25,7 +25,7 @@ RUN \
|
|||||||
mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||||
python=3.10 cudf=$RAPIDS_VERSION_ARG* rmm=$RAPIDS_VERSION_ARG* cudatoolkit=$CUDA_VERSION_ARG \
|
python=3.10 cudf=$RAPIDS_VERSION_ARG* rmm=$RAPIDS_VERSION_ARG* cudatoolkit=$CUDA_VERSION_ARG \
|
||||||
nccl>=$(cut -d "-" -f 1 << $NCCL_VERSION_ARG) \
|
nccl>=$(cut -d "-" -f 1 << $NCCL_VERSION_ARG) \
|
||||||
dask \
|
dask=2024.1.1 \
|
||||||
dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
|
dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
|
||||||
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
|
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
|
||||||
pyspark>=3.4.0 cloudpickle cuda-python && \
|
pyspark>=3.4.0 cloudpickle cuda-python && \
|
||||||
|
|||||||
@ -252,7 +252,7 @@ class TestDistributedGPU:
|
|||||||
|
|
||||||
X_onehot, _ = make_categorical(local_cuda_client, 10000, 30, 13, True)
|
X_onehot, _ = make_categorical(local_cuda_client, 10000, 30, 13, True)
|
||||||
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
|
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
|
||||||
run_categorical(local_cuda_client, "gpu_hist", X, X_onehot, y)
|
run_categorical(local_cuda_client, "hist", "cuda", X, X_onehot, y)
|
||||||
|
|
||||||
@given(
|
@given(
|
||||||
params=hist_parameter_strategy,
|
params=hist_parameter_strategy,
|
||||||
|
|||||||
@ -315,8 +315,15 @@ def test_dask_sparse(client: "Client") -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
|
def run_categorical(
|
||||||
parameters = {"tree_method": tree_method, "max_cat_to_onehot": 9999} # force onehot
|
client: "Client", tree_method: str, device: str, X, X_onehot, y
|
||||||
|
) -> None:
|
||||||
|
# Force onehot
|
||||||
|
parameters = {
|
||||||
|
"tree_method": tree_method,
|
||||||
|
"device": device,
|
||||||
|
"max_cat_to_onehot": 9999,
|
||||||
|
}
|
||||||
rounds = 10
|
rounds = 10
|
||||||
m = xgb.dask.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
|
m = xgb.dask.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
|
||||||
by_etl_results = xgb.dask.train(
|
by_etl_results = xgb.dask.train(
|
||||||
@ -364,6 +371,7 @@ def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
|
|||||||
enable_categorical=True,
|
enable_categorical=True,
|
||||||
n_estimators=10,
|
n_estimators=10,
|
||||||
tree_method=tree_method,
|
tree_method=tree_method,
|
||||||
|
device=device,
|
||||||
# force onehot
|
# force onehot
|
||||||
max_cat_to_onehot=9999,
|
max_cat_to_onehot=9999,
|
||||||
)
|
)
|
||||||
@ -378,7 +386,10 @@ def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
|
|||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
# check partition based
|
# check partition based
|
||||||
reg = xgb.dask.DaskXGBRegressor(
|
reg = xgb.dask.DaskXGBRegressor(
|
||||||
enable_categorical=True, n_estimators=10, tree_method=tree_method
|
enable_categorical=True,
|
||||||
|
n_estimators=10,
|
||||||
|
tree_method=tree_method,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
reg.fit(X, y, eval_set=[(X, y)])
|
reg.fit(X, y, eval_set=[(X, y)])
|
||||||
assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"])
|
assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"])
|
||||||
@ -398,8 +409,8 @@ def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
|
|||||||
def test_categorical(client: "Client") -> None:
|
def test_categorical(client: "Client") -> None:
|
||||||
X, y = make_categorical(client, 10000, 30, 13)
|
X, y = make_categorical(client, 10000, 30, 13)
|
||||||
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
|
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
|
||||||
run_categorical(client, "approx", X, X_onehot, y)
|
run_categorical(client, "approx", "cpu", X, X_onehot, y)
|
||||||
run_categorical(client, "hist", X, X_onehot, y)
|
run_categorical(client, "hist", "cpu", X, X_onehot, y)
|
||||||
|
|
||||||
ft = ["c"] * X.shape[1]
|
ft = ["c"] * X.shape[1]
|
||||||
reg = xgb.dask.DaskXGBRegressor(
|
reg = xgb.dask.DaskXGBRegressor(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user