Implement slice via adapters (#5198)
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -94,6 +96,7 @@ class NoMetaInfo {
|
||||
const float* Labels() const { return nullptr; }
|
||||
const float* Weights() const { return nullptr; }
|
||||
const uint64_t* Qid() const { return nullptr; }
|
||||
const float* BaseMargin() const { return nullptr; }
|
||||
};
|
||||
|
||||
}; // namespace detail
|
||||
@@ -446,6 +449,7 @@ class FileAdapterBatch {
|
||||
const float* Labels() const { return block->label; }
|
||||
const float* Weights() const { return block->weight; }
|
||||
const uint64_t* Qid() const { return block->qid; }
|
||||
const float* BaseMargin() const { return nullptr; }
|
||||
|
||||
size_t Size() const { return block->size; }
|
||||
|
||||
@@ -481,6 +485,92 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
||||
std::unique_ptr<FileAdapterBatch> batch;
|
||||
dmlc::Parser<uint32_t>* parser;
|
||||
};
|
||||
|
||||
class DMatrixSliceAdapterBatch {
|
||||
public:
|
||||
// Fetch metainfo values according to sliced rows
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T>& in) {
|
||||
if (in.empty()) return {};
|
||||
|
||||
std::vector<T> out(this->Size());
|
||||
for (auto i = 0ull; i < this->Size(); i++) {
|
||||
out[i] = in[ridx_set[i]];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
|
||||
common::Span<const int> ridx_set)
|
||||
: batch(batch), ridx_set(ridx_set) {
|
||||
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
|
||||
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
|
||||
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
|
||||
}
|
||||
|
||||
class Line {
|
||||
public:
|
||||
Line(const SparsePage::Inst& inst, size_t row_idx)
|
||||
: inst(inst), row_idx(row_idx) {}
|
||||
|
||||
size_t Size() { return inst.size(); }
|
||||
COOTuple GetElement(size_t idx) {
|
||||
return COOTuple(row_idx, inst[idx].index, inst[idx].fvalue);
|
||||
}
|
||||
|
||||
private:
|
||||
SparsePage::Inst inst;
|
||||
size_t row_idx;
|
||||
};
|
||||
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
|
||||
const float* Labels() const {
|
||||
if (batch_labels.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_labels.data();
|
||||
}
|
||||
const float* Weights() const {
|
||||
if (batch_weights.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_weights.data();
|
||||
}
|
||||
const uint64_t* Qid() const { return nullptr; }
|
||||
const float* BaseMargin() const {
|
||||
if (batch_base_margin.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_base_margin.data();
|
||||
}
|
||||
|
||||
size_t Size() const { return ridx_set.size(); }
|
||||
const SparsePage& batch;
|
||||
common::Span<const int> ridx_set;
|
||||
std::vector<float> batch_labels;
|
||||
std::vector<float> batch_weights;
|
||||
std::vector<float> batch_base_margin;
|
||||
};
|
||||
|
||||
// Group pointer is not exposed
|
||||
// This is because external bindings currently manipulate the group values
|
||||
// manually when slicing This could potentially be moved to internal C++ code if
|
||||
// needed
|
||||
|
||||
class DMatrixSliceAdapter
|
||||
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
|
||||
public:
|
||||
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
|
||||
: dmat(dmat),
|
||||
ridx_set(ridx_set),
|
||||
batch(*dmat->GetBatches<SparsePage>().begin(), dmat, ridx_set) {}
|
||||
const DMatrixSliceAdapterBatch& Value() const override { return batch; }
|
||||
// Indicates a number of rows/columns must be inferred
|
||||
size_t NumRows() const { return ridx_set.size(); }
|
||||
size_t NumColumns() const { return dmat->Info().num_col_; }
|
||||
DMatrix* dmat;
|
||||
DMatrixSliceAdapterBatch batch;
|
||||
bool before_first{true};
|
||||
common::Span<const int> ridx_set;
|
||||
};
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||
|
||||
Reference in New Issue
Block a user