diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index d71364527..a1355c4d0 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -23,6 +23,7 @@ #include "../data/simple_csr_source.h" #include "../common/io.h" #include "../data/adapter.h" +#include "../data/simple_dmatrix.h" namespace xgboost { // declare the data callback. @@ -287,53 +288,20 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, API_BEGIN(); CHECK_HANDLE(); - data::SimpleCSRSource src; - src.CopyFrom(static_cast*>(handle)->get()); - data::SimpleCSRSource& ret = *source; - if (!allow_groups) { - CHECK_EQ(src.info.group_ptr_.size(), 0U) - << "slice does not support group structure"; + CHECK_EQ(static_cast*>(handle) + ->get() + ->Info() + .group_ptr_.size(), + 0U) + << "slice does not support group structure"; } - - ret.Clear(); - ret.info.num_row_ = len; - ret.info.num_col_ = src.info.num_col_; - - auto iter = &src; - iter->BeforeFirst(); - CHECK(iter->Next()); - - const auto& batch = iter->Value(); - const auto& src_labels = src.info.labels_.ConstHostVector(); - const auto& src_weights = src.info.weights_.ConstHostVector(); - const auto& src_base_margin = src.info.base_margin_.ConstHostVector(); - auto& ret_labels = ret.info.labels_.HostVector(); - auto& ret_weights = ret.info.weights_.HostVector(); - auto& ret_base_margin = ret.info.base_margin_.HostVector(); - auto& offset_vec = ret.page_.offset.HostVector(); - auto& data_vec = ret.page_.data.HostVector(); - - for (xgboost::bst_ulong i = 0; i < len; ++i) { - const int ridx = idxset[i]; - auto inst = batch[ridx]; - CHECK_LT(static_cast(ridx), batch.Size()); - data_vec.insert(data_vec.end(), inst.data(), - inst.data() + inst.size()); - offset_vec.push_back(offset_vec.back() + inst.size()); - ret.info.num_nonzero_ += inst.size(); - - if (src_labels.size() != 0) { - ret_labels.push_back(src_labels[ridx]); - } - if (src_weights.size() != 0) { - ret_weights.push_back(src_weights[ridx]); - } - if (src_base_margin.size() != 0) { - ret_base_margin.push_back(src_base_margin[ridx]); - } - } - *out = new std::shared_ptr(DMatrix::Create(std::move(source))); + DMatrix* dmat = static_cast*>(handle)->get(); + CHECK(dynamic_cast(dmat)) + << "Slice only supported for SimpleDMatrix currently."; + data::DMatrixSliceAdapter adapter(dmat, {idxset, len}); + *out = new std::shared_ptr( + DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1)); API_END(); } diff --git a/src/data/adapter.h b/src/data/adapter.h index ea89ba4fa..896a44b94 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -8,6 +8,8 @@ #include #include #include +#include +#include namespace xgboost { namespace data { @@ -94,6 +96,7 @@ class NoMetaInfo { const float* Labels() const { return nullptr; } const float* Weights() const { return nullptr; } const uint64_t* Qid() const { return nullptr; } + const float* BaseMargin() const { return nullptr; } }; }; // namespace detail @@ -446,6 +449,7 @@ class FileAdapterBatch { const float* Labels() const { return block->label; } const float* Weights() const { return block->weight; } const uint64_t* Qid() const { return block->qid; } + const float* BaseMargin() const { return nullptr; } size_t Size() const { return block->size; } @@ -481,6 +485,92 @@ class FileAdapter : dmlc::DataIter { std::unique_ptr batch; dmlc::Parser* parser; }; + +class DMatrixSliceAdapterBatch { + public: + // Fetch metainfo values according to sliced rows + template + std::vector Gather(const std::vector& in) { + if (in.empty()) return {}; + + std::vector out(this->Size()); + for (auto i = 0ull; i < this->Size(); i++) { + out[i] = in[ridx_set[i]]; + } + return out; + } + DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat, + common::Span ridx_set) + : batch(batch), ridx_set(ridx_set) { + batch_labels = this->Gather(dmat->Info().labels_.HostVector()); + batch_weights = this->Gather(dmat->Info().weights_.HostVector()); + batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector()); + } + + class Line { + public: + Line(const SparsePage::Inst& inst, size_t row_idx) + : inst(inst), row_idx(row_idx) {} + + size_t Size() { return inst.size(); } + COOTuple GetElement(size_t idx) { + return COOTuple(row_idx, inst[idx].index, inst[idx].fvalue); + } + + private: + SparsePage::Inst inst; + size_t row_idx; + }; + Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); } + const float* Labels() const { + if (batch_labels.empty()) { + return nullptr; + } + return batch_labels.data(); + } + const float* Weights() const { + if (batch_weights.empty()) { + return nullptr; + } + return batch_weights.data(); + } + const uint64_t* Qid() const { return nullptr; } + const float* BaseMargin() const { + if (batch_base_margin.empty()) { + return nullptr; + } + return batch_base_margin.data(); + } + + size_t Size() const { return ridx_set.size(); } + const SparsePage& batch; + common::Span ridx_set; + std::vector batch_labels; + std::vector batch_weights; + std::vector batch_base_margin; +}; + +// Group pointer is not exposed +// This is because external bindings currently manipulate the group values +// manually when slicing This could potentially be moved to internal C++ code if +// needed + +class DMatrixSliceAdapter + : public detail::SingleBatchDataIter { + public: + DMatrixSliceAdapter(DMatrix* dmat, common::Span ridx_set) + : dmat(dmat), + ridx_set(ridx_set), + batch(*dmat->GetBatches().begin(), dmat, ridx_set) {} + const DMatrixSliceAdapterBatch& Value() const override { return batch; } + // Indicates a number of rows/columns must be inferred + size_t NumRows() const { return ridx_set.size(); } + size_t NumColumns() const { return dmat->Info().num_col_; } + DMatrix* dmat; + DMatrixSliceAdapterBatch batch; + bool before_first{true}; + common::Span ridx_set; +}; }; // namespace data } // namespace xgboost #endif // XGBOOST_DATA_ADAPTER_H_ diff --git a/src/data/data.cc b/src/data/data.cc index 9f0cbae69..9f158851c 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -340,6 +340,9 @@ template DMatrix* DMatrix::Create( template DMatrix* DMatrix::Create( data::FileAdapter* adapter, float missing, int nthread, const std::string& cache_prefix, size_t page_size); +template DMatrix* DMatrix::Create( + data::DMatrixSliceAdapter* adapter, float missing, int nthread, + const std::string& cache_prefix, size_t page_size); SparsePage SparsePage::GetTranspose(int num_columns) const { SparsePage transpose; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index bb83400a0..229f4ab58 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -111,6 +111,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); } + if (batch.BaseMargin() != nullptr) { + auto& base_margin = mat.info.base_margin_.HostVector(); + base_margin.insert(base_margin.end(), batch.BaseMargin(), + batch.BaseMargin() + batch.Size()); + } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); // get group @@ -166,5 +171,7 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread); template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread); +template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing, + int nthread); } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 56fa07921..172eb14bc 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -222,6 +222,11 @@ class SparsePageSource : public DataSource { weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); } + if (batch.BaseMargin() != nullptr) { + auto& base_margin = info.base_margin_.HostVector(); + base_margin.insert(base_margin.end(), batch.BaseMargin(), + batch.BaseMargin() + batch.Size()); + } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); // get group diff --git a/tests/cpp/data/test_adapter.cc b/tests/cpp/data/test_adapter.cc index 5c73030a3..d785ab5d1 100644 --- a/tests/cpp/data/test_adapter.cc +++ b/tests/cpp/data/test_adapter.cc @@ -61,3 +61,31 @@ TEST(adapter, CSCAdapterColsMoreThanRows) { EXPECT_EQ(inst[3].fvalue, 8); EXPECT_EQ(inst[3].index, 3); } + +TEST(c_api, DMatrixSliceAdapterFromSimpleDMatrix) { + auto pp_dmat = CreateDMatrix(6, 2, 1.0); + auto p_dmat = *pp_dmat; + + std::vector ridx_set = {1, 3, 5}; + data::DMatrixSliceAdapter adapter(p_dmat.get(), + {ridx_set.data(), ridx_set.size()}); + EXPECT_EQ(adapter.NumRows(), ridx_set.size()); + + adapter.BeforeFirst(); + for (auto &batch : p_dmat->GetBatches()) { + adapter.Next(); + auto &adapter_batch = adapter.Value(); + for (auto i = 0ull; i < adapter_batch.Size(); i++) { + auto inst = batch[ridx_set[i]]; + auto line = adapter_batch.GetLine(i); + ASSERT_EQ(inst.size(), line.Size()); + for (auto j = 0ull; j < line.Size(); j++) { + EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value); + EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx); + EXPECT_EQ(i, line.GetElement(j).row_idx); + } + } + } + + delete pp_dmat; +} diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 03bd2c349..0524165cb 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -210,3 +210,47 @@ TEST(SimpleDMatrix, FromFile) { } } } + +TEST(SimpleDMatrix, Slice) { + const int kRows = 6; + const int kCols = 2; + auto pp_dmat = CreateDMatrix(kRows, kCols, 1.0); + auto p_dmat = *pp_dmat; + auto &labels = p_dmat->Info().labels_.HostVector(); + auto &weights = p_dmat->Info().weights_.HostVector(); + auto &base_margin = p_dmat->Info().base_margin_.HostVector(); + weights.resize(kRows); + labels.resize(kRows); + base_margin.resize(kRows); + std::iota(labels.begin(), labels.end(), 0); + std::iota(weights.begin(), weights.end(), 0); + std::iota(base_margin.begin(), base_margin.end(), 0); + + std::vector ridx_set = {1, 3, 5}; + data::DMatrixSliceAdapter adapter(p_dmat.get(), + {ridx_set.data(), ridx_set.size()}); + EXPECT_EQ(adapter.NumRows(), ridx_set.size()); + data::SimpleDMatrix new_dmat(&adapter, + std::numeric_limits::quiet_NaN(), 1); + + EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size()); + + auto &old_batch = *p_dmat->GetBatches().begin(); + auto &new_batch = *new_dmat.GetBatches().begin(); + for (auto i = 0ull; i < ridx_set.size(); i++) { + EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i], + p_dmat->Info().labels_.HostVector()[ridx_set[i]]); + EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i], + p_dmat->Info().weights_.HostVector()[ridx_set[i]]); + EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i], + p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]); + auto old_inst = old_batch[ridx_set[i]]; + auto new_inst = new_batch[i]; + ASSERT_EQ(old_inst.size(), new_inst.size()); + for (auto j = 0ull; j < old_inst.size(); j++) { + EXPECT_EQ(old_inst[j], new_inst[j]); + } + } + + delete pp_dmat; +};