Handle duplicated values in sketching. (#6178)
* Accumulate weights in duplicated values. * Fix device id in iterative dmatrix.
This commit is contained in:
@@ -63,15 +63,17 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
size_t accumulated_rows = 0;
|
||||
bst_feature_t cols = 0;
|
||||
int32_t device = GenericParameter::kCpuId;
|
||||
int32_t current_device_;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device_));
|
||||
int32_t current_device;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||
auto get_device = [&]() -> int32_t {
|
||||
int32_t d = GenericParameter::kCpuId ? current_device_ : device;
|
||||
int32_t d = (device == GenericParameter::kCpuId) ? current_device : device;
|
||||
CHECK_NE(d, GenericParameter::kCpuId);
|
||||
return d;
|
||||
};
|
||||
|
||||
while (iter.Next()) {
|
||||
device = proxy->DeviceIdx();
|
||||
CHECK_LT(device, common::AllVisibleGPUs());
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
|
||||
@@ -66,6 +66,9 @@ class DMatrixProxy : public DMatrix {
|
||||
} else {
|
||||
this->FromCudaArray(interface_str);
|
||||
}
|
||||
if (this->info_.num_row_ == 0) {
|
||||
this->device_ = GenericParameter::kCpuId;
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user