Fix slice and get info. (#5552)

This commit is contained in:
Jiaming Yuan
2020-04-18 18:00:13 +08:00
committed by GitHub
parent c245eb8755
commit e1f22baf8c
14 changed files with 177 additions and 163 deletions

View File

@@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
dmlc::RowBlock<uint32_t> block_;
std::unique_ptr<FileAdapterBatch> batch_;
};
class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows
template <typename T>
std::vector<T> Gather(const std::vector<T>& in) {
if (in.empty()) return {};
std::vector<T> out(this->Size());
for (auto i = 0ull; i < this->Size(); i++) {
out[i] = in[ridx_set[i]];
}
return out;
}
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
common::Span<const int> ridx_set)
: batch(batch), ridx_set(ridx_set) {
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
}
class Line {
public:
Line(const SparsePage::Inst& inst, size_t row_idx)
: inst_(inst), row_idx_(row_idx) {}
size_t Size() { return inst_.size(); }
COOTuple GetElement(size_t idx) {
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
}
private:
SparsePage::Inst inst_;
size_t row_idx_;
};
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
const float* Labels() const {
if (batch_labels.empty()) {
return nullptr;
}
return batch_labels.data();
}
const float* Weights() const {
if (batch_weights.empty()) {
return nullptr;
}
return batch_weights.data();
}
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const {
if (batch_base_margin.empty()) {
return nullptr;
}
return batch_base_margin.data();
}
size_t Size() const { return ridx_set.size(); }
const SparsePage& batch;
common::Span<const int> ridx_set;
std::vector<float> batch_labels;
std::vector<float> batch_weights;
std::vector<float> batch_base_margin;
};
// Group pointer is not exposed
// This is because external bindings currently manipulate the group values
// manually when slicing This could potentially be moved to internal C++ code if
// needed
class DMatrixSliceAdapter
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
public:
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
: dmat_(dmat),
ridx_set_(ridx_set),
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
// Indicates a number of rows/columns must be inferred
size_t NumRows() const { return ridx_set_.size(); }
size_t NumColumns() const { return dmat_->Info().num_col_; }
private:
DMatrix* dmat_;
common::Span<const int> ridx_set_;
DMatrixSliceAdapterBatch batch_;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_