From 78d65a1928a37ccc5000c846d94413cfa09c769b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 3 Dec 2022 09:37:51 -0800 Subject: [PATCH] Initial support for column-wise data split (#8468) --- include/xgboost/data.h | 12 +++++ src/data/data.cc | 27 ++++++++++-- src/data/iterative_dmatrix.h | 4 ++ src/data/proxy_dmatrix.h | 4 ++ src/data/simple_dmatrix.cc | 23 ++++++++++ src/data/simple_dmatrix.h | 1 + src/data/sparse_page_dmatrix.h | 4 ++ tests/cpp/data/test_simple_dmatrix.cc | 63 +++++++++++++++++++++++++++ 8 files changed, 135 insertions(+), 3 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 9fb643541..d5c89a8cc 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -112,6 +112,9 @@ class MetaInfo { void Validate(int32_t device) const; MetaInfo Slice(common::Span ridxs) const; + + MetaInfo Copy() const; + /*! * \brief Get weight of each instances. * \param i Instance index. @@ -620,6 +623,15 @@ class DMatrix { virtual DMatrix *Slice(common::Span ridxs) = 0; + /** + * \brief Slice a DMatrix by columns. + * + * @param start The position of the first column + * @param size The number of columns in the slice + * @return DMatrix containing the slice of columns + */ + virtual DMatrix *SliceCol(std::size_t start, std::size_t size) = 0; + protected: virtual BatchSet GetRowBatches() = 0; virtual BatchSet GetColumnBatches() = 0; diff --git a/src/data/data.cc b/src/data/data.cc index fb46b487b..ccf5ebc50 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -381,6 +381,12 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { return out; } +MetaInfo MetaInfo::Copy() const { + MetaInfo out; + out.Extend(*this, /*accumulate_rows=*/true, /*check_column=*/false); + return out; +} + namespace { template void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor* p_out) { @@ -777,8 +783,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) { DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode, const std::string& file_format) { - CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone) - << "Precondition violated; data split mode can only be 'row' or 'none'"; + CHECK(data_split_mode == DataSplitMode::kRow || + data_split_mode == DataSplitMode::kCol || + data_split_mode == DataSplitMode::kNone) + << "Precondition violated; data split mode can only be 'row', 'col', or 'none'"; std::string fname, cache_file; size_t dlm_pos = uri.find('#'); if (dlm_pos != std::string::npos) { @@ -878,7 +886,20 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s * partitioned data will fail the train/val validation check * since partitioned data not knowing the real number of features. */ collective::Allreduce(&dmat->Info().num_col_, 1); - return dmat; + + if (data_split_mode == DataSplitMode::kCol) { + if (!cache_file.empty()) { + LOG(FATAL) << "Column-wise data split is not support for external memory."; + } + auto slice_cols = (dmat->Info().num_col_ + 1) / npart; + auto slice_start = slice_cols * partid; + auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start); + auto* sliced = dmat->SliceCol(slice_start, size); + delete dmat; + return sliced; + } else { + return dmat; + } } template GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 8375c7c8d..2e7fd6f00 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -87,6 +87,10 @@ class DMatrixProxy : public DMatrix { LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix."; return nullptr; } + DMatrix* SliceCol(std::size_t start, std::size_t size) override { + LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix."; + return nullptr; + } BatchSet GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 4679ef543..56185b03e 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -45,6 +45,29 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { return out; } +DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) { + auto out = new SimpleDMatrix; + SparsePage& out_page = *out->sparse_page_; + for (auto const &page : this->GetBatches()) { + auto batch = page.GetView(); + auto& h_data = out_page.data.HostVector(); + auto& h_offset = out_page.offset.HostVector(); + size_t rptr{0}; + for (auto i = 0; i < this->Info().num_row_; i++) { + auto inst = batch[i]; + auto prev_size = h_data.size(); + std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) { + return e.index >= start && e.index < start + size; + }); + rptr += h_data.size() - prev_size; + h_offset.emplace_back(rptr); + } + out->Info() = this->Info().Copy(); + out->Info().num_nonzero_ = h_offset.back(); + } + return out; +} + BatchSet SimpleDMatrix::GetRowBatches() { // since csr is the default data structure so `source_` is always available. auto begin_iter = BatchIterator( diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 8a844a5af..9b9b5accf 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -35,6 +35,7 @@ class SimpleDMatrix : public DMatrix { bool SingleColBlock() const override { return true; } DMatrix* Slice(common::Span ridxs) override; + DMatrix* SliceCol(std::size_t start, std::size_t size) override; /*! \brief magic number used to identify SimpleDMatrix binary files */ static const int kMagic = 0xffffab01; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 6770e3ce0..3bbe8fbae 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -107,6 +107,10 @@ class SparsePageDMatrix : public DMatrix { LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; return nullptr; } + DMatrix *SliceCol(std::size_t start, std::size_t size) override { + LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory."; + return nullptr; + } private: BatchSet GetRowBatches() override; diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index ad545ce14..c67c39c0f 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -300,6 +300,69 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense } +TEST(SimpleDMatrix, SliceCol) { + size_t constexpr kRows {16}; + size_t constexpr kCols {8}; + size_t constexpr kClasses {3}; + auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); + auto& weights = p_m->Info().weights_.HostVector(); + weights.resize(kRows); + std::iota(weights.begin(), weights.end(), 0.0f); + + auto& lower = p_m->Info().labels_lower_bound_.HostVector(); + auto& upper = p_m->Info().labels_upper_bound_.HostVector(); + lower.resize(kRows); + upper.resize(kRows); + + std::iota(lower.begin(), lower.end(), 0.0f); + std::iota(upper.begin(), upper.end(), 1.0f); + + auto& margin = p_m->Info().base_margin_; + margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, GenericParameter::kCpuId}; + + size_t constexpr kSlicCols {4}; + for (auto slice = 0; slice < 2; slice++) { + auto const slice_start = slice * kSlicCols; + std::unique_ptr out { p_m->SliceCol(slice_start, kSlicCols) }; + ASSERT_EQ(out->Info().labels.Size(), kRows); + ASSERT_EQ(out->Info().labels_lower_bound_.Size(), kRows); + ASSERT_EQ(out->Info().labels_upper_bound_.Size(), kRows); + ASSERT_EQ(out->Info().base_margin_.Size(), kRows * kClasses); + + for (auto const &in_batch : p_m->GetBatches()) { + auto in_page = in_batch.GetView(); + for (auto const &out_batch : out->GetBatches()) { + auto out_page = out_batch.GetView(); + for (size_t i = 0; i < kRows; ++i) { + auto out_inst = out_page[i]; + auto in_inst = in_page[i]; + ASSERT_EQ(out_inst.size() * 2, in_inst.size()) << i; + for (size_t j = 0; j < kSlicCols; ++j) { + ASSERT_EQ(in_inst[slice_start + j].fvalue, out_inst[j].fvalue); + ASSERT_EQ(in_inst[slice_start + j].index, out_inst[j].index); + } + + ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(i), + out->Info().labels_lower_bound_.HostVector().at(i)); + ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(i), + out->Info().labels_upper_bound_.HostVector().at(i)); + ASSERT_EQ(p_m->Info().weights_.HostVector().at(i), out->Info().weights_.HostVector().at(i)); + + auto out_margin = out->Info().base_margin_.View(GenericParameter::kCpuId); + auto in_margin = margin.View(GenericParameter::kCpuId); + for (size_t j = 0; j < kClasses; ++j) { + ASSERT_EQ(out_margin(i, j), in_margin(i, j)); + } + } + } + } + + ASSERT_EQ(out->Info().num_col_, out->Info().num_col_); + ASSERT_EQ(out->Info().num_row_, kRows); + ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense + } +} + TEST(SimpleDMatrix, SaveLoadBinary) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm";