sync Mar 29

This commit is contained in:
amdsc21 2023-03-30 00:46:50 +02:00
commit acad01afc9
20 changed files with 335 additions and 115 deletions

View File

@ -180,6 +180,22 @@ class MetaInfo {
*/ */
void SynchronizeNumberOfColumns(); void SynchronizeNumberOfColumns();
/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return data_split_mode == DataSplitMode::kCol;
}
/*!
* \brief A convenient method to check if we are doing vertical federated learning, which requires
* some special processing.
*/
bool IsVerticalFederated() const;
private: private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr); void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr); void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
@ -542,16 +558,6 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_; return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
} }
/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return Info().data_split_mode == DataSplitMode::kRow;
}
/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return Info().data_split_mode == DataSplitMode::kCol;
}
/*! /*!
* \brief Load DMatrix from URI. * \brief Load DMatrix from URI.
* \param uri The URI of input. * \param uri The URI of input.

View File

@ -888,6 +888,29 @@ def _get_workers_from_data(
return list(X_worker_map) return list(X_worker_map)
def _filter_empty(
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
) -> Optional[TrainReturnT]:
n_workers = collective.get_world_size()
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
rank = collective.get_rank()
non_empty[rank] = int(is_valid)
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
non_empty = non_empty.astype(bool)
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
for i in range(non_empty.size):
# This is the first valid worker
if non_empty[i] and i == rank:
return ret
if non_empty[i]:
return None
raise ValueError("None of the workers can provide a valid result.")
async def _train_async( async def _train_async(
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
@ -973,14 +996,10 @@ async def _train_async(
xgb_model=xgb_model, xgb_model=xgb_model,
callbacks=callbacks, callbacks=callbacks,
) )
if Xy.num_row() != 0: # Don't return the boosters from empty workers. It's quite difficult to
ret: Optional[TrainReturnT] = { # guarantee everything is in sync in the present of empty workers,
"booster": booster, # especially with complex objectives like quantile.
"history": local_history, return _filter_empty(booster, local_history, Xy.num_row() != 0)
}
else:
ret = None
return ret
async with distributed.MultiLock(workers, client): async with distributed.MultiLock(workers, client):
if evals is not None: if evals is not None:

View File

@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) { if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced, HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads); m->Info().IsColumnSplit(), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian); container.PushRowPage(page, info, hessian);
} }
@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
} else { } else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced, SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads}; m->Info().IsColumnSplit(), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) { for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian); container.PushColPage(page, info, hessian);
} }

View File

@ -704,7 +704,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
} }
void MetaInfo::SynchronizeNumberOfColumns() { void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) { if (IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1); collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else { } else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1); collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
@ -770,6 +770,10 @@ void MetaInfo::Validate(std::int32_t device) const {
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); } void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}
using DMatrixThreadLocal = using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>; dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

View File

@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft); SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{ p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(), batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->IsColumnSplit(), ctx_.Threads()}); proxy->Info().IsColumnSplit(), ctx_.Threads()});
} }
HostAdapterDispatch(proxy, [&](auto const& batch) { HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i]; proxy->Info().num_nonzero_ = batch_nnz[i];

View File

@ -74,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
} }
void SimpleDMatrix::ReindexFeatures() { void SimpleDMatrix::ReindexFeatures() {
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) { if (info_.IsVerticalFederated()) {
std::vector<uint64_t> buffer(collective::GetWorldSize()); std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_; buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t)); collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));

View File

@ -860,9 +860,9 @@ class LearnerConfiguration : public Learner {
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) { void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning. // Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) { if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the estimation is calculated there // We assume labels are only available on worker 0, so the estimation is calculated there
// and added to other workers. // and broadcast to other workers.
if (collective::GetRank() == 0) { if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score); UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(), collective::Broadcast(base_score->Data()->HostPointer(),
@ -1487,7 +1487,7 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration, void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) { HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning. // Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) { if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the gradients are calculated there // We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers. // and broadcast to other workers.
if (collective::GetRank() == 0) { if (collective::GetRank() == 0) {

View File

@ -605,7 +605,7 @@ class CPUPredictor : public Predictor {
protected: protected:
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds, void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->IsColumnSplit()) { if (p_fmat->Info().IsColumnSplit()) {
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end); ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds); helper.PredictDMatrix(p_fmat, out_preds);
return; return;

View File

@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
} }
CHECK(h_sum.CContiguous()); CHECK(h_sum.CContiguous());
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce. if (info.IsRowSplit()) {
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>( collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2); reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
} }

View File

@ -449,7 +449,7 @@ class HistEvaluator {
param_{param}, param_{param},
column_sampler_{std::move(sampler)}, column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId}, tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
is_col_split_{info.data_split_mode == DataSplitMode::kCol} { is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_); interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(), column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel, param_->colsample_bynode, param_->colsample_bylevel,

View File

@ -72,12 +72,13 @@ class GloablApproxBuilder {
} else { } else {
CHECK_EQ(n_total_bins, page.cut.TotalBins()); CHECK_EQ(n_total_bins, page.cut.TotalBins());
} }
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit()); partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
p_fmat->Info().IsColumnSplit());
n_batches_++; n_batches_++;
} }
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_, histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
collective::IsDistributed(), p_fmat->IsColumnSplit()); collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }
@ -91,7 +92,7 @@ class GloablApproxBuilder {
for (auto const &g : gpair) { for (auto const &g : gpair) {
root_sum.Add(g); root_sum.Add(g);
} }
if (p_fmat->IsRowSplit()) { if (p_fmat->Info().IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2); collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
} }
std::vector<CPUExpandEntry> nodes{best}; std::vector<CPUExpandEntry> nodes{best};

View File

@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
} else { } else {
CHECK_EQ(n_total_bins, page.cut.TotalBins()); CHECK_EQ(n_total_bins, page.cut.TotalBins());
} }
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit()); partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
page_id++; page_id++;
} }
@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
for (std::size_t i = 0; i < n_targets; ++i) { for (std::size_t i = 0; i < n_targets; ++i) {
histogram_builder_.emplace_back(); histogram_builder_.emplace_back();
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), p_fmat->IsColumnSplit()); collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
} }
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_); evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
@ -388,11 +388,12 @@ class HistBuilder {
} else { } else {
CHECK_EQ(n_total_bins, page.cut.TotalBins()); CHECK_EQ(n_total_bins, page.cut.TotalBins());
} }
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit()); partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
++page_id; ++page_id;
} }
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), fmat->IsColumnSplit()); collective::IsDistributed(), fmat->Info().IsColumnSplit());
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(), evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
col_sampler_); col_sampler_);
p_last_tree_ = p_tree; p_last_tree_ = p_tree;

View File

@ -1,19 +1,102 @@
import argparse
import base64
import glob import glob
import hashlib
import os
import pathlib
import re import re
import sys import shutil
import zipfile import tempfile
if len(sys.argv) != 2: VCOMP140_PATH = "C:\\Windows\\System32\\vcomp140.dll"
print('Usage: {} [wheel]'.format(sys.argv[0]))
sys.exit(1)
vcomp140_path = 'C:\\Windows\\System32\\vcomp140.dll'
for wheel_path in sorted(glob.glob(sys.argv[1])): def get_sha256sum(path):
m = re.search(r'xgboost-(.*)-py3', wheel_path) return (
assert m, f'wheel_path = {wheel_path}' base64.urlsafe_b64encode(hashlib.sha256(open(path, "rb").read()).digest())
.decode("latin1")
.rstrip("=")
)
def update_record(*, wheel_content_dir, xgboost_version):
vcomp140_size = os.path.getsize(VCOMP140_PATH)
vcomp140_hash = get_sha256sum(VCOMP140_PATH)
record_path = wheel_content_dir / pathlib.Path(
f"xgboost-{xgboost_version}.dist-info/RECORD"
)
with open(record_path, "r") as f:
record_content = f.read()
record_content += f"xgboost-{xgboost_version}.data/data/xgboost/vcomp140.dll,"
record_content += f"sha256={vcomp140_hash},{vcomp140_size}\n"
with open(record_path, "w") as f:
f.write(record_content)
def main(args):
candidates = list(sorted(glob.glob(args.wheel_path)))
for wheel_path in candidates:
print(f"Processing wheel {wheel_path}")
m = re.search(r"xgboost-(.*)\+.*-py3", wheel_path)
if not m:
raise ValueError(f"Wheel {wheel_path} has unexpected name")
version = m.group(1) version = m.group(1)
print(f" Detected version for {wheel_path}: {version}")
print(f" Inserting vcomp140.dll into {wheel_path}...") print(f" Inserting vcomp140.dll into {wheel_path}...")
with zipfile.ZipFile(wheel_path, 'a') as f: with tempfile.TemporaryDirectory() as tempdir:
f.write(vcomp140_path, 'xgboost-{}.data/data/xgboost/vcomp140.dll'.format(version)) wheel_content_dir = pathlib.Path(tempdir) / "wheel_content"
print(f" Extract {wheel_path} into {wheel_content_dir}")
shutil.unpack_archive(
wheel_path, extract_dir=wheel_content_dir, format="zip"
)
data_dir = wheel_content_dir / pathlib.Path(
f"xgboost-{version}.data/data/xgboost"
)
data_dir.mkdir(parents=True, exist_ok=True)
print(f" Copy {VCOMP140_PATH} -> {data_dir}")
shutil.copy(VCOMP140_PATH, data_dir)
print(f" Update RECORD")
update_record(wheel_content_dir=wheel_content_dir, xgboost_version=version)
print(f" Content of {wheel_content_dir}:")
for e in sorted(wheel_content_dir.rglob("*")):
if e.is_file():
r = e.relative_to(wheel_content_dir)
print(f" {r}")
print(f" Create new wheel...")
new_wheel_tmp_path = pathlib.Path(tempdir) / "new_wheel"
shutil.make_archive(
str(new_wheel_tmp_path.resolve()),
format="zip",
root_dir=wheel_content_dir,
)
new_wheel_tmp_path = new_wheel_tmp_path.resolve().with_suffix(".zip")
new_wheel_tmp_path = new_wheel_tmp_path.rename(
new_wheel_tmp_path.with_suffix(".whl")
)
print(f" Created new wheel {new_wheel_tmp_path}")
# Rename the old wheel with suffix .bak
# The new wheel takes the name of the old wheel
wheel_path_obj = pathlib.Path(wheel_path).resolve()
backup_path = wheel_path_obj.with_suffix(".whl.bak")
print(f" Rename {wheel_path_obj} -> {backup_path}")
wheel_path_obj.replace(backup_path)
print(f" Rename {new_wheel_tmp_path} -> {wheel_path_obj}")
new_wheel_tmp_path.replace(wheel_path_obj)
shutil.rmtree(wheel_content_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"wheel_path", type=str, help="Path to wheel (wildcard permitted)"
)
args = parser.parse_args()
main(args)

View File

@ -191,15 +191,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
} }
namespace xgboost { namespace xgboost {
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
std::vector<xgboost::bst_float>::const_iterator _end1, float GetBaseScore(Json const &config) {
std::vector<xgboost::bst_float>::const_iterator _beg2) { return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) {
if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){
return false;
}
}
return true;
} }
SimpleLCG::StateType SimpleLCG::operator()() { SimpleLCG::StateType SimpleLCG::operator()() {

View File

@ -101,9 +101,8 @@ double GetMultiMetricEval(xgboost::Metric* metric,
std::vector<xgboost::bst_uint> groups = {}); std::vector<xgboost::bst_uint> groups = {});
namespace xgboost { namespace xgboost {
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
std::vector<xgboost::bst_float>::const_iterator _end1, float GetBaseScore(Json const &config);
std::vector<xgboost::bst_float>::const_iterator _beg2);
/*! /*!
* \brief Linear congruential generator. * \brief Linear congruential generator.

View File

@ -52,18 +52,33 @@ class BaseFederatedTest : public ::testing::Test {
server_thread_->join(); server_thread_->join();
} }
void InitCommunicator(int rank) {
Json config{JsonObject()};
config["xgboost_communicator"] = String("federated");
config["federated_server_address"] = String(server_address_);
config["federated_world_size"] = kWorldSize;
config["federated_rank"] = rank;
xgboost::collective::Init(config);
}
static int const kWorldSize{3}; static int const kWorldSize{3};
std::string server_address_; std::string server_address_;
std::unique_ptr<std::thread> server_thread_; std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_; std::unique_ptr<grpc::Server> server_;
}; };
template <typename Function, typename... Args>
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
Function&& function, Args&&... args) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < world_size; rank++) {
threads.emplace_back([&, rank]() {
Json config{JsonObject()};
config["xgboost_communicator"] = String("federated");
config["federated_server_address"] = String(server_address);
config["federated_world_size"] = world_size;
config["federated_rank"] = rank;
xgboost::collective::Init(config);
std::forward<Function>(function)(std::forward<Args>(args)...);
xgboost::collective::Finalize();
});
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace xgboost } // namespace xgboost

View File

@ -1,12 +1,9 @@
/*! /*!
* Copyright 2023 XGBoost contributors * Copyright 2023 XGBoost contributors
*/ */
#include <dmlc/parameter.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <fstream>
#include <iostream>
#include <thread> #include <thread>
#include "../../../plugin/federated/federated_server.h" #include "../../../plugin/federated/federated_server.h"
@ -17,10 +14,10 @@
namespace xgboost { namespace xgboost {
class FederatedDataTest : public BaseFederatedTest { class FederatedDataTest : public BaseFederatedTest {};
public:
void VerifyLoadUri(int rank) { void VerifyLoadUri() {
InitCommunicator(rank); auto const rank = collective::GetRank();
size_t constexpr kRows{16}; size_t constexpr kRows{16};
size_t const kCols = 8 + rank; size_t const kCols = 8 + rank;
@ -33,7 +30,7 @@ class FederatedDataTest : public BaseFederatedTest {
std::string uri = path + "?format=csv"; std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol)); dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3); ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3);
ASSERT_EQ(dmat->Info().num_row_, kRows); ASSERT_EQ(dmat->Info().num_row_, kRows);
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
@ -48,18 +45,9 @@ class FederatedDataTest : public BaseFederatedTest {
} }
} }
} }
xgboost::collective::Finalize();
} }
};
TEST_F(FederatedDataTest, LoadUri) { TEST_F(FederatedDataTest, LoadUri) {
std::vector<std::thread> threads; RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri);
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
}
for (auto& thread : threads) {
thread.join();
}
} }
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,78 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/objective.h>
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
#include "../helpers.h"
#include "helpers.h"
namespace xgboost {
class FederatedLearnerTest : public BaseFederatedTest {
protected:
static auto constexpr kRows{16};
static auto constexpr kCols{16};
};
void VerifyBaseScore(size_t rows, size_t cols, float expected_base_score) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, sliced);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_EQ(base_score, expected_base_score);
}
void VerifyModel(size_t rows, size_t cols, Json const& expected_model) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, sliced);
Json model{Object{}};
learner->SaveModel(&model);
ASSERT_EQ(model, expected_model);
}
TEST_F(FederatedLearnerTest, BaseScore) {
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, Xy_);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyBaseScore, kRows, kCols,
base_score);
}
TEST_F(FederatedLearnerTest, Model) {
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, Xy_);
Json model{Object{}};
learner->SaveModel(&model);
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyModel, kRows, kCols,
std::cref(model));
}
} // namespace xgboost

View File

@ -460,10 +460,6 @@ class InitBaseScore : public ::testing::Test {
void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); } void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
static float GetBaseScore(Json const &config) {
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
}
public: public:
void TestUpdateConfig() { void TestUpdateConfig() {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})}; std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
@ -611,4 +607,32 @@ TEST_F(InitBaseScore, InitAfterLoad) { this->TestInitAfterLoad(); }
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); } TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); } TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
void TestColumnSplitBaseScore(std::shared_ptr<DMatrix> Xy_, float expected_base_score) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, sliced);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_EQ(base_score, expected_base_score);
}
TEST_F(InitBaseScore, ColumnSplit) {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, Xy_);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplitBaseScore, Xy_, base_score);
}
} // namespace xgboost } // namespace xgboost

View File

@ -6,11 +6,12 @@
#include "../../src/common/linalg_op.h" #include "../../src/common/linalg_op.h"
#include "../../src/tree/fit_stump.h" #include "../../src/tree/fit_stump.h"
#include "../helpers.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
namespace { namespace {
void TestFitStump(Context const *ctx) { void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) {
std::size_t constexpr kRows = 16, kTargets = 2; std::size_t constexpr kRows = 16, kTargets = 2;
HostDeviceVector<GradientPair> gpair; HostDeviceVector<GradientPair> gpair;
auto &h_gpair = gpair.HostVector(); auto &h_gpair = gpair.HostVector();
@ -22,6 +23,7 @@ void TestFitStump(Context const *ctx) {
} }
linalg::Vector<float> out; linalg::Vector<float> out;
MetaInfo info; MetaInfo info;
info.data_split_mode = split;
FitStump(ctx, info, gpair, kTargets, &out); FitStump(ctx, info, gpair, kTargets, &out);
auto h_out = out.HostView(); auto h_out = out.HostView();
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) { for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
@ -45,5 +47,12 @@ TEST(InitEstimation, GPUFitStump) {
TestFitStump(&ctx); TestFitStump(&ctx);
} }
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) #endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
TEST(InitEstimation, FitStumpColumnSplit) {
Context ctx;
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol);
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost