Support vertical federated learning (#8932)
This commit is contained in:
parent
8dc1e4b3ea
commit
b240f055d3
@ -171,6 +171,15 @@ class MetaInfo {
|
|||||||
*/
|
*/
|
||||||
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Synchronize the number of columns across all workers.
|
||||||
|
*
|
||||||
|
* Normally we just need to find the maximum number of columns across all workers, but
|
||||||
|
* in vertical federated learning, since each worker loads its own list of columns,
|
||||||
|
* we need to sum them.
|
||||||
|
*/
|
||||||
|
void SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||||
@ -325,6 +334,10 @@ class SparsePage {
|
|||||||
* \brief Check wether the column index is sorted.
|
* \brief Check wether the column index is sorted.
|
||||||
*/
|
*/
|
||||||
bool IsIndicesSorted(int32_t n_threads) const;
|
bool IsIndicesSorted(int32_t n_threads) const;
|
||||||
|
/**
|
||||||
|
* \brief Reindex the column index with an offset.
|
||||||
|
*/
|
||||||
|
void Reindex(uint64_t feature_offset, int32_t n_threads);
|
||||||
|
|
||||||
void SortRows(int32_t n_threads);
|
void SortRows(int32_t n_threads);
|
||||||
|
|
||||||
@ -559,17 +572,18 @@ class DMatrix {
|
|||||||
* \brief Creates a new DMatrix from an external data adapter.
|
* \brief Creates a new DMatrix from an external data adapter.
|
||||||
*
|
*
|
||||||
* \tparam AdapterT Type of the adapter.
|
* \tparam AdapterT Type of the adapter.
|
||||||
* \param [in,out] adapter View onto an external data.
|
* \param [in,out] adapter View onto an external data.
|
||||||
* \param missing Values to count as missing.
|
* \param missing Values to count as missing.
|
||||||
* \param nthread Number of threads for construction.
|
* \param nthread Number of threads for construction.
|
||||||
* \param cache_prefix (Optional) The cache prefix for external memory.
|
* \param cache_prefix (Optional) The cache prefix for external memory.
|
||||||
* \param page_size (Optional) Size of the page.
|
* \param data_split_mode (Optional) Data split mode.
|
||||||
*
|
*
|
||||||
* \return a Created DMatrix.
|
* \return a Created DMatrix.
|
||||||
*/
|
*/
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
|
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix = "");
|
const std::string& cache_prefix = "",
|
||||||
|
DataSplitMode data_split_mode = DataSplitMode::kRow);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
|
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
|
||||||
|
|||||||
@ -703,6 +703,14 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MetaInfo::SynchronizeNumberOfColumns() {
|
||||||
|
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||||
|
} else {
|
||||||
|
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void MetaInfo::Validate(std::int32_t device) const {
|
void MetaInfo::Validate(std::int32_t device) const {
|
||||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||||
@ -870,7 +878,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
|||||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||||
data::FileAdapter adapter(parser.get());
|
data::FileAdapter adapter(parser.get());
|
||||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||||
cache_file);
|
cache_file, data_split_mode);
|
||||||
} else {
|
} else {
|
||||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
||||||
file_format};
|
file_format};
|
||||||
@ -906,11 +914,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
|||||||
LOG(FATAL) << "Encountered parser error:\n" << e.what();
|
LOG(FATAL) << "Encountered parser error:\n" << e.what();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* sync up number of features after matrix loaded.
|
|
||||||
* partitioned data will fail the train/val validation check
|
|
||||||
* since partitioned data not knowing the real number of features. */
|
|
||||||
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
|
||||||
|
|
||||||
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
||||||
if (!cache_file.empty()) {
|
if (!cache_file.empty()) {
|
||||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||||
@ -920,7 +923,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
return sliced;
|
return sliced;
|
||||||
} else {
|
} else {
|
||||||
dmat->Info().data_split_mode = data_split_mode;
|
|
||||||
return dmat;
|
return dmat;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -957,39 +959,49 @@ template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
|
|||||||
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
|
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&) {
|
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
||||||
return new data::SimpleDMatrix(adapter, missing, nthread);
|
DataSplitMode data_split_mode) {
|
||||||
|
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter, float missing,
|
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter, float missing,
|
||||||
std::int32_t nthread,
|
std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::ArrayAdapter>(data::ArrayAdapter* adapter, float missing,
|
template DMatrix* DMatrix::Create<data::ArrayAdapter>(data::ArrayAdapter* adapter, float missing,
|
||||||
std::int32_t nthread,
|
std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter, float missing,
|
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter, float missing,
|
||||||
std::int32_t nthread,
|
std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter, float missing,
|
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter, float missing,
|
||||||
std::int32_t nthread,
|
std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::DataTableAdapter>(data::DataTableAdapter* adapter,
|
template DMatrix* DMatrix::Create<data::DataTableAdapter>(data::DataTableAdapter* adapter,
|
||||||
float missing, std::int32_t nthread,
|
float missing, std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter, float missing,
|
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter, float missing,
|
||||||
std::int32_t nthread,
|
std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(data::CSRArrayAdapter* adapter,
|
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(data::CSRArrayAdapter* adapter,
|
||||||
float missing, std::int32_t nthread,
|
float missing, std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter* adapter,
|
template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter* adapter,
|
||||||
float missing, std::int32_t nthread,
|
float missing, std::int32_t nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create(
|
template DMatrix* DMatrix::Create(
|
||||||
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||||
float missing, int nthread, const std::string& cache_prefix);
|
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
|
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
|
||||||
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&);
|
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
|
|
||||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||||
SparsePage transpose;
|
SparsePage transpose;
|
||||||
@ -1051,6 +1063,13 @@ void SparsePage::SortIndices(int32_t n_threads) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) {
|
||||||
|
auto& h_data = this->data.HostVector();
|
||||||
|
common::ParallelFor(h_data.size(), n_threads, [&](auto i) {
|
||||||
|
h_data[i].index += feature_offset;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void SparsePage::SortRows(int32_t n_threads) {
|
void SparsePage::SortRows(int32_t n_threads) {
|
||||||
auto& h_offset = this->offset.HostVector();
|
auto& h_offset = this->offset.HostVector();
|
||||||
auto& h_data = this->data.HostVector();
|
auto& h_data = this->data.HostVector();
|
||||||
|
|||||||
@ -170,17 +170,17 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
|
|||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix) {
|
const std::string& cache_prefix, DataSplitMode data_split_mode) {
|
||||||
CHECK_EQ(cache_prefix.size(), 0)
|
CHECK_EQ(cache_prefix.size(), 0)
|
||||||
<< "Device memory construction is not currently supported with external "
|
<< "Device memory construction is not currently supported with external "
|
||||||
"memory.";
|
"memory.";
|
||||||
return new data::SimpleDMatrix(adapter, missing, nthread);
|
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
template DMatrix* DMatrix::Create<data::CudfAdapter>(
|
template DMatrix* DMatrix::Create<data::CudfAdapter>(
|
||||||
data::CudfAdapter* adapter, float missing, int nthread,
|
data::CudfAdapter* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||||
template DMatrix* DMatrix::Create<data::CupyAdapter>(
|
template DMatrix* DMatrix::Create<data::CupyAdapter>(
|
||||||
data::CupyAdapter* adapter, float missing, int nthread,
|
data::CupyAdapter* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix);
|
const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -190,7 +190,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
// From here on Info() has the correct data shape
|
// From here on Info() has the correct data shape
|
||||||
Info().num_row_ = accumulated_rows;
|
Info().num_row_ = accumulated_rows;
|
||||||
Info().num_nonzero_ = nnz;
|
Info().num_nonzero_ = nnz;
|
||||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
Info().SynchronizeNumberOfColumns();
|
||||||
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
||||||
return f > accumulated_rows;
|
return f > accumulated_rows;
|
||||||
})) << "Something went wrong during iteration.";
|
})) << "Something went wrong during iteration.";
|
||||||
|
|||||||
@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
|
|
||||||
iter.Reset();
|
iter.Reset();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
info_.SynchronizeNumberOfColumns();
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
||||||
|
|||||||
@ -73,6 +73,19 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SimpleDMatrix::ReindexFeatures() {
|
||||||
|
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
|
||||||
|
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||||
|
buffer[collective::GetRank()] = info_.num_col_;
|
||||||
|
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
||||||
|
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
|
||||||
|
if (offset == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sparse_page_->Reindex(offset, ctx_.Threads());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||||
// since csr is the default data structure so `source_` is always available.
|
// since csr is the default data structure so `source_` is always available.
|
||||||
auto begin_iter = BatchIterator<SparsePage>(
|
auto begin_iter = BatchIterator<SparsePage>(
|
||||||
@ -151,7 +164,8 @@ BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||||
|
DataSplitMode data_split_mode) {
|
||||||
this->ctx_.nthread = nthread;
|
this->ctx_.nthread = nthread;
|
||||||
|
|
||||||
std::vector<uint64_t> qids;
|
std::vector<uint64_t> qids;
|
||||||
@ -217,7 +231,9 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
|||||||
|
|
||||||
|
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
info_.data_split_mode = data_split_mode;
|
||||||
|
ReindexFeatures();
|
||||||
|
info_.SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||||
using IteratorAdapterT
|
using IteratorAdapterT
|
||||||
@ -272,22 +288,31 @@ void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
|
|||||||
fo->Write(sparse_page_->data.HostVector());
|
fo->Write(sparse_page_->data.HostVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread);
|
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread,
|
||||||
template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing, int nthread);
|
DataSplitMode data_split_mode);
|
||||||
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, int nthread);
|
template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing, int nthread,
|
||||||
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing, int nthread);
|
DataSplitMode data_split_mode);
|
||||||
template SimpleDMatrix::SimpleDMatrix(CSCArrayAdapter* adapter, float missing, int nthread);
|
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, int nthread,
|
||||||
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, int nthread);
|
DataSplitMode data_split_mode);
|
||||||
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread);
|
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing, int nthread,
|
||||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread);
|
DataSplitMode data_split_mode);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(CSCArrayAdapter* adapter, float missing, int nthread,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, int nthread,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread,
|
||||||
|
DataSplitMode data_split_mode);
|
||||||
template SimpleDMatrix::SimpleDMatrix(
|
template SimpleDMatrix::SimpleDMatrix(
|
||||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
||||||
*adapter,
|
*adapter,
|
||||||
float missing, int nthread);
|
float missing, int nthread, DataSplitMode data_split_mode);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) {
|
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
|
||||||
ctx_.nthread = nthread;
|
DataSplitMode data_split_mode) {
|
||||||
|
ctx_.nthread = nthread;
|
||||||
|
|
||||||
auto& offset_vec = sparse_page_->offset.HostVector();
|
auto& offset_vec = sparse_page_->offset.HostVector();
|
||||||
auto& data_vec = sparse_page_->data.HostVector();
|
auto& data_vec = sparse_page_->data.HostVector();
|
||||||
@ -346,7 +371,10 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
|||||||
}
|
}
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
info_.data_split_mode = data_split_mode;
|
||||||
|
ReindexFeatures();
|
||||||
|
info_.SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
info_.num_row_ = total_batch_size;
|
info_.num_row_ = total_batch_size;
|
||||||
info_.num_nonzero_ = data_vec.size();
|
info_.num_nonzero_ = data_vec.size();
|
||||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||||
|
|||||||
@ -15,7 +15,10 @@ namespace data {
|
|||||||
// Current implementation assumes a single batch. More batches can
|
// Current implementation assumes a single batch. More batches can
|
||||||
// be supported in future. Does not currently support inferring row/column size
|
// be supported in future. Does not currently support inferring row/column size
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/) {
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/,
|
||||||
|
DataSplitMode data_split_mode) {
|
||||||
|
CHECK(data_split_mode != DataSplitMode::kCol)
|
||||||
|
<< "Column-wise data split is currently not supported on the GPU.";
|
||||||
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
|
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
|
||||||
: adapter->DeviceIdx();
|
: adapter->DeviceIdx();
|
||||||
CHECK_GE(device, 0);
|
CHECK_GE(device, 0);
|
||||||
@ -35,12 +38,13 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
|||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
info_.data_split_mode = data_split_mode;
|
||||||
|
info_.SynchronizeNumberOfColumns();
|
||||||
}
|
}
|
||||||
|
|
||||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||||
int nthread);
|
int nthread, DataSplitMode data_split_mode);
|
||||||
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
|
||||||
int nthread);
|
int nthread, DataSplitMode data_split_mode);
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -22,7 +22,8 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
public:
|
public:
|
||||||
SimpleDMatrix() = default;
|
SimpleDMatrix() = default;
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
|
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||||
|
DataSplitMode data_split_mode = DataSplitMode::kRow);
|
||||||
|
|
||||||
explicit SimpleDMatrix(dmlc::Stream* in_stream);
|
explicit SimpleDMatrix(dmlc::Stream* in_stream);
|
||||||
~SimpleDMatrix() override = default;
|
~SimpleDMatrix() override = default;
|
||||||
@ -61,6 +62,15 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
bool GHistIndexExists() const override { return static_cast<bool>(gradient_index_); }
|
bool GHistIndexExists() const override { return static_cast<bool>(gradient_index_); }
|
||||||
bool SparsePageExists() const override { return true; }
|
bool SparsePageExists() const override { return true; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Reindex the features based on a global view.
|
||||||
|
*
|
||||||
|
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
|
||||||
|
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
|
||||||
|
* reindex the features based on the offset needed to obtain the global view.
|
||||||
|
*/
|
||||||
|
void ReindexFeatures();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Context ctx_;
|
Context ctx_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -96,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
this->info_.num_col_ = n_features;
|
this->info_.num_col_ = n_features;
|
||||||
this->info_.num_nonzero_ = nnz;
|
this->info_.num_nonzero_ = nnz;
|
||||||
|
|
||||||
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
|
info_.SynchronizeNumberOfColumns();
|
||||||
CHECK_NE(info_.num_col_, 0);
|
CHECK_NE(info_.num_col_, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -440,7 +440,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
info.Validate(Ctx()->gpu_id);
|
info.Validate(Ctx()->gpu_id);
|
||||||
// We estimate it from input data.
|
// We estimate it from input data.
|
||||||
linalg::Tensor<float, 1> base_score;
|
linalg::Tensor<float, 1> base_score;
|
||||||
UsePtr(obj_)->InitEstimation(info, &base_score);
|
InitEstimation(info, &base_score);
|
||||||
CHECK_EQ(base_score.Size(), 1);
|
CHECK_EQ(base_score.Size(), 1);
|
||||||
mparam_.base_score = base_score(0);
|
mparam_.base_score = base_score(0);
|
||||||
CHECK(!std::isnan(mparam_.base_score));
|
CHECK(!std::isnan(mparam_.base_score));
|
||||||
@ -857,6 +857,25 @@ class LearnerConfiguration : public Learner {
|
|||||||
mparam_.num_target = n_targets;
|
mparam_.num_target = n_targets;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||||
|
// Special handling for vertical federated learning.
|
||||||
|
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
||||||
|
// We assume labels are only available on worker 0, so the estimation is calculated there
|
||||||
|
// and added to other workers.
|
||||||
|
if (collective::GetRank() == 0) {
|
||||||
|
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||||
|
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||||
|
sizeof(bst_float) * base_score->Size(), 0);
|
||||||
|
} else {
|
||||||
|
base_score->Reshape(1);
|
||||||
|
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||||
|
sizeof(bst_float) * base_score->Size(), 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
|
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
|
||||||
@ -1307,7 +1326,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
monitor_.Stop("PredictRaw");
|
monitor_.Stop("PredictRaw");
|
||||||
|
|
||||||
monitor_.Start("GetGradient");
|
monitor_.Start("GetGradient");
|
||||||
obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_);
|
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
|
||||||
monitor_.Stop("GetGradient");
|
monitor_.Stop("GetGradient");
|
||||||
TrainingObserver::Instance().Observe(gpair_, "Gradients");
|
TrainingObserver::Instance().Observe(gpair_, "Gradients");
|
||||||
|
|
||||||
@ -1486,6 +1505,28 @@ class LearnerImpl : public LearnerIO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||||
|
HostDeviceVector<GradientPair>* out_gpair) {
|
||||||
|
// Special handling for vertical federated learning.
|
||||||
|
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
||||||
|
// We assume labels are only available on worker 0, so the gradients are calculated there
|
||||||
|
// and broadcast to other workers.
|
||||||
|
if (collective::GetRank() == 0) {
|
||||||
|
obj_->GetGradient(preds, info, iteration, out_gpair);
|
||||||
|
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
|
||||||
|
0);
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(info.labels.Size(), 0)
|
||||||
|
<< "In vertical federated learning, labels should only be on the first worker";
|
||||||
|
out_gpair->Resize(preds.Size());
|
||||||
|
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
obj_->GetGradient(preds, info, iteration, out_gpair);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*! \brief random number transformation seed. */
|
/*! \brief random number transformation seed. */
|
||||||
static int32_t constexpr kRandSeedMagic = 127;
|
static int32_t constexpr kRandSeedMagic = 127;
|
||||||
// gradient pairs
|
// gradient pairs
|
||||||
|
|||||||
@ -33,7 +33,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
|
|||||||
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
|
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
|
||||||
bst_target_t n_targets = this->Targets(info);
|
bst_target_t n_targets = this->Targets(info);
|
||||||
linalg::Vector<float> leaf_weight;
|
linalg::Vector<float> leaf_weight;
|
||||||
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
|
tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight);
|
||||||
|
|
||||||
// workaround, we don't support multi-target due to binary model serialization for
|
// workaround, we don't support multi-target due to binary model serialization for
|
||||||
// base margin.
|
// base margin.
|
||||||
|
|||||||
@ -21,7 +21,8 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
namespace cpu_impl {
|
namespace cpu_impl {
|
||||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||||
|
linalg::TensorView<GradientPair const, 2> gpair,
|
||||||
linalg::VectorView<float> out) {
|
linalg::VectorView<float> out) {
|
||||||
auto n_targets = out.Size();
|
auto n_targets = out.Size();
|
||||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||||
@ -43,8 +44,12 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(h_sum.CContiguous());
|
CHECK(h_sum.CContiguous());
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
|
||||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
|
||||||
|
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
|
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||||
|
}
|
||||||
|
|
||||||
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
||||||
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
||||||
@ -64,7 +69,7 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
|||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace cuda_impl
|
} // namespace cuda_impl
|
||||||
|
|
||||||
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
|
||||||
bst_target_t n_targets, linalg::Vector<float>* out) {
|
bst_target_t n_targets, linalg::Vector<float>* out) {
|
||||||
out->SetDevice(ctx->gpu_id);
|
out->SetDevice(ctx->gpu_id);
|
||||||
out->Reshape(n_targets);
|
out->Reshape(n_targets);
|
||||||
@ -72,7 +77,7 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
|||||||
|
|
||||||
gpair.SetDevice(ctx->gpu_id);
|
gpair.SetDevice(ctx->gpu_id);
|
||||||
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
|
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
|
||||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
|
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
#include "../common/common.h" // AssertGPUSupport
|
#include "../common/common.h" // AssertGPUSupport
|
||||||
#include "xgboost/base.h" // GradientPair
|
#include "xgboost/base.h" // GradientPair
|
||||||
#include "xgboost/context.h" // Context
|
#include "xgboost/context.h" // Context
|
||||||
|
#include "xgboost/data.h" // MetaInfo
|
||||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
#include "xgboost/linalg.h" // TensorView
|
#include "xgboost/linalg.h" // TensorView
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
|
|||||||
/**
|
/**
|
||||||
* @brief Fit a tree stump as an estimation of base_score.
|
* @brief Fit a tree stump as an estimation of base_score.
|
||||||
*/
|
*/
|
||||||
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
|
||||||
bst_target_t n_targets, linalg::Vector<float>* out);
|
bst_target_t n_targets, linalg::Vector<float>* out);
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -112,31 +112,12 @@ TEST(SparsePage, SortIndices) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(DMatrix, Uri) {
|
TEST(DMatrix, Uri) {
|
||||||
size_t constexpr kRows {16};
|
auto constexpr kRows {16};
|
||||||
size_t constexpr kCols {8};
|
auto constexpr kCols {8};
|
||||||
std::vector<float> data (kRows * kCols);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < kRows * kCols; ++i) {
|
|
||||||
data[i] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string path = tmpdir.path + "/small.csv";
|
auto const path = tmpdir.path + "/small.csv";
|
||||||
|
CreateTestCSV(path, kRows, kCols);
|
||||||
std::ofstream fout(path);
|
|
||||||
size_t i = 0;
|
|
||||||
for (size_t r = 0; r < kRows; ++r) {
|
|
||||||
for (size_t c = 0; c < kCols; ++c) {
|
|
||||||
fout << data[i];
|
|
||||||
i++;
|
|
||||||
if (c != kCols - 1) {
|
|
||||||
fout << ",";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fout << "\n";
|
|
||||||
}
|
|
||||||
fout.flush();
|
|
||||||
fout.close();
|
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat;
|
std::unique_ptr<DMatrix> dmat;
|
||||||
// FIXME(trivialfis): Enable the following test by restricting csv parser in dmlc-core.
|
// FIXME(trivialfis): Enable the following test by restricting csv parser in dmlc-core.
|
||||||
|
|||||||
@ -65,6 +65,29 @@ void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CreateTestCSV(std::string const& path, size_t rows, size_t cols) {
|
||||||
|
std::vector<float> data(rows * cols);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < rows * cols; ++i) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ofstream fout(path);
|
||||||
|
size_t i = 0;
|
||||||
|
for (size_t r = 0; r < rows; ++r) {
|
||||||
|
for (size_t c = 0; c < cols; ++c) {
|
||||||
|
fout << data[i];
|
||||||
|
i++;
|
||||||
|
if (c != cols - 1) {
|
||||||
|
fout << ",";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fout << "\n";
|
||||||
|
}
|
||||||
|
fout.flush();
|
||||||
|
fout.close();
|
||||||
|
}
|
||||||
|
|
||||||
void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||||
std::vector<xgboost::bst_float> preds,
|
std::vector<xgboost::bst_float> preds,
|
||||||
std::vector<xgboost::bst_float> labels,
|
std::vector<xgboost::bst_float> labels,
|
||||||
|
|||||||
@ -59,6 +59,8 @@ void CreateSimpleTestData(const std::string& filename);
|
|||||||
// 0-based indexing.
|
// 0-based indexing.
|
||||||
void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_based = true);
|
void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_based = true);
|
||||||
|
|
||||||
|
void CreateTestCSV(std::string const& path, size_t rows, size_t cols);
|
||||||
|
|
||||||
void CheckObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
void CheckObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||||
std::vector<xgboost::bst_float> preds,
|
std::vector<xgboost::bst_float> preds,
|
||||||
std::vector<xgboost::bst_float> labels,
|
std::vector<xgboost::bst_float> labels,
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
#include <chrono>
|
|
||||||
#include <thread>
|
|
||||||
#include <random>
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
#include "helpers.h"
|
|
||||||
|
|
||||||
using namespace std::chrono_literals;
|
|
||||||
|
|
||||||
int GenerateRandomPort(int low, int high) {
|
|
||||||
// Ensure unique timestamp by introducing a small artificial delay
|
|
||||||
std::this_thread::sleep_for(100ms);
|
|
||||||
auto timestamp = static_cast<uint64_t>(std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
||||||
std::chrono::system_clock::now().time_since_epoch()).count());
|
|
||||||
std::mt19937_64 rng(timestamp);
|
|
||||||
std::uniform_int_distribution<int> dist(low, high);
|
|
||||||
int port = dist(rng);
|
|
||||||
return port;
|
|
||||||
}
|
|
||||||
@ -1,10 +1,69 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022-2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#ifndef XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_
|
#include <grpcpp/server_builder.h>
|
||||||
#define XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
|
||||||
int GenerateRandomPort(int low, int high);
|
#include <random>
|
||||||
|
|
||||||
#endif // XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
|
|
||||||
|
inline int GenerateRandomPort(int low, int high) {
|
||||||
|
using namespace std::chrono_literals;
|
||||||
|
// Ensure unique timestamp by introducing a small artificial delay
|
||||||
|
std::this_thread::sleep_for(100ms);
|
||||||
|
auto timestamp = static_cast<uint64_t>(std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
std::chrono::system_clock::now().time_since_epoch())
|
||||||
|
.count());
|
||||||
|
std::mt19937_64 rng(timestamp);
|
||||||
|
std::uniform_int_distribution<int> dist(low, high);
|
||||||
|
int port = dist(rng);
|
||||||
|
return port;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string GetServerAddress() {
|
||||||
|
int port = GenerateRandomPort(50000, 60000);
|
||||||
|
std::string address = std::string("localhost:") + std::to_string(port);
|
||||||
|
return address;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
class BaseFederatedTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
server_address_ = GetServerAddress();
|
||||||
|
server_thread_.reset(new std::thread([this] {
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
xgboost::federated::FederatedService service{kWorldSize};
|
||||||
|
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
server_->Wait();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
server_->Shutdown();
|
||||||
|
server_thread_->join();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitCommunicator(int rank) {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["xgboost_communicator"] = String("federated");
|
||||||
|
config["federated_server_address"] = String(server_address_);
|
||||||
|
config["federated_world_size"] = kWorldSize;
|
||||||
|
config["federated_rank"] = rank;
|
||||||
|
xgboost::collective::Init(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int const kWorldSize{3};
|
||||||
|
std::string server_address_;
|
||||||
|
std::unique_ptr<std::thread> server_thread_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,56 +1,20 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <grpcpp/server_builder.h>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/host_vector.h>
|
#include <thrust/host_vector.h>
|
||||||
|
|
||||||
|
#include <ctime>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <ctime>
|
|
||||||
|
|
||||||
#include "./helpers.h"
|
|
||||||
#include "../../../plugin/federated/federated_communicator.h"
|
#include "../../../plugin/federated/federated_communicator.h"
|
||||||
#include "../../../plugin/federated/federated_server.h"
|
|
||||||
#include "../../../src/collective/device_communicator_adapter.cuh"
|
#include "../../../src/collective/device_communicator_adapter.cuh"
|
||||||
|
#include "./helpers.h"
|
||||||
|
|
||||||
namespace {
|
namespace xgboost::collective {
|
||||||
|
|
||||||
std::string GetServerAddress() {
|
class FederatedAdapterTest : public BaseFederatedTest {};
|
||||||
int port = GenerateRandomPort(50000, 60000);
|
|
||||||
std::string address = std::string("localhost:") + std::to_string(port);
|
|
||||||
return address;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace collective {
|
|
||||||
|
|
||||||
class FederatedAdapterTest : public ::testing::Test {
|
|
||||||
protected:
|
|
||||||
void SetUp() override {
|
|
||||||
server_address_ = GetServerAddress();
|
|
||||||
server_thread_.reset(new std::thread([this] {
|
|
||||||
grpc::ServerBuilder builder;
|
|
||||||
federated::FederatedService service{kWorldSize};
|
|
||||||
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
|
|
||||||
builder.RegisterService(&service);
|
|
||||||
server_ = builder.BuildAndStart();
|
|
||||||
server_->Wait();
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
void TearDown() override {
|
|
||||||
server_->Shutdown();
|
|
||||||
server_thread_->join();
|
|
||||||
}
|
|
||||||
|
|
||||||
static int const kWorldSize{2};
|
|
||||||
std::string server_address_;
|
|
||||||
std::unique_ptr<std::thread> server_thread_;
|
|
||||||
std::unique_ptr<grpc::Server> server_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||||
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; };
|
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; };
|
||||||
@ -65,20 +29,20 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) {
|
|||||||
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(std::thread([rank, server_address=server_address_] {
|
threads.emplace_back([rank, server_address = server_address_] {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
||||||
DeviceCommunicatorAdapter adapter{0, &comm};
|
DeviceCommunicatorAdapter adapter{0, &comm};
|
||||||
int const count = 3;
|
int count = 3;
|
||||||
thrust::device_vector<double> buffer(count, 0);
|
thrust::device_vector<double> buffer(count, 0);
|
||||||
thrust::sequence(buffer.begin(), buffer.end());
|
thrust::sequence(buffer.begin(), buffer.end());
|
||||||
adapter.AllReduceSum(buffer.data().get(), count);
|
adapter.AllReduceSum(buffer.data().get(), count);
|
||||||
thrust::host_vector<double> host_buffer = buffer;
|
thrust::host_vector<double> host_buffer = buffer;
|
||||||
EXPECT_EQ(host_buffer.size(), count);
|
EXPECT_EQ(host_buffer.size(), count);
|
||||||
for (auto i = 0; i < count; i++) {
|
for (auto i = 0; i < count; i++) {
|
||||||
EXPECT_EQ(host_buffer[i], i * 2);
|
EXPECT_EQ(host_buffer[i], i * kWorldSize);
|
||||||
}
|
}
|
||||||
}));
|
});
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -88,7 +52,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
|||||||
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
|
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(std::thread([rank, server_address=server_address_] {
|
threads.emplace_back([rank, server_address = server_address_] {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
|
||||||
DeviceCommunicatorAdapter adapter{0, &comm};
|
DeviceCommunicatorAdapter adapter{0, &comm};
|
||||||
@ -104,17 +68,16 @@ TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
|
|||||||
EXPECT_EQ(segments[0], 2);
|
EXPECT_EQ(segments[0], 2);
|
||||||
EXPECT_EQ(segments[1], 3);
|
EXPECT_EQ(segments[1], 3);
|
||||||
thrust::host_vector<char> host_buffer = receive_buffer;
|
thrust::host_vector<char> host_buffer = receive_buffer;
|
||||||
EXPECT_EQ(host_buffer.size(), 5);
|
EXPECT_EQ(host_buffer.size(), 9);
|
||||||
int expected[] = {0, 1, 0, 1, 2};
|
int expected[] = {0, 1, 0, 1, 2, 0, 1, 2, 3};
|
||||||
for (auto i = 0; i < 5; i++) {
|
for (auto i = 0; i < 9; i++) {
|
||||||
EXPECT_EQ(host_buffer[i], expected[i]);
|
EXPECT_EQ(host_buffer[i], expected[i]);
|
||||||
}
|
}
|
||||||
}));
|
});
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace collective
|
} // namespace xgboost::collective
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -2,65 +2,34 @@
|
|||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <dmlc/parameter.h>
|
#include <dmlc/parameter.h>
|
||||||
#include <grpcpp/server_builder.h>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <ctime>
|
|
||||||
|
|
||||||
#include "helpers.h"
|
|
||||||
#include "../../../plugin/federated/federated_communicator.h"
|
#include "../../../plugin/federated/federated_communicator.h"
|
||||||
#include "../../../plugin/federated/federated_server.h"
|
#include "helpers.h"
|
||||||
|
|
||||||
namespace {
|
namespace xgboost::collective {
|
||||||
|
|
||||||
std::string GetServerAddress() {
|
class FederatedCommunicatorTest : public BaseFederatedTest {
|
||||||
int port = GenerateRandomPort(50000, 60000);
|
|
||||||
std::string address = std::string("localhost:") + std::to_string(port);
|
|
||||||
return address;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace collective {
|
|
||||||
|
|
||||||
class FederatedCommunicatorTest : public ::testing::Test {
|
|
||||||
public:
|
public:
|
||||||
static void VerifyAllgather(int rank, const std::string& server_address) {
|
static void VerifyAllgather(int rank, const std::string &server_address) {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
CheckAllgather(comm, rank);
|
CheckAllgather(comm, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
static void VerifyAllreduce(int rank, const std::string &server_address) {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
CheckAllreduce(comm);
|
CheckAllreduce(comm);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void VerifyBroadcast(int rank, const std::string& server_address) {
|
static void VerifyBroadcast(int rank, const std::string &server_address) {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
CheckBroadcast(comm, rank);
|
CheckBroadcast(comm, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
|
||||||
server_address_ = GetServerAddress();
|
|
||||||
server_thread_.reset(new std::thread([this] {
|
|
||||||
grpc::ServerBuilder builder;
|
|
||||||
federated::FederatedService service{kWorldSize};
|
|
||||||
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
|
|
||||||
builder.RegisterService(&service);
|
|
||||||
server_ = builder.BuildAndStart();
|
|
||||||
server_->Wait();
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
void TearDown() override {
|
|
||||||
server_->Shutdown();
|
|
||||||
server_thread_->join();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
|
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
|
||||||
int buffer[kWorldSize] = {0, 0, 0};
|
int buffer[kWorldSize] = {0, 0, 0};
|
||||||
buffer[rank] = rank;
|
buffer[rank] = rank;
|
||||||
@ -90,11 +59,6 @@ class FederatedCommunicatorTest : public ::testing::Test {
|
|||||||
EXPECT_EQ(buffer, "hello");
|
EXPECT_EQ(buffer, "hello");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int const kWorldSize{3};
|
|
||||||
std::string server_address_;
|
|
||||||
std::unique_ptr<std::thread> server_thread_;
|
|
||||||
std::unique_ptr<grpc::Server> server_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
|
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
|
||||||
@ -161,8 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
|
|||||||
TEST_F(FederatedCommunicatorTest, Allgather) {
|
TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(
|
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_);
|
||||||
std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_));
|
|
||||||
}
|
}
|
||||||
for (auto &thread : threads) {
|
for (auto &thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -172,8 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) {
|
|||||||
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(
|
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_);
|
||||||
std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_));
|
|
||||||
}
|
}
|
||||||
for (auto &thread : threads) {
|
for (auto &thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -183,12 +145,10 @@ TEST_F(FederatedCommunicatorTest, Allreduce) {
|
|||||||
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(
|
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_);
|
||||||
std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_));
|
|
||||||
}
|
}
|
||||||
for (auto &thread : threads) {
|
for (auto &thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace collective
|
} // namespace xgboost::collective
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
65
tests/cpp/plugin/test_federated_data.cc
Normal file
65
tests/cpp/plugin/test_federated_data.cc
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2023 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
|
#include "../filesystem.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
class FederatedDataTest : public BaseFederatedTest {
|
||||||
|
public:
|
||||||
|
void VerifyLoadUri(int rank) {
|
||||||
|
InitCommunicator(rank);
|
||||||
|
|
||||||
|
size_t constexpr kRows{16};
|
||||||
|
size_t const kCols = 8 + rank;
|
||||||
|
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
|
||||||
|
CreateTestCSV(path, kRows, kCols);
|
||||||
|
|
||||||
|
std::unique_ptr<DMatrix> dmat;
|
||||||
|
std::string uri = path + "?format=csv";
|
||||||
|
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
|
||||||
|
|
||||||
|
ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3);
|
||||||
|
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||||
|
|
||||||
|
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||||
|
auto entries = page.GetView().data;
|
||||||
|
auto index = 0;
|
||||||
|
int offsets[] = {0, 8, 17};
|
||||||
|
int offset = offsets[rank];
|
||||||
|
for (auto row = 0; row < kRows; row++) {
|
||||||
|
for (auto col = 0; col < kCols; col++) {
|
||||||
|
EXPECT_EQ(entries[index].index, col + offset);
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xgboost::collective::Finalize();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FederatedDataTest, LoadUri) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,30 +1,17 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2020 XGBoost contributors
|
* Copyright 2017-2020 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <grpcpp/server_builder.h>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <ctime>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
#include "federated_client.h"
|
#include "federated_client.h"
|
||||||
#include "federated_server.h"
|
|
||||||
#include "helpers.h"
|
#include "helpers.h"
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
std::string GetServerAddress() {
|
|
||||||
int port = GenerateRandomPort(50000, 60000);
|
|
||||||
std::string address = std::string("localhost:") + std::to_string(port);
|
|
||||||
return address;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
class FederatedServerTest : public ::testing::Test {
|
class FederatedServerTest : public BaseFederatedTest {
|
||||||
public:
|
public:
|
||||||
static void VerifyAllgather(int rank, const std::string& server_address) {
|
static void VerifyAllgather(int rank, const std::string& server_address) {
|
||||||
federated::FederatedClient client{server_address, rank};
|
federated::FederatedClient client{server_address, rank};
|
||||||
@ -51,23 +38,6 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
|
||||||
server_address_ = GetServerAddress();
|
|
||||||
server_thread_.reset(new std::thread([this] {
|
|
||||||
grpc::ServerBuilder builder;
|
|
||||||
federated::FederatedService service{kWorldSize};
|
|
||||||
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
|
|
||||||
builder.RegisterService(&service);
|
|
||||||
server_ = builder.BuildAndStart();
|
|
||||||
server_->Wait();
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
void TearDown() override {
|
|
||||||
server_->Shutdown();
|
|
||||||
server_thread_->join();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
||||||
int data[kWorldSize] = {0, 0, 0};
|
int data[kWorldSize] = {0, 0, 0};
|
||||||
data[rank] = rank;
|
data[rank] = rank;
|
||||||
@ -98,17 +68,12 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
auto reply = client.Broadcast(send_buffer, 0);
|
auto reply = client.Broadcast(send_buffer, 0);
|
||||||
EXPECT_EQ(reply, "hello broadcast") << "rank " << rank;
|
EXPECT_EQ(reply, "hello broadcast") << "rank " << rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int const kWorldSize{3};
|
|
||||||
std::string server_address_;
|
|
||||||
std::unique_ptr<std::thread> server_thread_;
|
|
||||||
std::unique_ptr<grpc::Server> server_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(FederatedServerTest, Allgather) {
|
TEST_F(FederatedServerTest, Allgather) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_));
|
threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_address_);
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -118,7 +83,7 @@ TEST_F(FederatedServerTest, Allgather) {
|
|||||||
TEST_F(FederatedServerTest, Allreduce) {
|
TEST_F(FederatedServerTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank, server_address_));
|
threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_address_);
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -128,7 +93,7 @@ TEST_F(FederatedServerTest, Allreduce) {
|
|||||||
TEST_F(FederatedServerTest, Broadcast) {
|
TEST_F(FederatedServerTest, Broadcast) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank, server_address_));
|
threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_address_);
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
@ -138,7 +103,7 @@ TEST_F(FederatedServerTest, Broadcast) {
|
|||||||
TEST_F(FederatedServerTest, Mixture) {
|
TEST_F(FederatedServerTest, Mixture) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank, server_address_));
|
threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_address_);
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
|
|||||||
@ -21,7 +21,8 @@ void TestFitStump(Context const *ctx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
linalg::Vector<float> out;
|
linalg::Vector<float> out;
|
||||||
FitStump(ctx, gpair, kTargets, &out);
|
MetaInfo info;
|
||||||
|
FitStump(ctx, info, gpair, kTargets, &out);
|
||||||
auto h_out = out.HostView();
|
auto h_out = out.HostView();
|
||||||
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
||||||
// sum_hess == kRows
|
// sum_hess == kRows
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user