Don't split input data in federated mode (#8279)
Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
66fd9f5207
commit
8d4038da57
@ -170,6 +170,12 @@ class FederatedCommunicator : public Communicator {
|
||||
*/
|
||||
bool IsDistributed() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Get if the communicator is federated.
|
||||
* \return True.
|
||||
*/
|
||||
bool IsFederated() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Perform in-place allreduce.
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
|
||||
@ -208,16 +208,12 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
bool load_row_split = false;
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
#else
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
} else if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
|
||||
load_row_split = true;
|
||||
}
|
||||
#endif
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(fname);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
||||
|
||||
@ -88,6 +88,13 @@ inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||
*/
|
||||
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is federated.
|
||||
*
|
||||
* \return True if the communicator is federated.
|
||||
*/
|
||||
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
*
|
||||
|
||||
@ -78,6 +78,9 @@ class Communicator {
|
||||
/** @brief Whether the communicator is running in distributed mode. */
|
||||
virtual bool IsDistributed() const = 0;
|
||||
|
||||
/** @brief Whether the communicator is running in federated mode. */
|
||||
virtual bool IsFederated() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
|
||||
@ -16,6 +16,7 @@ class NoOpCommunicator : public Communicator {
|
||||
public:
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {}
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
|
||||
|
||||
@ -53,6 +53,8 @@ class RabitCommunicator : public Communicator {
|
||||
|
||||
bool IsDistributed() const override { return rabit::IsDistributed(); }
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
switch (data_type) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user