Merge branch 'master'

This commit is contained in:
Hui Liu 2023-11-02 09:05:31 -07:00
commit 3af5dfd546
11 changed files with 207 additions and 67 deletions

View File

@ -6,7 +6,6 @@
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ... #include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
#include <grpcpp/server_builder.h> // for ServerBuilder #include <grpcpp/server_builder.h> // for ServerBuilder
#include <chrono> // for ms
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <exception> // for exception #include <exception> // for exception
#include <limits> // for numeric_limits #include <limits> // for numeric_limits
@ -61,7 +60,7 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
} }
std::future<Result> FederatedTracker::Run() { std::future<Result> FederatedTracker::Run() {
return std::async([this]() { return std::async(std::launch::async, [this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_); std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
xgboost::collective::federated::FederatedService service{ xgboost::collective::federated::FederatedService service{
static_cast<std::int32_t>(this->n_workers_)}; static_cast<std::int32_t>(this->n_workers_)};
@ -98,10 +97,13 @@ std::future<Result> FederatedTracker::Run() {
try { try {
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
ready_ = true;
server_->Wait(); server_->Wait();
} catch (std::exception const& e) { } catch (std::exception const& e) {
return collective::Fail(std::string{e.what()}); return collective::Fail(std::string{e.what()});
} }
ready_ = false;
return collective::Success(); return collective::Success();
}); });
} }
@ -109,18 +111,8 @@ std::future<Result> FederatedTracker::Run() {
FederatedTracker::~FederatedTracker() = default; FederatedTracker::~FederatedTracker() = default;
Result FederatedTracker::Shutdown() { Result FederatedTracker::Shutdown() {
common::Timer timer; auto rc = this->WaitUntilReady();
timer.Start(); CHECK(rc.OK()) << rc.Report();
using namespace std::chrono_literals;
while (!server_) {
timer.Stop();
auto ela = timer.ElapsedSeconds();
if (ela > this->Timeout().count()) {
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(10ms);
}
try { try {
server_->Shutdown(); server_->Shutdown();
@ -130,4 +122,17 @@ Result FederatedTracker::Shutdown() {
return Success(); return Success();
} }
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
std::string host;
rc = GetHostAddress(&host);
CHECK(rc.OK());
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -57,9 +57,8 @@ class FederatedTracker : public collective::Tracker {
explicit FederatedTracker(Json const& config); explicit FederatedTracker(Json const& config);
~FederatedTracker() override; ~FederatedTracker() override;
std::future<Result> Run() override; std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves. [[nodiscard]] Json WorkerArgs() const override;
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Result Shutdown(); [[nodiscard]] Result Shutdown();
}; };
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -248,7 +248,7 @@ __model_doc = f"""
Balancing of positive and negative weights. Balancing of positive and negative weights.
base_score : Optional[float] base_score : Optional[float]
The initial prediction score of all instances, global bias. The initial prediction score of all instances, global bias.
random_state : Optional[Union[numpy.random.RandomState, int]] random_state : Optional[Union[numpy.random.RandomState, numpy.random.Generator, int]]
Random number seed. Random number seed.
.. note:: .. note::
@ -651,7 +651,9 @@ class XGBModel(XGBModelBase):
reg_lambda: Optional[float] = None, reg_lambda: Optional[float] = None,
scale_pos_weight: Optional[float] = None, scale_pos_weight: Optional[float] = None,
base_score: Optional[float] = None, base_score: Optional[float] = None,
random_state: Optional[Union[np.random.RandomState, int]] = None, random_state: Optional[
Union[np.random.RandomState, np.random.Generator, int]
] = None,
missing: float = np.nan, missing: float = np.nan,
num_parallel_tree: Optional[int] = None, num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None, monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
@ -789,6 +791,10 @@ class XGBModel(XGBModelBase):
params["random_state"] = params["random_state"].randint( params["random_state"] = params["random_state"].randint(
np.iinfo(np.int32).max np.iinfo(np.int32).max
) )
elif isinstance(params["random_state"], np.random.Generator):
params["random_state"] = int(
params["random_state"].integers(np.iinfo(np.int32).max)
)
return params return params

View File

@ -3,19 +3,35 @@
*/ */
#include "coll.h" #include "coll.h"
#include <algorithm> // for min, max #include <algorithm> // for min, max, copy_n
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t #include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus #include <functional> // for bit_and, bit_or, bit_xor, plus
#include <type_traits> // for is_floating_point_v, is_same_v
#include <utility> // for move
#include "allgather.h" // for RingAllgatherV, RingAllgather #include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "allreduce.h" // for Allreduce #include "allgather.h" // for RingAllgatherV, RingAllgather
#include "broadcast.h" // for Broadcast #include "allreduce.h" // for Allreduce
#include "comm.h" // for Comm #include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#if defined(XGBOOST_USE_CUDA)
#include "cuda_fp16.h" // for __half
#endif
namespace xgboost::collective { namespace xgboost::collective {
template <typename T>
bool constexpr IsFloatingPointV() {
#if defined(XGBOOST_USE_CUDA)
return std::is_floating_point_v<T> || std::is_same_v<T, __half>;
#else
return std::is_floating_point_v<T>;
#endif // defined(XGBOOST_USE_CUDA)
}
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type, Op op) { ArrayInterfaceHandler::Type type, Op op) {
namespace coll = ::xgboost::collective; namespace coll = ::xgboost::collective;
auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto redop_fn = [](auto lhs, auto out, auto elem_op) {
@ -25,32 +41,59 @@ namespace xgboost::collective {
p_out[i] = elem_op(p_lhs[i], p_out[i]); p_out[i] = elem_op(p_lhs[i], p_out[i]);
} }
}; };
auto fn = [&](auto elem_op) {
return coll::Allreduce( auto fn = [&](auto elem_op, auto t) {
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); }); using T = decltype(t);
auto erased_fn = [redop_fn, elem_op](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out);
redop_fn(lhs_t, rhs_t, elem_op);
};
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
}; };
switch (op) { auto rc = DispatchDType(type, [&](auto t) {
case Op::kMax: { using T = decltype(t);
return fn([](auto l, auto r) { return std::max(l, r); }); switch (op) {
case Op::kMax: {
return fn([](auto l, auto r) { return std::max(l, r); }, t);
}
case Op::kMin: {
return fn([](auto l, auto r) { return std::min(l, r); }, t);
}
case Op::kSum: {
return fn(std::plus<>{}, t);
}
case Op::kBitwiseAND: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
} else {
return fn(std::bit_and<>{}, t);
}
}
case Op::kBitwiseOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
} else {
return fn(std::bit_or<>{}, t);
}
}
case Op::kBitwiseXOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
} else {
return fn(std::bit_xor<>{}, t);
}
}
} }
case Op::kMin: { return Fail("Invalid op.");
return fn([](auto l, auto r) { return std::min(l, r); }); });
}
case Op::kSum: { return std::move(rc) << [&] { return comm.Block(); };
return fn(std::plus<>{});
}
case Op::kBitwiseAND: {
return fn(std::bit_and<>{});
}
case Op::kBitwiseOR: {
return fn(std::bit_or<>{});
}
case Op::kBitwiseXOR: {
return fn(std::bit_xor<>{});
}
}
return comm.Block();
} }
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,

View File

@ -16,7 +16,7 @@
#endif // defined(_WIN32) #endif // defined(_WIN32)
#include <algorithm> // for sort #include <algorithm> // for sort
#include <chrono> // for seconds #include <chrono> // for seconds, ms
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <string> // for string #include <string> // for string
#include <utility> // for move, forward #include <utility> // for move, forward
@ -37,6 +37,25 @@ Tracker::Tracker(Json const& config)
timeout_{std::chrono::seconds{OptionalArg<Integer const>( timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {} config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
Result Tracker::WaitUntilReady() const {
using namespace std::chrono_literals; // NOLINT
// Busy waiting. The function is mostly for waiting for the OS to launch an async
// thread, which should be reasonably fast.
common::Timer timer;
timer.Start();
while (!this->Ready()) {
auto ela = timer.Duration().count();
if (ela > this->Timeout().count()) {
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(100ms);
}
return Success();
}
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr) RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
: sock_{std::move(sock)} { : sock_{std::move(sock)} {
auto host = addr.Addr(); auto host = addr.Addr();
@ -76,6 +95,7 @@ RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
auto rc = collective::GetHostAddress(&self); auto rc = collective::GetHostAddress(&self);
auto host = OptionalArg<String>(config, "host", self); auto host = OptionalArg<String>(config, "host", self);
host_ = host;
listener_ = TCPSocket::Create(SockDomain::kV4); listener_ = TCPSocket::Create(SockDomain::kV4);
rc = listener_.Bind(host, &this->port_); rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report(); CHECK(rc.OK()) << rc.Report();
@ -173,6 +193,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
while (state.ShouldContinue()) { while (state.ShouldContinue()) {
TCPSocket sock; TCPSocket sock;
SockAddrV4 addr; SockAddrV4 addr;
this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr); auto rc = listener_.Accept(&sock, &addr);
if (!rc.OK()) { if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc)); return Fail("Failed to accept connection.", std::move(rc));
@ -237,10 +258,21 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
} }
} }
} }
ready_ = false;
return Success(); return Success();
}); });
} }
[[nodiscard]] Json RabitTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
[[nodiscard]] Result GetHostAddress(std::string* out) { [[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out); auto rc = GetHostName(out);
if (!rc.OK()) { if (!rc.OK()) {

View File

@ -40,6 +40,7 @@ class Tracker {
std::int32_t n_workers_{0}; std::int32_t n_workers_{0};
std::int32_t port_{-1}; std::int32_t port_{-1};
std::chrono::seconds timeout_{0}; std::chrono::seconds timeout_{0};
std::atomic<bool> ready_{false};
public: public:
explicit Tracker(Json const& config); explicit Tracker(Json const& config);
@ -47,10 +48,17 @@ class Tracker {
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {} : n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT virtual ~Tracker() noexcept(false){}; // NOLINT
[[nodiscard]] Result WaitUntilReady() const;
[[nodiscard]] virtual std::future<Result> Run() = 0; [[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0; [[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; } [[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
[[nodiscard]] virtual std::int32_t Port() const { return port_; } [[nodiscard]] virtual std::int32_t Port() const { return port_; }
/**
* @brief Flag to indicate whether the server is running.
*/
[[nodiscard]] bool Ready() const { return ready_; }
}; };
class RabitTracker : public Tracker { class RabitTracker : public Tracker {
@ -124,13 +132,7 @@ class RabitTracker : public Tracker {
~RabitTracker() noexcept(false) override = default; ~RabitTracker() noexcept(false) override = default;
std::future<Result> Run() override; std::future<Result> Run() override;
[[nodiscard]] Json WorkerArgs() const override;
[[nodiscard]] Json WorkerArgs() const override {
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
}; };
// Prob the public IP address of the host, need a better method. // Prob the public IP address of the host, need a better method.

View File

@ -29,6 +29,7 @@ struct Timer {
void Start() { start = ClockT::now(); } void Start() { start = ClockT::now(); }
void Stop() { elapsed += ClockT::now() - start; } void Stop() { elapsed += ClockT::now() - start; }
double ElapsedSeconds() const { return SecondsT(elapsed).count(); } double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
SecondsT Duration() const { return ClockT::now() - start; }
void PrintElapsed(std::string label) { void PrintElapsed(std::string label) {
char buffer[255]; char buffer[255];
snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(), snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(),

View File

@ -12,14 +12,17 @@ sysctl -n machdep.cpu.brand_string
uname -m uname -m
set +x set +x
# Create new Conda env # Build XGBoost4J binary
echo "--- Set up Conda env" echo "--- Build libxgboost4j.dylib"
. $HOME/mambaforge/etc/profile.d/conda.sh set -x
. $HOME/mambaforge/etc/profile.d/mamba.sh mkdir build
conda_env=xgboost_dev_$(uuidgen | tr '[:upper:]' '[:lower:]' | tr -d '-') pushd build
mamba create -y -n ${conda_env} python=3.8 export JAVA_HOME=$(/usr/libexec/java_home)
conda activate ${conda_env} cmake .. -GNinja -DJVM_BINDINGS=ON -DUSE_OPENMP=OFF -DCMAKE_OSX_DEPLOYMENT_TARGET=10.15
mamba env update -n ${conda_env} --file tests/ci_build/conda_env/macos_cpu_test.yml ninja -v
popd
rm -rf build
set +x
# Ensure that XGBoost can be built with Clang 11 # Ensure that XGBoost can be built with Clang 11
echo "--- Build and Test XGBoost with MacOS M1, Clang 11" echo "--- Build and Test XGBoost with MacOS M1, Clang 11"

View File

@ -27,9 +27,15 @@ class PrintWorker : public WorkerForTest {
TEST_F(TrackerTest, Bootstrap) { TEST_F(TrackerTest, Bootstrap) {
RabitTracker tracker{host, n_workers, 0, timeout}; RabitTracker tracker{host, n_workers, 0, timeout};
ASSERT_FALSE(tracker.Ready());
auto fut = tracker.Run(); auto fut = tracker.Run();
std::vector<std::thread> workers; std::vector<std::thread> workers;
auto args = tracker.WorkerArgs();
ASSERT_TRUE(tracker.Ready());
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
std::int32_t port = tracker.Port(); std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) { for (std::int32_t i = 0; i < n_workers; ++i) {
@ -47,6 +53,9 @@ TEST_F(TrackerTest, Print) {
auto fut = tracker.Run(); auto fut = tracker.Run();
std::vector<std::thread> workers; std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK());
std::int32_t port = tracker.Port(); std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) { for (std::int32_t i = 0; i < n_workers; ++i) {

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <memory> // for make_unique
#include <string> // for string
#include "../../../../src/collective/tracker.h" // for GetHostAddress
#include "federated_tracker.h"
#include "test_worker.h"
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
TEST(FederatedTrackerTest, Basic) {
Json config{Object()};
config["federated_secure"] = Boolean{false};
config["n_workers"] = Integer{3};
auto tracker = std::make_unique<FederatedTracker>(config);
ASSERT_FALSE(tracker->Ready());
auto fut = tracker->Run();
auto args = tracker->WorkerArgs();
ASSERT_TRUE(tracker->Ready());
ASSERT_GE(tracker->Port(), 1);
std::string host;
auto rc = GetHostAddress(&host);
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
rc = tracker->Shutdown();
ASSERT_TRUE(rc.OK());
ASSERT_TRUE(fut.get().OK());
ASSERT_FALSE(tracker->Ready());
}
} // namespace xgboost::collective

View File

@ -702,6 +702,10 @@ def test_sklearn_random_state():
clf = xgb.XGBClassifier(random_state=random_state) clf = xgb.XGBClassifier(random_state=random_state)
assert isinstance(clf.get_xgb_params()['random_state'], int) assert isinstance(clf.get_xgb_params()['random_state'], int)
random_state = np.random.default_rng(seed=404)
clf = xgb.XGBClassifier(random_state=random_state)
assert isinstance(clf.get_xgb_params()['random_state'], int)
def test_sklearn_n_jobs(): def test_sklearn_n_jobs():
clf = xgb.XGBClassifier(n_jobs=1) clf = xgb.XGBClassifier(n_jobs=1)