Implement slice via adapters (#5198)

This commit is contained in:
Rory Mitchell
2020-01-14 12:55:41 +13:00
committed by GitHub
parent f100b8d878
commit a73e25e15f
7 changed files with 190 additions and 45 deletions

View File

@@ -8,6 +8,8 @@
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace xgboost {
namespace data {
@@ -94,6 +96,7 @@ class NoMetaInfo {
const float* Labels() const { return nullptr; }
const float* Weights() const { return nullptr; }
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const { return nullptr; }
};
}; // namespace detail
@@ -446,6 +449,7 @@ class FileAdapterBatch {
const float* Labels() const { return block->label; }
const float* Weights() const { return block->weight; }
const uint64_t* Qid() const { return block->qid; }
const float* BaseMargin() const { return nullptr; }
size_t Size() const { return block->size; }
@@ -481,6 +485,92 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
std::unique_ptr<FileAdapterBatch> batch;
dmlc::Parser<uint32_t>* parser;
};
class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows
template <typename T>
std::vector<T> Gather(const std::vector<T>& in) {
if (in.empty()) return {};
std::vector<T> out(this->Size());
for (auto i = 0ull; i < this->Size(); i++) {
out[i] = in[ridx_set[i]];
}
return out;
}
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
common::Span<const int> ridx_set)
: batch(batch), ridx_set(ridx_set) {
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
}
class Line {
public:
Line(const SparsePage::Inst& inst, size_t row_idx)
: inst(inst), row_idx(row_idx) {}
size_t Size() { return inst.size(); }
COOTuple GetElement(size_t idx) {
return COOTuple(row_idx, inst[idx].index, inst[idx].fvalue);
}
private:
SparsePage::Inst inst;
size_t row_idx;
};
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
const float* Labels() const {
if (batch_labels.empty()) {
return nullptr;
}
return batch_labels.data();
}
const float* Weights() const {
if (batch_weights.empty()) {
return nullptr;
}
return batch_weights.data();
}
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const {
if (batch_base_margin.empty()) {
return nullptr;
}
return batch_base_margin.data();
}
size_t Size() const { return ridx_set.size(); }
const SparsePage& batch;
common::Span<const int> ridx_set;
std::vector<float> batch_labels;
std::vector<float> batch_weights;
std::vector<float> batch_base_margin;
};
// Group pointer is not exposed
// This is because external bindings currently manipulate the group values
// manually when slicing This could potentially be moved to internal C++ code if
// needed
class DMatrixSliceAdapter
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
public:
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
: dmat(dmat),
ridx_set(ridx_set),
batch(*dmat->GetBatches<SparsePage>().begin(), dmat, ridx_set) {}
const DMatrixSliceAdapterBatch& Value() const override { return batch; }
// Indicates a number of rows/columns must be inferred
size_t NumRows() const { return ridx_set.size(); }
size_t NumColumns() const { return dmat->Info().num_col_; }
DMatrix* dmat;
DMatrixSliceAdapterBatch batch;
bool before_first{true};
common::Span<const int> ridx_set;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_