Implement slice via adapters (#5198)
This commit is contained in:
parent
f100b8d878
commit
a73e25e15f
@ -23,6 +23,7 @@
|
|||||||
#include "../data/simple_csr_source.h"
|
#include "../data/simple_csr_source.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
|
#include "../data/simple_dmatrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// declare the data callback.
|
// declare the data callback.
|
||||||
@ -287,53 +288,20 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
|||||||
|
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
data::SimpleCSRSource src;
|
|
||||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
|
||||||
data::SimpleCSRSource& ret = *source;
|
|
||||||
|
|
||||||
if (!allow_groups) {
|
if (!allow_groups) {
|
||||||
CHECK_EQ(src.info.group_ptr_.size(), 0U)
|
CHECK_EQ(static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
<< "slice does not support group structure";
|
->get()
|
||||||
|
->Info()
|
||||||
|
.group_ptr_.size(),
|
||||||
|
0U)
|
||||||
|
<< "slice does not support group structure";
|
||||||
}
|
}
|
||||||
|
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
|
||||||
ret.Clear();
|
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
|
||||||
ret.info.num_row_ = len;
|
<< "Slice only supported for SimpleDMatrix currently.";
|
||||||
ret.info.num_col_ = src.info.num_col_;
|
data::DMatrixSliceAdapter adapter(dmat, {idxset, len});
|
||||||
|
*out = new std::shared_ptr<DMatrix>(
|
||||||
auto iter = &src;
|
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
|
||||||
iter->BeforeFirst();
|
|
||||||
CHECK(iter->Next());
|
|
||||||
|
|
||||||
const auto& batch = iter->Value();
|
|
||||||
const auto& src_labels = src.info.labels_.ConstHostVector();
|
|
||||||
const auto& src_weights = src.info.weights_.ConstHostVector();
|
|
||||||
const auto& src_base_margin = src.info.base_margin_.ConstHostVector();
|
|
||||||
auto& ret_labels = ret.info.labels_.HostVector();
|
|
||||||
auto& ret_weights = ret.info.weights_.HostVector();
|
|
||||||
auto& ret_base_margin = ret.info.base_margin_.HostVector();
|
|
||||||
auto& offset_vec = ret.page_.offset.HostVector();
|
|
||||||
auto& data_vec = ret.page_.data.HostVector();
|
|
||||||
|
|
||||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
|
||||||
const int ridx = idxset[i];
|
|
||||||
auto inst = batch[ridx];
|
|
||||||
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
|
|
||||||
data_vec.insert(data_vec.end(), inst.data(),
|
|
||||||
inst.data() + inst.size());
|
|
||||||
offset_vec.push_back(offset_vec.back() + inst.size());
|
|
||||||
ret.info.num_nonzero_ += inst.size();
|
|
||||||
|
|
||||||
if (src_labels.size() != 0) {
|
|
||||||
ret_labels.push_back(src_labels[ridx]);
|
|
||||||
}
|
|
||||||
if (src_weights.size() != 0) {
|
|
||||||
ret_weights.push_back(src_weights[ridx]);
|
|
||||||
}
|
|
||||||
if (src_base_margin.size() != 0) {
|
|
||||||
ret_base_margin.push_back(src_base_margin[ridx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,8 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -94,6 +96,7 @@ class NoMetaInfo {
|
|||||||
const float* Labels() const { return nullptr; }
|
const float* Labels() const { return nullptr; }
|
||||||
const float* Weights() const { return nullptr; }
|
const float* Weights() const { return nullptr; }
|
||||||
const uint64_t* Qid() const { return nullptr; }
|
const uint64_t* Qid() const { return nullptr; }
|
||||||
|
const float* BaseMargin() const { return nullptr; }
|
||||||
};
|
};
|
||||||
|
|
||||||
}; // namespace detail
|
}; // namespace detail
|
||||||
@ -446,6 +449,7 @@ class FileAdapterBatch {
|
|||||||
const float* Labels() const { return block->label; }
|
const float* Labels() const { return block->label; }
|
||||||
const float* Weights() const { return block->weight; }
|
const float* Weights() const { return block->weight; }
|
||||||
const uint64_t* Qid() const { return block->qid; }
|
const uint64_t* Qid() const { return block->qid; }
|
||||||
|
const float* BaseMargin() const { return nullptr; }
|
||||||
|
|
||||||
size_t Size() const { return block->size; }
|
size_t Size() const { return block->size; }
|
||||||
|
|
||||||
@ -481,6 +485,92 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
|||||||
std::unique_ptr<FileAdapterBatch> batch;
|
std::unique_ptr<FileAdapterBatch> batch;
|
||||||
dmlc::Parser<uint32_t>* parser;
|
dmlc::Parser<uint32_t>* parser;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class DMatrixSliceAdapterBatch {
|
||||||
|
public:
|
||||||
|
// Fetch metainfo values according to sliced rows
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> Gather(const std::vector<T>& in) {
|
||||||
|
if (in.empty()) return {};
|
||||||
|
|
||||||
|
std::vector<T> out(this->Size());
|
||||||
|
for (auto i = 0ull; i < this->Size(); i++) {
|
||||||
|
out[i] = in[ridx_set[i]];
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
|
||||||
|
common::Span<const int> ridx_set)
|
||||||
|
: batch(batch), ridx_set(ridx_set) {
|
||||||
|
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
|
||||||
|
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
|
||||||
|
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
class Line {
|
||||||
|
public:
|
||||||
|
Line(const SparsePage::Inst& inst, size_t row_idx)
|
||||||
|
: inst(inst), row_idx(row_idx) {}
|
||||||
|
|
||||||
|
size_t Size() { return inst.size(); }
|
||||||
|
COOTuple GetElement(size_t idx) {
|
||||||
|
return COOTuple(row_idx, inst[idx].index, inst[idx].fvalue);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SparsePage::Inst inst;
|
||||||
|
size_t row_idx;
|
||||||
|
};
|
||||||
|
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
|
||||||
|
const float* Labels() const {
|
||||||
|
if (batch_labels.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return batch_labels.data();
|
||||||
|
}
|
||||||
|
const float* Weights() const {
|
||||||
|
if (batch_weights.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return batch_weights.data();
|
||||||
|
}
|
||||||
|
const uint64_t* Qid() const { return nullptr; }
|
||||||
|
const float* BaseMargin() const {
|
||||||
|
if (batch_base_margin.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return batch_base_margin.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Size() const { return ridx_set.size(); }
|
||||||
|
const SparsePage& batch;
|
||||||
|
common::Span<const int> ridx_set;
|
||||||
|
std::vector<float> batch_labels;
|
||||||
|
std::vector<float> batch_weights;
|
||||||
|
std::vector<float> batch_base_margin;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Group pointer is not exposed
|
||||||
|
// This is because external bindings currently manipulate the group values
|
||||||
|
// manually when slicing This could potentially be moved to internal C++ code if
|
||||||
|
// needed
|
||||||
|
|
||||||
|
class DMatrixSliceAdapter
|
||||||
|
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
|
||||||
|
public:
|
||||||
|
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
|
||||||
|
: dmat(dmat),
|
||||||
|
ridx_set(ridx_set),
|
||||||
|
batch(*dmat->GetBatches<SparsePage>().begin(), dmat, ridx_set) {}
|
||||||
|
const DMatrixSliceAdapterBatch& Value() const override { return batch; }
|
||||||
|
// Indicates a number of rows/columns must be inferred
|
||||||
|
size_t NumRows() const { return ridx_set.size(); }
|
||||||
|
size_t NumColumns() const { return dmat->Info().num_col_; }
|
||||||
|
DMatrix* dmat;
|
||||||
|
DMatrixSliceAdapterBatch batch;
|
||||||
|
bool before_first{true};
|
||||||
|
common::Span<const int> ridx_set;
|
||||||
|
};
|
||||||
}; // namespace data
|
}; // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||||
|
|||||||
@ -340,6 +340,9 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
|||||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||||
data::FileAdapter* adapter, float missing, int nthread,
|
data::FileAdapter* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix, size_t page_size);
|
const std::string& cache_prefix, size_t page_size);
|
||||||
|
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
|
||||||
|
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
|
||||||
|
const std::string& cache_prefix, size_t page_size);
|
||||||
|
|
||||||
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||||
SparsePage transpose;
|
SparsePage transpose;
|
||||||
|
|||||||
@ -111,6 +111,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
|||||||
weights.insert(weights.end(), batch.Weights(),
|
weights.insert(weights.end(), batch.Weights(),
|
||||||
batch.Weights() + batch.Size());
|
batch.Weights() + batch.Size());
|
||||||
}
|
}
|
||||||
|
if (batch.BaseMargin() != nullptr) {
|
||||||
|
auto& base_margin = mat.info.base_margin_.HostVector();
|
||||||
|
base_margin.insert(base_margin.end(), batch.BaseMargin(),
|
||||||
|
batch.BaseMargin() + batch.Size());
|
||||||
|
}
|
||||||
if (batch.Qid() != nullptr) {
|
if (batch.Qid() != nullptr) {
|
||||||
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
||||||
// get group
|
// get group
|
||||||
@ -166,5 +171,7 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
|||||||
int nthread);
|
int nthread);
|
||||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
||||||
int nthread);
|
int nthread);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
|
||||||
|
int nthread);
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -222,6 +222,11 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
weights.insert(weights.end(), batch.Weights(),
|
weights.insert(weights.end(), batch.Weights(),
|
||||||
batch.Weights() + batch.Size());
|
batch.Weights() + batch.Size());
|
||||||
}
|
}
|
||||||
|
if (batch.BaseMargin() != nullptr) {
|
||||||
|
auto& base_margin = info.base_margin_.HostVector();
|
||||||
|
base_margin.insert(base_margin.end(), batch.BaseMargin(),
|
||||||
|
batch.BaseMargin() + batch.Size());
|
||||||
|
}
|
||||||
if (batch.Qid() != nullptr) {
|
if (batch.Qid() != nullptr) {
|
||||||
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
||||||
// get group
|
// get group
|
||||||
|
|||||||
@ -61,3 +61,31 @@ TEST(adapter, CSCAdapterColsMoreThanRows) {
|
|||||||
EXPECT_EQ(inst[3].fvalue, 8);
|
EXPECT_EQ(inst[3].fvalue, 8);
|
||||||
EXPECT_EQ(inst[3].index, 3);
|
EXPECT_EQ(inst[3].index, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(c_api, DMatrixSliceAdapterFromSimpleDMatrix) {
|
||||||
|
auto pp_dmat = CreateDMatrix(6, 2, 1.0);
|
||||||
|
auto p_dmat = *pp_dmat;
|
||||||
|
|
||||||
|
std::vector<int> ridx_set = {1, 3, 5};
|
||||||
|
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||||
|
{ridx_set.data(), ridx_set.size()});
|
||||||
|
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||||
|
|
||||||
|
adapter.BeforeFirst();
|
||||||
|
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
||||||
|
adapter.Next();
|
||||||
|
auto &adapter_batch = adapter.Value();
|
||||||
|
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
|
||||||
|
auto inst = batch[ridx_set[i]];
|
||||||
|
auto line = adapter_batch.GetLine(i);
|
||||||
|
ASSERT_EQ(inst.size(), line.Size());
|
||||||
|
for (auto j = 0ull; j < line.Size(); j++) {
|
||||||
|
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
|
||||||
|
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
|
||||||
|
EXPECT_EQ(i, line.GetElement(j).row_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete pp_dmat;
|
||||||
|
}
|
||||||
|
|||||||
@ -210,3 +210,47 @@ TEST(SimpleDMatrix, FromFile) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SimpleDMatrix, Slice) {
|
||||||
|
const int kRows = 6;
|
||||||
|
const int kCols = 2;
|
||||||
|
auto pp_dmat = CreateDMatrix(kRows, kCols, 1.0);
|
||||||
|
auto p_dmat = *pp_dmat;
|
||||||
|
auto &labels = p_dmat->Info().labels_.HostVector();
|
||||||
|
auto &weights = p_dmat->Info().weights_.HostVector();
|
||||||
|
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
|
||||||
|
weights.resize(kRows);
|
||||||
|
labels.resize(kRows);
|
||||||
|
base_margin.resize(kRows);
|
||||||
|
std::iota(labels.begin(), labels.end(), 0);
|
||||||
|
std::iota(weights.begin(), weights.end(), 0);
|
||||||
|
std::iota(base_margin.begin(), base_margin.end(), 0);
|
||||||
|
|
||||||
|
std::vector<int> ridx_set = {1, 3, 5};
|
||||||
|
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||||
|
{ridx_set.data(), ridx_set.size()});
|
||||||
|
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||||
|
data::SimpleDMatrix new_dmat(&adapter,
|
||||||
|
std::numeric_limits<float>::quiet_NaN(), 1);
|
||||||
|
|
||||||
|
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size());
|
||||||
|
|
||||||
|
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin();
|
||||||
|
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin();
|
||||||
|
for (auto i = 0ull; i < ridx_set.size(); i++) {
|
||||||
|
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i],
|
||||||
|
p_dmat->Info().labels_.HostVector()[ridx_set[i]]);
|
||||||
|
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i],
|
||||||
|
p_dmat->Info().weights_.HostVector()[ridx_set[i]]);
|
||||||
|
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i],
|
||||||
|
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]);
|
||||||
|
auto old_inst = old_batch[ridx_set[i]];
|
||||||
|
auto new_inst = new_batch[i];
|
||||||
|
ASSERT_EQ(old_inst.size(), new_inst.size());
|
||||||
|
for (auto j = 0ull; j < old_inst.size(); j++) {
|
||||||
|
EXPECT_EQ(old_inst[j], new_inst[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete pp_dmat;
|
||||||
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user