Support vertical federated learning (#8932)

This commit is contained in:
Rong Ou
2023-03-21 23:25:26 -07:00
committed by GitHub
parent 8dc1e4b3ea
commit b240f055d3
23 changed files with 371 additions and 249 deletions

View File

@@ -171,6 +171,15 @@ class MetaInfo {
*/
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
/**
* @brief Synchronize the number of columns across all workers.
*
* Normally we just need to find the maximum number of columns across all workers, but
* in vertical federated learning, since each worker loads its own list of columns,
* we need to sum them.
*/
void SynchronizeNumberOfColumns();
private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
@@ -325,6 +334,10 @@ class SparsePage {
* \brief Check wether the column index is sorted.
*/
bool IsIndicesSorted(int32_t n_threads) const;
/**
* \brief Reindex the column index with an offset.
*/
void Reindex(uint64_t feature_offset, int32_t n_threads);
void SortRows(int32_t n_threads);
@@ -559,17 +572,18 @@ class DMatrix {
* \brief Creates a new DMatrix from an external data adapter.
*
* \tparam AdapterT Type of the adapter.
* \param [in,out] adapter View onto an external data.
* \param missing Values to count as missing.
* \param nthread Number of threads for construction.
* \param cache_prefix (Optional) The cache prefix for external memory.
* \param page_size (Optional) Size of the page.
* \param [in,out] adapter View onto an external data.
* \param missing Values to count as missing.
* \param nthread Number of threads for construction.
* \param cache_prefix (Optional) The cache prefix for external memory.
* \param data_split_mode (Optional) Data split mode.
*
* \return a Created DMatrix.
*/
template <typename AdapterT>
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix = "");
const std::string& cache_prefix = "",
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.