Reduce base margin to 2 dim for now. (#7455)

This commit is contained in:
Jiaming Yuan
2021-11-27 00:46:13 +08:00
committed by GitHub
parent bf7bb575b4
commit 557ffc4bf5
7 changed files with 33 additions and 33 deletions

View File

@@ -185,20 +185,20 @@ void MetaInfo::Clear() {
/*
* Binary serialization format for MetaInfo:
*
* | name | type | is_scalar | num_row | num_col | dim3 | value |
* |--------------------+----------+-----------+-------------+-------------+-------------+------------------------|
* | num_row | kUInt64 | True | NA | NA | NA | ${num_row_} |
* | num_col | kUInt64 | True | NA | NA | NA | ${num_col_} |
* | num_nonzero | kUInt64 | True | NA | NA | NA | ${num_nonzero_} |
* | labels | kFloat32 | False | ${size} | 1 | NA | ${labels_} |
* | group_ptr | kUInt32 | False | ${size} | 1 | NA | ${group_ptr_} |
* | weights | kFloat32 | False | ${size} | 1 | NA | ${weights_} |
* | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${Shape(2)} | ${base_margin_} |
* | labels_lower_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_lower_bound_} |
* | labels_upper_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_upper_bound_} |
* | feature_names | kStr | False | ${size} | 1 | NA | ${feature_names} |
* | feature_types | kStr | False | ${size} | 1 | NA | ${feature_types} |
* | feature_types | kFloat32 | False | ${size} | 1 | NA | ${feature_weights} |
* | name | type | is_scalar | num_row | num_col | value |
* |--------------------+----------+-----------+-------------+-------------+------------------------|
* | num_row | kUInt64 | True | NA | NA | ${num_row_} |
* | num_col | kUInt64 | True | NA | NA | ${num_col_} |
* | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} |
* | labels | kFloat32 | False | ${size} | 1 | ${labels_} |
* | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} |
* | weights | kFloat32 | False | ${size} | 1 | ${weights_} |
* | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${base_margin_} |
* | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound_} |
* | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound_} |
* | feature_names | kStr | False | ${size} | 1 | ${feature_names} |
* | feature_types | kStr | False | ${size} | 1 | ${feature_types} |
* | feature_weights | kFloat32 | False | ${size} | 1 | ${feature_weights} |
*
* Note that the scalar fields (is_scalar=True) will have num_row and num_col missing.
* Also notice the difference between the saved name and the name used in `SetInfo':
@@ -344,7 +344,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1], margin.Shape()[2]);
out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1]);
size_t stride = margin.Stride(0);
out.base_margin_.Data()->HostVector() =
Gather(this->base_margin_.Data()->HostVector(), ridxs, stride);
@@ -447,7 +447,7 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) {
void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
// multi-dim float info
if (key == "base_margin") {
CopyTensorInfoImpl<3>(arr, &this->base_margin_);
CopyTensorInfoImpl(arr, &this->base_margin_);
// FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of
// input shape. This issue is CPU only since CUDA uses array interface from day 1.
//

View File

@@ -137,10 +137,10 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
batch.Weights() + batch.Size());
}
if (batch.BaseMargin() != nullptr) {
info_.base_margin_ = linalg::Tensor<float, 3>{batch.BaseMargin(),
batch.BaseMargin() + batch.Size(),
{batch.Size()},
GenericParameter::kCpuId};
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(),
batch.BaseMargin() + batch.Size(),
{batch.Size()},
GenericParameter::kCpuId};
}
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());