Initial support for column-wise data split (#8468)

This commit is contained in:
Rong Ou
2022-12-03 09:37:51 -08:00
committed by GitHub
parent c0609b98f1
commit 78d65a1928
8 changed files with 135 additions and 3 deletions

View File

@@ -381,6 +381,12 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
return out;
}
MetaInfo MetaInfo::Copy() const {
MetaInfo out;
out.Extend(*this, /*accumulate_rows=*/true, /*check_column=*/false);
return out;
}
namespace {
template <int32_t D, typename T>
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
@@ -777,8 +783,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) {
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; data split mode can only be 'row' or 'none'";
CHECK(data_split_mode == DataSplitMode::kRow ||
data_split_mode == DataSplitMode::kCol ||
data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; data split mode can only be 'row', 'col', or 'none'";
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
@@ -878,7 +886,20 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
* partitioned data will fail the train/val validation check
* since partitioned data not knowing the real number of features. */
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
return dmat;
if (data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
}
auto slice_cols = (dmat->Info().num_col_ + 1) / npart;
auto slice_start = slice_cols * partid;
auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start);
auto* sliced = dmat->SliceCol(slice_start, size);
delete dmat;
return sliced;
} else {
return dmat;
}
}
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,

View File

@@ -86,6 +86,10 @@ class IterativeDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
return nullptr;
}
DMatrix *SliceCol(std::size_t start, std::size_t size) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));

View File

@@ -87,6 +87,10 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
return nullptr;
}
DMatrix* SliceCol(std::size_t start, std::size_t size) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));

View File

@@ -45,6 +45,29 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
return out;
}
DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
auto out = new SimpleDMatrix;
SparsePage& out_page = *out->sparse_page_;
for (auto const &page : this->GetBatches<SparsePage>()) {
auto batch = page.GetView();
auto& h_data = out_page.data.HostVector();
auto& h_offset = out_page.offset.HostVector();
size_t rptr{0};
for (auto i = 0; i < this->Info().num_row_; i++) {
auto inst = batch[i];
auto prev_size = h_data.size();
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) {
return e.index >= start && e.index < start + size;
});
rptr += h_data.size() - prev_size;
h_offset.emplace_back(rptr);
}
out->Info() = this->Info().Copy();
out->Info().num_nonzero_ = h_offset.back();
}
return out;
}
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(

View File

@@ -35,6 +35,7 @@ class SimpleDMatrix : public DMatrix {
bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
DMatrix* SliceCol(std::size_t start, std::size_t size) override;
/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;

View File

@@ -107,6 +107,10 @@ class SparsePageDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}
DMatrix *SliceCol(std::size_t start, std::size_t size) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
return nullptr;
}
private:
BatchSet<SparsePage> GetRowBatches() override;