Initial support for column-wise data split (#8468)
This commit is contained in:
@@ -381,6 +381,12 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
return out;
|
||||
}
|
||||
|
||||
MetaInfo MetaInfo::Copy() const {
|
||||
MetaInfo out;
|
||||
out.Extend(*this, /*accumulate_rows=*/true, /*check_column=*/false);
|
||||
return out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <int32_t D, typename T>
|
||||
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
@@ -777,8 +783,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
|
||||
const std::string& file_format) {
|
||||
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
|
||||
<< "Precondition violated; data split mode can only be 'row' or 'none'";
|
||||
CHECK(data_split_mode == DataSplitMode::kRow ||
|
||||
data_split_mode == DataSplitMode::kCol ||
|
||||
data_split_mode == DataSplitMode::kNone)
|
||||
<< "Precondition violated; data split mode can only be 'row', 'col', or 'none'";
|
||||
std::string fname, cache_file;
|
||||
size_t dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
@@ -878,7 +886,20 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
* partitioned data will fail the train/val validation check
|
||||
* since partitioned data not knowing the real number of features. */
|
||||
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
||||
return dmat;
|
||||
|
||||
if (data_split_mode == DataSplitMode::kCol) {
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
}
|
||||
auto slice_cols = (dmat->Info().num_col_ + 1) / npart;
|
||||
auto slice_start = slice_cols * partid;
|
||||
auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start);
|
||||
auto* sliced = dmat->SliceCol(slice_start, size);
|
||||
delete dmat;
|
||||
return sliced;
|
||||
} else {
|
||||
return dmat;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
|
||||
Reference in New Issue
Block a user