Reduce base margin to 2 dim for now. (#7455)
This commit is contained in:
@@ -185,20 +185,20 @@ void MetaInfo::Clear() {
|
||||
/*
|
||||
* Binary serialization format for MetaInfo:
|
||||
*
|
||||
* | name | type | is_scalar | num_row | num_col | dim3 | value |
|
||||
* |--------------------+----------+-----------+-------------+-------------+-------------+------------------------|
|
||||
* | num_row | kUInt64 | True | NA | NA | NA | ${num_row_} |
|
||||
* | num_col | kUInt64 | True | NA | NA | NA | ${num_col_} |
|
||||
* | num_nonzero | kUInt64 | True | NA | NA | NA | ${num_nonzero_} |
|
||||
* | labels | kFloat32 | False | ${size} | 1 | NA | ${labels_} |
|
||||
* | group_ptr | kUInt32 | False | ${size} | 1 | NA | ${group_ptr_} |
|
||||
* | weights | kFloat32 | False | ${size} | 1 | NA | ${weights_} |
|
||||
* | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${Shape(2)} | ${base_margin_} |
|
||||
* | labels_lower_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_lower_bound_} |
|
||||
* | labels_upper_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_upper_bound_} |
|
||||
* | feature_names | kStr | False | ${size} | 1 | NA | ${feature_names} |
|
||||
* | feature_types | kStr | False | ${size} | 1 | NA | ${feature_types} |
|
||||
* | feature_types | kFloat32 | False | ${size} | 1 | NA | ${feature_weights} |
|
||||
* | name | type | is_scalar | num_row | num_col | value |
|
||||
* |--------------------+----------+-----------+-------------+-------------+------------------------|
|
||||
* | num_row | kUInt64 | True | NA | NA | ${num_row_} |
|
||||
* | num_col | kUInt64 | True | NA | NA | ${num_col_} |
|
||||
* | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} |
|
||||
* | labels | kFloat32 | False | ${size} | 1 | ${labels_} |
|
||||
* | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} |
|
||||
* | weights | kFloat32 | False | ${size} | 1 | ${weights_} |
|
||||
* | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${base_margin_} |
|
||||
* | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound_} |
|
||||
* | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound_} |
|
||||
* | feature_names | kStr | False | ${size} | 1 | ${feature_names} |
|
||||
* | feature_types | kStr | False | ${size} | 1 | ${feature_types} |
|
||||
* | feature_weights | kFloat32 | False | ${size} | 1 | ${feature_weights} |
|
||||
*
|
||||
* Note that the scalar fields (is_scalar=True) will have num_row and num_col missing.
|
||||
* Also notice the difference between the saved name and the name used in `SetInfo':
|
||||
@@ -344,7 +344,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
||||
<< "Incorrect size of base margin vector.";
|
||||
auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx());
|
||||
out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1], margin.Shape()[2]);
|
||||
out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1]);
|
||||
size_t stride = margin.Stride(0);
|
||||
out.base_margin_.Data()->HostVector() =
|
||||
Gather(this->base_margin_.Data()->HostVector(), ridxs, stride);
|
||||
@@ -447,7 +447,7 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) {
|
||||
void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
// multi-dim float info
|
||||
if (key == "base_margin") {
|
||||
CopyTensorInfoImpl<3>(arr, &this->base_margin_);
|
||||
CopyTensorInfoImpl(arr, &this->base_margin_);
|
||||
// FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of
|
||||
// input shape. This issue is CPU only since CUDA uses array interface from day 1.
|
||||
//
|
||||
|
||||
@@ -137,10 +137,10 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
batch.Weights() + batch.Size());
|
||||
}
|
||||
if (batch.BaseMargin() != nullptr) {
|
||||
info_.base_margin_ = linalg::Tensor<float, 3>{batch.BaseMargin(),
|
||||
batch.BaseMargin() + batch.Size(),
|
||||
{batch.Size()},
|
||||
GenericParameter::kCpuId};
|
||||
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(),
|
||||
batch.BaseMargin() + batch.Size(),
|
||||
{batch.Size()},
|
||||
GenericParameter::kCpuId};
|
||||
}
|
||||
if (batch.Qid() != nullptr) {
|
||||
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
||||
|
||||
@@ -61,7 +61,8 @@ Predictor* Predictor::Create(
|
||||
return p_predictor;
|
||||
}
|
||||
|
||||
void ValidateBaseMarginShape(linalg::Tensor<float, 3> const& margin, bst_row_t n_samples,
|
||||
template <int32_t D>
|
||||
void ValidateBaseMarginShape(linalg::Tensor<float, D> const& margin, bst_row_t n_samples,
|
||||
bst_group_t n_groups) {
|
||||
// FIXME: Bindings other than Python doesn't have shape.
|
||||
std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) +
|
||||
|
||||
Reference in New Issue
Block a user