Merge branch 'master'
This commit is contained in:
@@ -12,14 +12,17 @@ sysctl -n machdep.cpu.brand_string
|
||||
uname -m
|
||||
set +x
|
||||
|
||||
# Create new Conda env
|
||||
echo "--- Set up Conda env"
|
||||
. $HOME/mambaforge/etc/profile.d/conda.sh
|
||||
. $HOME/mambaforge/etc/profile.d/mamba.sh
|
||||
conda_env=xgboost_dev_$(uuidgen | tr '[:upper:]' '[:lower:]' | tr -d '-')
|
||||
mamba create -y -n ${conda_env} python=3.8
|
||||
conda activate ${conda_env}
|
||||
mamba env update -n ${conda_env} --file tests/ci_build/conda_env/macos_cpu_test.yml
|
||||
# Build XGBoost4J binary
|
||||
echo "--- Build libxgboost4j.dylib"
|
||||
set -x
|
||||
mkdir build
|
||||
pushd build
|
||||
export JAVA_HOME=$(/usr/libexec/java_home)
|
||||
cmake .. -GNinja -DJVM_BINDINGS=ON -DUSE_OPENMP=OFF -DCMAKE_OSX_DEPLOYMENT_TARGET=10.15
|
||||
ninja -v
|
||||
popd
|
||||
rm -rf build
|
||||
set +x
|
||||
|
||||
# Ensure that XGBoost can be built with Clang 11
|
||||
echo "--- Build and Test XGBoost with MacOS M1, Clang 11"
|
||||
|
||||
@@ -27,9 +27,15 @@ class PrintWorker : public WorkerForTest {
|
||||
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
ASSERT_FALSE(tracker.Ready());
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
auto args = tracker.WorkerArgs();
|
||||
ASSERT_TRUE(tracker.Ready());
|
||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
@@ -47,6 +53,9 @@ TEST_F(TrackerTest, Print) {
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
|
||||
36
tests/cpp/plugin/federated/test_federated_tracker.cc
Normal file
36
tests/cpp/plugin/federated/test_federated_tracker.cc
Normal file
@@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory> // for make_unique
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "federated_tracker.h"
|
||||
#include "test_worker.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(FederatedTrackerTest, Basic) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{3};
|
||||
|
||||
auto tracker = std::make_unique<FederatedTracker>(config);
|
||||
ASSERT_FALSE(tracker->Ready());
|
||||
auto fut = tracker->Run();
|
||||
auto args = tracker->WorkerArgs();
|
||||
ASSERT_TRUE(tracker->Ready());
|
||||
|
||||
ASSERT_GE(tracker->Port(), 1);
|
||||
std::string host;
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||
|
||||
rc = tracker->Shutdown();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
ASSERT_FALSE(tracker->Ready());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -702,6 +702,10 @@ def test_sklearn_random_state():
|
||||
clf = xgb.XGBClassifier(random_state=random_state)
|
||||
assert isinstance(clf.get_xgb_params()['random_state'], int)
|
||||
|
||||
random_state = np.random.default_rng(seed=404)
|
||||
clf = xgb.XGBClassifier(random_state=random_state)
|
||||
assert isinstance(clf.get_xgb_params()['random_state'], int)
|
||||
|
||||
|
||||
def test_sklearn_n_jobs():
|
||||
clf = xgb.XGBClassifier(n_jobs=1)
|
||||
|
||||
Reference in New Issue
Block a user