More tests for column split and vertical federated learning (#8985)

Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning.

Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`.

Some refactoring of the testing fixtures.
This commit is contained in:
Rong Ou
2023-03-28 01:40:26 -07:00
committed by GitHub
parent 401ce5cf5e
commit ff26cd3212
18 changed files with 212 additions and 94 deletions

View File

@@ -180,6 +180,22 @@ class MetaInfo {
*/
void SynchronizeNumberOfColumns();
/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return data_split_mode == DataSplitMode::kCol;
}
/*!
* \brief A convenient method to check if we are doing vertical federated learning, which requires
* some special processing.
*/
bool IsVerticalFederated() const;
private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
@@ -542,16 +558,6 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}
/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return Info().data_split_mode == DataSplitMode::kRow;
}
/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return Info().data_split_mode == DataSplitMode::kCol;
}
/*!
* \brief Load DMatrix from URI.
* \param uri The URI of input.