Convert labels into tensor. (#7456)

* Add a new ctor to tensor for `initilizer_list`.
* Change labels from host device vector to tensor.
* Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
Jiaming Yuan
2021-12-17 00:58:35 +08:00
committed by GitHub
parent 6f8a4633b7
commit 5b1161bb64
35 changed files with 319 additions and 258 deletions

View File

@@ -176,7 +176,7 @@ uint64_t constexpr MetaInfo::kNumField;
// implementation of inline functions
void MetaInfo::Clear() {
num_row_ = num_col_ = num_nonzero_ = 0;
labels_.HostVector().clear();
labels = decltype(labels){};
group_ptr_.clear();
weights_.HostVector().clear();
base_margin_ = decltype(base_margin_){};
@@ -213,8 +213,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
SaveScalarField(fo, u8"num_row", DataType::kUInt64, num_row_); ++field_cnt;
SaveScalarField(fo, u8"num_col", DataType::kUInt64, num_col_); ++field_cnt;
SaveScalarField(fo, u8"num_nonzero", DataType::kUInt64, num_nonzero_); ++field_cnt;
SaveVectorField(fo, u8"labels", DataType::kFloat32,
{labels_.Size(), 1}, labels_); ++field_cnt;
SaveTensorField(fo, u8"labels", DataType::kFloat32, labels); ++field_cnt;
SaveVectorField(fo, u8"group_ptr", DataType::kUInt32,
{group_ptr_.size(), 1}, group_ptr_); ++field_cnt;
SaveVectorField(fo, u8"weights", DataType::kFloat32,
@@ -291,7 +290,7 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadScalarField(fi, u8"num_row", DataType::kUInt64, &num_row_);
LoadScalarField(fi, u8"num_col", DataType::kUInt64, &num_col_);
LoadScalarField(fi, u8"num_nonzero", DataType::kUInt64, &num_nonzero_);
LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_);
LoadTensorField(fi, u8"labels", DataType::kFloat32, &labels);
LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_);
LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_);
LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_);
@@ -326,7 +325,19 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
out.num_col_ = this->num_col_;
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
if (this->labels.Size() != this->num_row_) {
auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx());
out.labels.Reshape(ridxs.size(), labels.Shape(1));
out.labels.Data()->HostVector() =
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
} else {
out.labels.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
data->HostVector() = Gather(this->labels.Data()->HostVector(), ridxs);
shape[0] = data->Size();
shape[1] = 1;
});
}
out.labels_upper_bound_.HostVector() =
Gather(this->labels_upper_bound_.HostVector(), ridxs);
out.labels_lower_bound_.HostVector() =
@@ -343,13 +354,16 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1]);
size_t stride = margin.Stride(0);
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
out.base_margin_.Data()->HostVector() =
Gather(this->base_margin_.Data()->HostVector(), ridxs, stride);
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
} else {
out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs);
out.base_margin_.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
data->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs);
shape[0] = data->Size();
shape[1] = 1;
});
}
out.feature_weights.Resize(this->feature_weights.Size());
@@ -460,6 +474,17 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
this->base_margin_.Reshape(this->num_row_, n_groups);
}
return;
} else if (key == "label") {
CopyTensorInfoImpl(arr, &this->labels);
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
size_t n_targets = this->labels.Size() / this->num_row_;
this->labels.Reshape(this->num_row_, n_targets);
}
auto const& h_labels = labels.Data()->ConstHostVector();
auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{});
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
return;
}
// uint info
if (key == "group") {
@@ -500,12 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
// float info
linalg::Tensor<float, 1> t;
CopyTensorInfoImpl<1>(arr, &t);
if (key == "label") {
this->labels_ = std::move(*t.Data());
auto const& h_labels = labels_.ConstHostVector();
auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{});
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
} else if (key == "weight") {
if (key == "weight") {
this->weights_ = std::move(*t.Data());
auto const& h_weights = this->weights_.ConstHostVector();
auto valid = std::none_of(h_weights.cbegin(), h_weights.cend(),
@@ -568,7 +588,7 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
if (dtype == DataType::kFloat32) {
const std::vector<bst_float>* vec = nullptr;
if (!std::strcmp(key, "label")) {
vec = &this->labels_.HostVector();
vec = &this->labels.Data()->HostVector();
} else if (!std::strcmp(key, "weight")) {
vec = &this->weights_.HostVector();
} else if (!std::strcmp(key, "base_margin")) {
@@ -649,8 +669,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
this->num_col_ = that.num_col_;
this->labels_.SetDevice(that.labels_.DeviceIdx());
this->labels_.Extend(that.labels_);
linalg::Stack(&this->labels, that.labels);
this->weights_.SetDevice(that.weights_.DeviceIdx());
this->weights_.Extend(that.weights_);
@@ -702,7 +721,7 @@ void MetaInfo::Validate(int32_t device) const {
<< "Invalid group structure. Number of rows obtained from groups "
"doesn't equal to actual number of rows given by data.";
}
auto check_device = [device](HostDeviceVector<float> const &v) {
auto check_device = [device](HostDeviceVector<float> const& v) {
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
device == GenericParameter::kCpuId ||
v.DeviceIdx() == device)
@@ -717,10 +736,10 @@ void MetaInfo::Validate(int32_t device) const {
check_device(weights_);
return;
}
if (labels_.Size() != 0) {
CHECK_EQ(labels_.Size(), num_row_)
if (labels.Size() != 0) {
CHECK_EQ(labels.Size(), num_row_)
<< "Size of labels must equal to number of rows.";
check_device(labels_);
check_device(*labels.Data());
return;
}
if (labels_lower_bound_.Size() != 0) {

View File

@@ -119,6 +119,12 @@ void MetaInfo::SetInfoFromCUDA(StringView key, Json array) {
if (key == "base_margin") {
CopyTensorInfoImpl(array, &base_margin_);
return;
} else if (key == "label") {
CopyTensorInfoImpl(array, &labels);
auto ptr = labels.Data()->ConstDevicePointer();
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{});
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
return;
}
// uint info
if (key == "group") {
@@ -135,12 +141,7 @@ void MetaInfo::SetInfoFromCUDA(StringView key, Json array) {
// float info
linalg::Tensor<float, 1> t;
CopyTensorInfoImpl(array, &t);
if (key == "label") {
this->labels_ = std::move(*t.Data());
auto ptr = labels_.ConstDevicePointer();
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), data::LabelsCheck{});
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
} else if (key == "weight") {
if (key == "weight") {
this->weights_ = std::move(*t.Data());
auto ptr = weights_.ConstDevicePointer();
auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), data::WeightsCheck{});

View File

@@ -153,7 +153,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
if (batches == 1) {
this->info_ = std::move(proxy->Info());
this->info_.num_nonzero_ = nnz;
CHECK_EQ(proxy->Info().labels_.Size(), 0);
CHECK_EQ(proxy->Info().labels.Size(), 0);
}
iter.Reset();

View File

@@ -127,14 +127,16 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
total_batch_size += batch.Size();
// Append meta information if available
if (batch.Labels() != nullptr) {
auto& labels = info_.labels_.HostVector();
labels.insert(labels.end(), batch.Labels(),
batch.Labels() + batch.Size());
info_.labels.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
shape[1] = 1;
auto& labels = data->HostVector();
labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size());
shape[0] += batch.Size();
});
}
if (batch.Weights() != nullptr) {
auto& weights = info_.weights_.HostVector();
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
}
if (batch.BaseMargin() != nullptr) {
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(),