Set device in device dmatrix. (#5596)
This commit is contained in:
@@ -149,9 +149,17 @@ TEST(MetaInfo, Validate) {
|
||||
info.num_col_ = 3;
|
||||
std::vector<xgboost::bst_group_t> groups (11);
|
||||
info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11);
|
||||
EXPECT_THROW(info.Validate(), dmlc::Error);
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
|
||||
std::vector<float> labels(info.num_row_ + 1);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
||||
EXPECT_THROW(info.Validate(), dmlc::Error);
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
info.group_ptr_.clear();
|
||||
labels.resize(info.num_row_);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.labels_.SetDevice(0);
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
@@ -136,3 +136,14 @@ Arrow specification.'''
|
||||
n = 100
|
||||
X = cp.random.random((n, 2))
|
||||
xgb.DeviceQuantileDMatrix(X.toDlpack())
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_specified_device(self):
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
dtrain = dmatrix_from_cupy(
|
||||
np.float32, xgb.DeviceQuantileDMatrix, np.nan)
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
xgb.train({'tree_method': 'gpu_hist', 'gpu_id': 1},
|
||||
dtrain, num_boost_round=10)
|
||||
|
||||
@@ -121,6 +121,7 @@ class TestGPUPredict(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_inplace_predict_cupy(self):
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
rows = 1000
|
||||
cols = 10
|
||||
cp_rng = cp.random.RandomState(1994)
|
||||
|
||||
Reference in New Issue
Block a user