Set device in device dmatrix. (#5596)
This commit is contained in:
@@ -338,7 +338,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate() const {
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||
<< "Size of weights must equal to number of groups when ranking "
|
||||
@@ -350,30 +350,44 @@ void MetaInfo::Validate() const {
|
||||
<< "Invalid group structure. Number of rows obtained from groups "
|
||||
"doesn't equal to actual number of rows given by data.";
|
||||
}
|
||||
auto check_device = [device](HostDeviceVector<float> const &v) {
|
||||
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
|
||||
device == GenericParameter::kCpuId ||
|
||||
v.DeviceIdx() == device)
|
||||
<< "Data is resided on a different device than `gpu_id`. "
|
||||
<< "Device that data is on: " << v.DeviceIdx() << ", "
|
||||
<< "`gpu_id` for XGBoost: " << device;
|
||||
};
|
||||
|
||||
if (weights_.Size() != 0) {
|
||||
CHECK_EQ(weights_.Size(), num_row_)
|
||||
<< "Size of weights must equal to number of rows.";
|
||||
check_device(weights_);
|
||||
return;
|
||||
}
|
||||
if (labels_.Size() != 0) {
|
||||
CHECK_EQ(labels_.Size(), num_row_)
|
||||
<< "Size of labels must equal to number of rows.";
|
||||
check_device(labels_);
|
||||
return;
|
||||
}
|
||||
if (labels_lower_bound_.Size() != 0) {
|
||||
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
|
||||
<< "Size of label_lower_bound must equal to number of rows.";
|
||||
check_device(labels_lower_bound_);
|
||||
return;
|
||||
}
|
||||
if (labels_upper_bound_.Size() != 0) {
|
||||
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
|
||||
<< "Size of label_upper_bound must equal to number of rows.";
|
||||
check_device(labels_upper_bound_);
|
||||
return;
|
||||
}
|
||||
CHECK_LE(num_nonzero_, num_col_ * num_row_);
|
||||
if (base_margin_.Size() != 0) {
|
||||
CHECK_EQ(base_margin_.Size() % num_row_, 0)
|
||||
<< "Size of base margin must be a multiple of number of rows.";
|
||||
check_device(base_margin_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user