Fix slice and get info. (#5552)
This commit is contained in:
@@ -181,11 +181,7 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
||||
<< "slice does not support group structure";
|
||||
}
|
||||
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
|
||||
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
|
||||
<< "Slice only supported for SimpleDMatrix currently.";
|
||||
data::DMatrixSliceAdapter adapter(dmat, {idxset, static_cast<size_t>(len)});
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
|
||||
*out = new std::shared_ptr<DMatrix>(dmat->Slice({idxset, len}));
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
dmlc::RowBlock<uint32_t> block_;
|
||||
std::unique_ptr<FileAdapterBatch> batch_;
|
||||
};
|
||||
|
||||
class DMatrixSliceAdapterBatch {
|
||||
public:
|
||||
// Fetch metainfo values according to sliced rows
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T>& in) {
|
||||
if (in.empty()) return {};
|
||||
|
||||
std::vector<T> out(this->Size());
|
||||
for (auto i = 0ull; i < this->Size(); i++) {
|
||||
out[i] = in[ridx_set[i]];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
|
||||
common::Span<const int> ridx_set)
|
||||
: batch(batch), ridx_set(ridx_set) {
|
||||
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
|
||||
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
|
||||
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
|
||||
}
|
||||
|
||||
class Line {
|
||||
public:
|
||||
Line(const SparsePage::Inst& inst, size_t row_idx)
|
||||
: inst_(inst), row_idx_(row_idx) {}
|
||||
|
||||
size_t Size() { return inst_.size(); }
|
||||
COOTuple GetElement(size_t idx) {
|
||||
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
|
||||
}
|
||||
|
||||
private:
|
||||
SparsePage::Inst inst_;
|
||||
size_t row_idx_;
|
||||
};
|
||||
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
|
||||
const float* Labels() const {
|
||||
if (batch_labels.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_labels.data();
|
||||
}
|
||||
const float* Weights() const {
|
||||
if (batch_weights.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_weights.data();
|
||||
}
|
||||
const uint64_t* Qid() const { return nullptr; }
|
||||
const float* BaseMargin() const {
|
||||
if (batch_base_margin.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_base_margin.data();
|
||||
}
|
||||
|
||||
size_t Size() const { return ridx_set.size(); }
|
||||
const SparsePage& batch;
|
||||
common::Span<const int> ridx_set;
|
||||
std::vector<float> batch_labels;
|
||||
std::vector<float> batch_weights;
|
||||
std::vector<float> batch_base_margin;
|
||||
};
|
||||
|
||||
// Group pointer is not exposed
|
||||
// This is because external bindings currently manipulate the group values
|
||||
// manually when slicing This could potentially be moved to internal C++ code if
|
||||
// needed
|
||||
|
||||
class DMatrixSliceAdapter
|
||||
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
|
||||
public:
|
||||
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
|
||||
: dmat_(dmat),
|
||||
ridx_set_(ridx_set),
|
||||
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
|
||||
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
|
||||
// Indicates a number of rows/columns must be inferred
|
||||
size_t NumRows() const { return ridx_set_.size(); }
|
||||
size_t NumColumns() const { return dmat_->Info().num_col_; }
|
||||
|
||||
private:
|
||||
DMatrix* dmat_;
|
||||
common::Span<const int> ridx_set_;
|
||||
DMatrixSliceAdapterBatch batch_;
|
||||
};
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||
|
||||
@@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
|
||||
if (in.empty()) {
|
||||
return {};
|
||||
}
|
||||
auto size = ridxs.size();
|
||||
std::vector<T> out(size * stride);
|
||||
for (auto i = 0ull; i < size; i++) {
|
||||
auto ridx = ridxs[i];
|
||||
for (size_t j = 0; j < stride; ++j) {
|
||||
out[i * stride +j] = in[ridx * stride + j];
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
MetaInfo out;
|
||||
out.num_row_ = ridxs.size();
|
||||
out.num_col_ = this->num_col_;
|
||||
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
||||
// the slice function.
|
||||
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
|
||||
out.labels_upper_bound_.HostVector() =
|
||||
Gather(this->labels_upper_bound_.HostVector(), ridxs);
|
||||
out.labels_lower_bound_.HostVector() =
|
||||
Gather(this->labels_lower_bound_.HostVector(), ridxs);
|
||||
// weights
|
||||
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
|
||||
auto& h_weights = out.weights_.HostVector();
|
||||
// Assuming all groups are available.
|
||||
out.weights_.HostVector() = h_weights;
|
||||
} else {
|
||||
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
|
||||
}
|
||||
|
||||
if (this->base_margin_.Size() != this->num_row_) {
|
||||
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
||||
<< "Incorrect size of base margin vector.";
|
||||
size_t stride = this->base_margin_.Size() / this->num_row_;
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
|
||||
} else {
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// try to load group information from file, if exists
|
||||
inline bool MetaTryLoadGroup(const std::string& fname,
|
||||
std::vector<unsigned>* group) {
|
||||
@@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||
data::FileAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
|
||||
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
|
||||
data::IteratorAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
|
||||
@@ -31,6 +31,10 @@ class DeviceDMatrix : public DMatrix {
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
|
||||
@@ -16,6 +16,27 @@ MetaInfo& SimpleDMatrix::Info() { return info_; }
|
||||
|
||||
const MetaInfo& SimpleDMatrix::Info() const { return info_; }
|
||||
|
||||
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||
auto out = new SimpleDMatrix;
|
||||
SparsePage& out_page = out->sparse_page_;
|
||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||
page.data.HostVector();
|
||||
page.offset.HostVector();
|
||||
auto& h_data = out_page.data.HostVector();
|
||||
auto& h_offset = out_page.offset.HostVector();
|
||||
size_t rptr{0};
|
||||
for (auto ridx : ridxs) {
|
||||
auto inst = page[ridx];
|
||||
rptr += inst.size();
|
||||
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
|
||||
h_offset.emplace_back(rptr);
|
||||
}
|
||||
out->Info() = this->Info().Slice(ridxs);
|
||||
out->Info().num_nonzero_ = h_offset.back();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
// since csr is the default data structure so `source_` is always available.
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
@@ -174,8 +195,6 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
} // namespace data
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace data {
|
||||
// Used for single batch data.
|
||||
class SimpleDMatrix : public DMatrix {
|
||||
public:
|
||||
SimpleDMatrix() = default;
|
||||
template <typename AdapterT>
|
||||
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
|
||||
|
||||
@@ -32,6 +33,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
const MetaInfo& Info() const override;
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||
|
||||
/*! \brief magic number used to identify SimpleDMatrix binary files */
|
||||
static const int kMagic = 0xffffab01;
|
||||
|
||||
@@ -37,6 +37,10 @@ class SparsePageDMatrix : public DMatrix {
|
||||
const MetaInfo& Info() const override;
|
||||
|
||||
bool SingleColBlock() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override;
|
||||
|
||||
Reference in New Issue
Block a user