|
|
|
|
@@ -176,7 +176,7 @@ uint64_t constexpr MetaInfo::kNumField;
|
|
|
|
|
// implementation of inline functions
|
|
|
|
|
void MetaInfo::Clear() {
|
|
|
|
|
num_row_ = num_col_ = num_nonzero_ = 0;
|
|
|
|
|
labels_.HostVector().clear();
|
|
|
|
|
labels = decltype(labels){};
|
|
|
|
|
group_ptr_.clear();
|
|
|
|
|
weights_.HostVector().clear();
|
|
|
|
|
base_margin_ = decltype(base_margin_){};
|
|
|
|
|
@@ -213,8 +213,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
|
|
|
|
SaveScalarField(fo, u8"num_row", DataType::kUInt64, num_row_); ++field_cnt;
|
|
|
|
|
SaveScalarField(fo, u8"num_col", DataType::kUInt64, num_col_); ++field_cnt;
|
|
|
|
|
SaveScalarField(fo, u8"num_nonzero", DataType::kUInt64, num_nonzero_); ++field_cnt;
|
|
|
|
|
SaveVectorField(fo, u8"labels", DataType::kFloat32,
|
|
|
|
|
{labels_.Size(), 1}, labels_); ++field_cnt;
|
|
|
|
|
SaveTensorField(fo, u8"labels", DataType::kFloat32, labels); ++field_cnt;
|
|
|
|
|
SaveVectorField(fo, u8"group_ptr", DataType::kUInt32,
|
|
|
|
|
{group_ptr_.size(), 1}, group_ptr_); ++field_cnt;
|
|
|
|
|
SaveVectorField(fo, u8"weights", DataType::kFloat32,
|
|
|
|
|
@@ -291,7 +290,7 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
|
|
|
|
LoadScalarField(fi, u8"num_row", DataType::kUInt64, &num_row_);
|
|
|
|
|
LoadScalarField(fi, u8"num_col", DataType::kUInt64, &num_col_);
|
|
|
|
|
LoadScalarField(fi, u8"num_nonzero", DataType::kUInt64, &num_nonzero_);
|
|
|
|
|
LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_);
|
|
|
|
|
LoadTensorField(fi, u8"labels", DataType::kFloat32, &labels);
|
|
|
|
|
LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_);
|
|
|
|
|
LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_);
|
|
|
|
|
LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_);
|
|
|
|
|
@@ -326,7 +325,19 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|
|
|
|
out.num_col_ = this->num_col_;
|
|
|
|
|
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
|
|
|
|
// the slice function.
|
|
|
|
|
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
|
|
|
|
|
if (this->labels.Size() != this->num_row_) {
|
|
|
|
|
auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx());
|
|
|
|
|
out.labels.Reshape(ridxs.size(), labels.Shape(1));
|
|
|
|
|
out.labels.Data()->HostVector() =
|
|
|
|
|
Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0));
|
|
|
|
|
} else {
|
|
|
|
|
out.labels.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
|
|
|
|
|
data->HostVector() = Gather(this->labels.Data()->HostVector(), ridxs);
|
|
|
|
|
shape[0] = data->Size();
|
|
|
|
|
shape[1] = 1;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out.labels_upper_bound_.HostVector() =
|
|
|
|
|
Gather(this->labels_upper_bound_.HostVector(), ridxs);
|
|
|
|
|
out.labels_lower_bound_.HostVector() =
|
|
|
|
|
@@ -343,13 +354,16 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|
|
|
|
if (this->base_margin_.Size() != this->num_row_) {
|
|
|
|
|
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
|
|
|
|
<< "Incorrect size of base margin vector.";
|
|
|
|
|
auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
|
|
|
|
|
out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1]);
|
|
|
|
|
size_t stride = margin.Stride(0);
|
|
|
|
|
auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
|
|
|
|
|
out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1));
|
|
|
|
|
out.base_margin_.Data()->HostVector() =
|
|
|
|
|
Gather(this->base_margin_.Data()->HostVector(), ridxs, stride);
|
|
|
|
|
Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0));
|
|
|
|
|
} else {
|
|
|
|
|
out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs);
|
|
|
|
|
out.base_margin_.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
|
|
|
|
|
data->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs);
|
|
|
|
|
shape[0] = data->Size();
|
|
|
|
|
shape[1] = 1;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out.feature_weights.Resize(this->feature_weights.Size());
|
|
|
|
|
@@ -460,6 +474,17 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
|
|
|
|
this->base_margin_.Reshape(this->num_row_, n_groups);
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
} else if (key == "label") {
|
|
|
|
|
CopyTensorInfoImpl(arr, &this->labels);
|
|
|
|
|
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
|
|
|
|
|
CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
|
|
|
|
|
size_t n_targets = this->labels.Size() / this->num_row_;
|
|
|
|
|
this->labels.Reshape(this->num_row_, n_targets);
|
|
|
|
|
}
|
|
|
|
|
auto const& h_labels = labels.Data()->ConstHostVector();
|
|
|
|
|
auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{});
|
|
|
|
|
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// uint info
|
|
|
|
|
if (key == "group") {
|
|
|
|
|
@@ -500,12 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
|
|
|
|
// float info
|
|
|
|
|
linalg::Tensor<float, 1> t;
|
|
|
|
|
CopyTensorInfoImpl<1>(arr, &t);
|
|
|
|
|
if (key == "label") {
|
|
|
|
|
this->labels_ = std::move(*t.Data());
|
|
|
|
|
auto const& h_labels = labels_.ConstHostVector();
|
|
|
|
|
auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{});
|
|
|
|
|
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
|
|
|
|
} else if (key == "weight") {
|
|
|
|
|
if (key == "weight") {
|
|
|
|
|
this->weights_ = std::move(*t.Data());
|
|
|
|
|
auto const& h_weights = this->weights_.ConstHostVector();
|
|
|
|
|
auto valid = std::none_of(h_weights.cbegin(), h_weights.cend(),
|
|
|
|
|
@@ -568,7 +588,7 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
|
|
|
|
if (dtype == DataType::kFloat32) {
|
|
|
|
|
const std::vector<bst_float>* vec = nullptr;
|
|
|
|
|
if (!std::strcmp(key, "label")) {
|
|
|
|
|
vec = &this->labels_.HostVector();
|
|
|
|
|
vec = &this->labels.Data()->HostVector();
|
|
|
|
|
} else if (!std::strcmp(key, "weight")) {
|
|
|
|
|
vec = &this->weights_.HostVector();
|
|
|
|
|
} else if (!std::strcmp(key, "base_margin")) {
|
|
|
|
|
@@ -649,8 +669,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|
|
|
|
}
|
|
|
|
|
this->num_col_ = that.num_col_;
|
|
|
|
|
|
|
|
|
|
this->labels_.SetDevice(that.labels_.DeviceIdx());
|
|
|
|
|
this->labels_.Extend(that.labels_);
|
|
|
|
|
linalg::Stack(&this->labels, that.labels);
|
|
|
|
|
|
|
|
|
|
this->weights_.SetDevice(that.weights_.DeviceIdx());
|
|
|
|
|
this->weights_.Extend(that.weights_);
|
|
|
|
|
@@ -702,7 +721,7 @@ void MetaInfo::Validate(int32_t device) const {
|
|
|
|
|
<< "Invalid group structure. Number of rows obtained from groups "
|
|
|
|
|
"doesn't equal to actual number of rows given by data.";
|
|
|
|
|
}
|
|
|
|
|
auto check_device = [device](HostDeviceVector<float> const &v) {
|
|
|
|
|
auto check_device = [device](HostDeviceVector<float> const& v) {
|
|
|
|
|
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
|
|
|
|
|
device == GenericParameter::kCpuId ||
|
|
|
|
|
v.DeviceIdx() == device)
|
|
|
|
|
@@ -717,10 +736,10 @@ void MetaInfo::Validate(int32_t device) const {
|
|
|
|
|
check_device(weights_);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (labels_.Size() != 0) {
|
|
|
|
|
CHECK_EQ(labels_.Size(), num_row_)
|
|
|
|
|
if (labels.Size() != 0) {
|
|
|
|
|
CHECK_EQ(labels.Size(), num_row_)
|
|
|
|
|
<< "Size of labels must equal to number of rows.";
|
|
|
|
|
check_device(labels_);
|
|
|
|
|
check_device(*labels.Data());
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (labels_lower_bound_.Size() != 0) {
|
|
|
|
|
|