Fix slice and get info. (#5552)
This commit is contained in:
@@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
dmlc::RowBlock<uint32_t> block_;
|
||||
std::unique_ptr<FileAdapterBatch> batch_;
|
||||
};
|
||||
|
||||
class DMatrixSliceAdapterBatch {
|
||||
public:
|
||||
// Fetch metainfo values according to sliced rows
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T>& in) {
|
||||
if (in.empty()) return {};
|
||||
|
||||
std::vector<T> out(this->Size());
|
||||
for (auto i = 0ull; i < this->Size(); i++) {
|
||||
out[i] = in[ridx_set[i]];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
|
||||
common::Span<const int> ridx_set)
|
||||
: batch(batch), ridx_set(ridx_set) {
|
||||
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
|
||||
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
|
||||
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
|
||||
}
|
||||
|
||||
class Line {
|
||||
public:
|
||||
Line(const SparsePage::Inst& inst, size_t row_idx)
|
||||
: inst_(inst), row_idx_(row_idx) {}
|
||||
|
||||
size_t Size() { return inst_.size(); }
|
||||
COOTuple GetElement(size_t idx) {
|
||||
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
|
||||
}
|
||||
|
||||
private:
|
||||
SparsePage::Inst inst_;
|
||||
size_t row_idx_;
|
||||
};
|
||||
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
|
||||
const float* Labels() const {
|
||||
if (batch_labels.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_labels.data();
|
||||
}
|
||||
const float* Weights() const {
|
||||
if (batch_weights.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_weights.data();
|
||||
}
|
||||
const uint64_t* Qid() const { return nullptr; }
|
||||
const float* BaseMargin() const {
|
||||
if (batch_base_margin.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_base_margin.data();
|
||||
}
|
||||
|
||||
size_t Size() const { return ridx_set.size(); }
|
||||
const SparsePage& batch;
|
||||
common::Span<const int> ridx_set;
|
||||
std::vector<float> batch_labels;
|
||||
std::vector<float> batch_weights;
|
||||
std::vector<float> batch_base_margin;
|
||||
};
|
||||
|
||||
// Group pointer is not exposed
|
||||
// This is because external bindings currently manipulate the group values
|
||||
// manually when slicing This could potentially be moved to internal C++ code if
|
||||
// needed
|
||||
|
||||
class DMatrixSliceAdapter
|
||||
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
|
||||
public:
|
||||
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
|
||||
: dmat_(dmat),
|
||||
ridx_set_(ridx_set),
|
||||
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
|
||||
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
|
||||
// Indicates a number of rows/columns must be inferred
|
||||
size_t NumRows() const { return ridx_set_.size(); }
|
||||
size_t NumColumns() const { return dmat_->Info().num_col_; }
|
||||
|
||||
private:
|
||||
DMatrix* dmat_;
|
||||
common::Span<const int> ridx_set_;
|
||||
DMatrixSliceAdapterBatch batch_;
|
||||
};
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||
|
||||
Reference in New Issue
Block a user