Common interface for collective communication (#8057)
* implement broadcast for federated communicator * implement allreduce * add communicator factory * add device adapter * add device communicator to factory * add rabit communicator * add rabit communicator to the factory * add nccl device communicator * add synchronize to device communicator * add back print and getprocessorname * add python wrapper and c api * clean up types * fix non-gpu build * try to fix ci * fix std::size_t * portable string compare ignore case * c style size_t * fix lint errors * cross platform setenv * fix memory leak * fix lint errors * address review feedback * add python test for rabit communicator * fix failing gtest * use json to configure communicators * fix lint error * get rid of factories * fix cpu build * fix include * fix python import * don't export collective.py yet * skip collective communicator pytest on windows * add review feedback * update documentation * remove mpi communicator type * fix tests * shutdown the communicator separately Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -9,10 +9,12 @@
|
||||
|
||||
#ifdef __cplusplus
|
||||
#define XGB_EXTERN_C extern "C"
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#else
|
||||
#define XGB_EXTERN_C
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#endif // __cplusplus
|
||||
@@ -1386,4 +1388,135 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
|
||||
bst_ulong *out_dim,
|
||||
bst_ulong const **out_shape,
|
||||
float const **out_scores);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * mpi: Use MPI.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
* - rabit_debug: Enable debugging.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
* \return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorInit(char const* json_config);
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
*
|
||||
* \return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetRank(void);
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void);
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void);
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
* \return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message);
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
*
|
||||
* \param name_str Pointer to received returned processor name.
|
||||
* \return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
* \return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
* \return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);
|
||||
|
||||
|
||||
#endif // XGBOOST_C_API_H_
|
||||
|
||||
Reference in New Issue
Block a user