Fix slice and get info. (#5552)

This commit is contained in:
Jiaming Yuan
2020-04-18 18:00:13 +08:00
committed by GitHub
parent c245eb8755
commit e1f22baf8c
14 changed files with 177 additions and 163 deletions

View File

@@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
}
template <typename T>
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
if (in.empty()) {
return {};
}
auto size = ridxs.size();
std::vector<T> out(size * stride);
for (auto i = 0ull; i < size; i++) {
auto ridx = ridxs[i];
for (size_t j = 0; j < stride; ++j) {
out[i * stride +j] = in[ridx * stride + j];
}
}
return out;
}
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
MetaInfo out;
out.num_row_ = ridxs.size();
out.num_col_ = this->num_col_;
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
out.labels_upper_bound_.HostVector() =
Gather(this->labels_upper_bound_.HostVector(), ridxs);
out.labels_lower_bound_.HostVector() =
Gather(this->labels_lower_bound_.HostVector(), ridxs);
// weights
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
auto& h_weights = out.weights_.HostVector();
// Assuming all groups are available.
out.weights_.HostVector() = h_weights;
} else {
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
}
if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
size_t stride = this->base_margin_.Size() / this->num_row_;
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
} else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}
return out;
}
// try to load group information from file, if exists
inline bool MetaTryLoadGroup(const std::string& fname,
std::vector<unsigned>* group) {
@@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
data::IteratorAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);