Implement slice via adapters (#5198)

This commit is contained in:
Rory Mitchell 2020-01-14 12:55:41 +13:00 committed by GitHub
parent f100b8d878
commit a73e25e15f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 190 additions and 45 deletions

View File

@ -23,6 +23,7 @@
#include "../data/simple_csr_source.h"
#include "../common/io.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
namespace xgboost {
// declare the data callback.
@ -287,53 +288,20 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
API_BEGIN();
CHECK_HANDLE();
data::SimpleCSRSource src;
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source;
if (!allow_groups) {
CHECK_EQ(src.info.group_ptr_.size(), 0U)
CHECK_EQ(static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()
->Info()
.group_ptr_.size(),
0U)
<< "slice does not support group structure";
}
ret.Clear();
ret.info.num_row_ = len;
ret.info.num_col_ = src.info.num_col_;
auto iter = &src;
iter->BeforeFirst();
CHECK(iter->Next());
const auto& batch = iter->Value();
const auto& src_labels = src.info.labels_.ConstHostVector();
const auto& src_weights = src.info.weights_.ConstHostVector();
const auto& src_base_margin = src.info.base_margin_.ConstHostVector();
auto& ret_labels = ret.info.labels_.HostVector();
auto& ret_weights = ret.info.weights_.HostVector();
auto& ret_base_margin = ret.info.base_margin_.HostVector();
auto& offset_vec = ret.page_.offset.HostVector();
auto& data_vec = ret.page_.data.HostVector();
for (xgboost::bst_ulong i = 0; i < len; ++i) {
const int ridx = idxset[i];
auto inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
data_vec.insert(data_vec.end(), inst.data(),
inst.data() + inst.size());
offset_vec.push_back(offset_vec.back() + inst.size());
ret.info.num_nonzero_ += inst.size();
if (src_labels.size() != 0) {
ret_labels.push_back(src_labels[ridx]);
}
if (src_weights.size() != 0) {
ret_weights.push_back(src_weights[ridx]);
}
if (src_base_margin.size() != 0) {
ret_base_margin.push_back(src_base_margin[ridx]);
}
}
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
<< "Slice only supported for SimpleDMatrix currently.";
data::DMatrixSliceAdapter adapter(dmat, {idxset, len});
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
API_END();
}

View File

@ -8,6 +8,8 @@
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace xgboost {
namespace data {
@ -94,6 +96,7 @@ class NoMetaInfo {
const float* Labels() const { return nullptr; }
const float* Weights() const { return nullptr; }
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const { return nullptr; }
};
}; // namespace detail
@ -446,6 +449,7 @@ class FileAdapterBatch {
const float* Labels() const { return block->label; }
const float* Weights() const { return block->weight; }
const uint64_t* Qid() const { return block->qid; }
const float* BaseMargin() const { return nullptr; }
size_t Size() const { return block->size; }
@ -481,6 +485,92 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
std::unique_ptr<FileAdapterBatch> batch;
dmlc::Parser<uint32_t>* parser;
};
class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows
template <typename T>
std::vector<T> Gather(const std::vector<T>& in) {
if (in.empty()) return {};
std::vector<T> out(this->Size());
for (auto i = 0ull; i < this->Size(); i++) {
out[i] = in[ridx_set[i]];
}
return out;
}
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
common::Span<const int> ridx_set)
: batch(batch), ridx_set(ridx_set) {
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
}
class Line {
public:
Line(const SparsePage::Inst& inst, size_t row_idx)
: inst(inst), row_idx(row_idx) {}
size_t Size() { return inst.size(); }
COOTuple GetElement(size_t idx) {
return COOTuple(row_idx, inst[idx].index, inst[idx].fvalue);
}
private:
SparsePage::Inst inst;
size_t row_idx;
};
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
const float* Labels() const {
if (batch_labels.empty()) {
return nullptr;
}
return batch_labels.data();
}
const float* Weights() const {
if (batch_weights.empty()) {
return nullptr;
}
return batch_weights.data();
}
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const {
if (batch_base_margin.empty()) {
return nullptr;
}
return batch_base_margin.data();
}
size_t Size() const { return ridx_set.size(); }
const SparsePage& batch;
common::Span<const int> ridx_set;
std::vector<float> batch_labels;
std::vector<float> batch_weights;
std::vector<float> batch_base_margin;
};
// Group pointer is not exposed
// This is because external bindings currently manipulate the group values
// manually when slicing This could potentially be moved to internal C++ code if
// needed
class DMatrixSliceAdapter
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
public:
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
: dmat(dmat),
ridx_set(ridx_set),
batch(*dmat->GetBatches<SparsePage>().begin(), dmat, ridx_set) {}
const DMatrixSliceAdapterBatch& Value() const override { return batch; }
// Indicates a number of rows/columns must be inferred
size_t NumRows() const { return ridx_set.size(); }
size_t NumColumns() const { return dmat->Info().num_col_; }
DMatrix* dmat;
DMatrixSliceAdapterBatch batch;
bool before_first{true};
common::Span<const int> ridx_set;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_

View File

@ -340,6 +340,9 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
SparsePage SparsePage::GetTranspose(int num_columns) const {
SparsePage transpose;

View File

@ -111,6 +111,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
}
if (batch.BaseMargin() != nullptr) {
auto& base_margin = mat.info.base_margin_.HostVector();
base_margin.insert(base_margin.end(), batch.BaseMargin(),
batch.BaseMargin() + batch.Size());
}
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
// get group
@ -166,5 +171,7 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
int nthread);
} // namespace data
} // namespace xgboost

View File

@ -222,6 +222,11 @@ class SparsePageSource : public DataSource<T> {
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
}
if (batch.BaseMargin() != nullptr) {
auto& base_margin = info.base_margin_.HostVector();
base_margin.insert(base_margin.end(), batch.BaseMargin(),
batch.BaseMargin() + batch.Size());
}
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
// get group

View File

@ -61,3 +61,31 @@ TEST(adapter, CSCAdapterColsMoreThanRows) {
EXPECT_EQ(inst[3].fvalue, 8);
EXPECT_EQ(inst[3].index, 3);
}
TEST(c_api, DMatrixSliceAdapterFromSimpleDMatrix) {
auto pp_dmat = CreateDMatrix(6, 2, 1.0);
auto p_dmat = *pp_dmat;
std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
adapter.BeforeFirst();
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
adapter.Next();
auto &adapter_batch = adapter.Value();
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
auto inst = batch[ridx_set[i]];
auto line = adapter_batch.GetLine(i);
ASSERT_EQ(inst.size(), line.Size());
for (auto j = 0ull; j < line.Size(); j++) {
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
EXPECT_EQ(i, line.GetElement(j).row_idx);
}
}
}
delete pp_dmat;
}

View File

@ -210,3 +210,47 @@ TEST(SimpleDMatrix, FromFile) {
}
}
}
TEST(SimpleDMatrix, Slice) {
const int kRows = 6;
const int kCols = 2;
auto pp_dmat = CreateDMatrix(kRows, kCols, 1.0);
auto p_dmat = *pp_dmat;
auto &labels = p_dmat->Info().labels_.HostVector();
auto &weights = p_dmat->Info().weights_.HostVector();
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
weights.resize(kRows);
labels.resize(kRows);
base_margin.resize(kRows);
std::iota(labels.begin(), labels.end(), 0);
std::iota(weights.begin(), weights.end(), 0);
std::iota(base_margin.begin(), base_margin.end(), 0);
std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
data::SimpleDMatrix new_dmat(&adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size());
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin();
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin();
for (auto i = 0ull; i < ridx_set.size(); i++) {
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i],
p_dmat->Info().labels_.HostVector()[ridx_set[i]]);
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i],
p_dmat->Info().weights_.HostVector()[ridx_set[i]]);
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i],
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]);
auto old_inst = old_batch[ridx_set[i]];
auto new_inst = new_batch[i];
ASSERT_EQ(old_inst.size(), new_inst.size());
for (auto j = 0ull; j < old_inst.size(); j++) {
EXPECT_EQ(old_inst[j], new_inst[j]);
}
}
delete pp_dmat;
};