Don't split input data in federated mode (#8279)

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou 2022-10-05 19:19:28 -07:00 committed by GitHub
parent 66fd9f5207
commit 8d4038da57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 8 deletions

View File

@ -170,6 +170,12 @@ class FederatedCommunicator : public Communicator {
*/ */
bool IsDistributed() const override { return true; } bool IsDistributed() const override { return true; }
/**
* \brief Get if the communicator is federated.
* \return True.
*/
bool IsFederated() const override { return true; }
/** /**
* \brief Perform in-place allreduce. * \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data. * \param send_receive_buffer Buffer for both sending and receiving data.

View File

@ -208,16 +208,12 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) { XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
bool load_row_split = false; bool load_row_split = false;
#if defined(XGBOOST_USE_FEDERATED) if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
#else } else if (collective::IsDistributed()) {
if (collective::IsDistributed()) { LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
load_row_split = true; load_row_split = true;
} }
#endif
xgboost_CHECK_C_ARG_PTR(fname); xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split)); *out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));

View File

@ -88,6 +88,13 @@ inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
*/ */
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); } inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
/*!
* \brief Get if the communicator is federated.
*
* \return True if the communicator is federated.
*/
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
/*! /*!
* \brief Print the message to the communicator. * \brief Print the message to the communicator.
* *

View File

@ -78,6 +78,9 @@ class Communicator {
/** @brief Whether the communicator is running in distributed mode. */ /** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0; virtual bool IsDistributed() const = 0;
/** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0;
/** /**
* @brief Combines values from all processes and distributes the result back to all processes. * @brief Combines values from all processes and distributes the result back to all processes.
* *

View File

@ -16,6 +16,7 @@ class NoOpCommunicator : public Communicator {
public: public:
NoOpCommunicator() : Communicator(1, 0) {} NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; } bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {} Operation op) override {}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {} void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}

View File

@ -53,6 +53,8 @@ class RabitCommunicator : public Communicator {
bool IsDistributed() const override { return rabit::IsDistributed(); } bool IsDistributed() const override { return rabit::IsDistributed(); }
bool IsFederated() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override { Operation op) override {
switch (data_type) { switch (data_type) {