Fix slice and get info. (#5552)
This commit is contained in:
@@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
|
||||
if (in.empty()) {
|
||||
return {};
|
||||
}
|
||||
auto size = ridxs.size();
|
||||
std::vector<T> out(size * stride);
|
||||
for (auto i = 0ull; i < size; i++) {
|
||||
auto ridx = ridxs[i];
|
||||
for (size_t j = 0; j < stride; ++j) {
|
||||
out[i * stride +j] = in[ridx * stride + j];
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
MetaInfo out;
|
||||
out.num_row_ = ridxs.size();
|
||||
out.num_col_ = this->num_col_;
|
||||
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
||||
// the slice function.
|
||||
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
|
||||
out.labels_upper_bound_.HostVector() =
|
||||
Gather(this->labels_upper_bound_.HostVector(), ridxs);
|
||||
out.labels_lower_bound_.HostVector() =
|
||||
Gather(this->labels_lower_bound_.HostVector(), ridxs);
|
||||
// weights
|
||||
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
|
||||
auto& h_weights = out.weights_.HostVector();
|
||||
// Assuming all groups are available.
|
||||
out.weights_.HostVector() = h_weights;
|
||||
} else {
|
||||
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
|
||||
}
|
||||
|
||||
if (this->base_margin_.Size() != this->num_row_) {
|
||||
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
||||
<< "Incorrect size of base margin vector.";
|
||||
size_t stride = this->base_margin_.Size() / this->num_row_;
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
|
||||
} else {
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// try to load group information from file, if exists
|
||||
inline bool MetaTryLoadGroup(const std::string& fname,
|
||||
std::vector<unsigned>* group) {
|
||||
@@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||
data::FileAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
|
||||
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
|
||||
data::IteratorAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
|
||||
Reference in New Issue
Block a user