Set device in device dmatrix. (#5596)

This commit is contained in:
Jiaming Yuan
2020-04-25 13:42:53 +08:00
committed by fis
parent 3728855ce9
commit 844d7c1d5b
8 changed files with 41 additions and 5 deletions

View File

@@ -338,7 +338,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
}
}
void MetaInfo::Validate() const {
void MetaInfo::Validate(int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
<< "Size of weights must equal to number of groups when ranking "
@@ -350,30 +350,44 @@ void MetaInfo::Validate() const {
<< "Invalid group structure. Number of rows obtained from groups "
"doesn't equal to actual number of rows given by data.";
}
auto check_device = [device](HostDeviceVector<float> const &v) {
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
device == GenericParameter::kCpuId ||
v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
};
if (weights_.Size() != 0) {
CHECK_EQ(weights_.Size(), num_row_)
<< "Size of weights must equal to number of rows.";
check_device(weights_);
return;
}
if (labels_.Size() != 0) {
CHECK_EQ(labels_.Size(), num_row_)
<< "Size of labels must equal to number of rows.";
check_device(labels_);
return;
}
if (labels_lower_bound_.Size() != 0) {
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
<< "Size of label_lower_bound must equal to number of rows.";
check_device(labels_lower_bound_);
return;
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";
check_device(labels_upper_bound_);
return;
}
CHECK_LE(num_nonzero_, num_col_ * num_row_);
if (base_margin_.Size() != 0) {
CHECK_EQ(base_margin_.Size() % num_row_, 0)
<< "Size of base margin must be a multiple of number of rows.";
check_device(base_margin_);
}
}

View File

@@ -201,6 +201,7 @@ template <typename AdapterT>
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) {
common::HistogramCuts cuts =
common::AdapterDeviceSketch(adapter, max_bin, missing);
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
auto& batch = adapter->Value();
// Work out how many valid entries we have in each row
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);

View File

@@ -99,6 +99,7 @@ void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);

View File

@@ -1052,7 +1052,7 @@ class LearnerImpl : public LearnerIO {
void ValidateDMatrix(DMatrix* p_fmat) const {
MetaInfo const& info = p_fmat->Info();
info.Validate();
info.Validate(generic_parameters_.gpu_id);
auto const row_based_split = [this]() {
return tparam_.dsplit == DataSplitMode::kRow ||