Quantile DMatrix for CPU. (#8130)
- Add a new `QuantileDMatrix` that works for both CPU and GPU. - Deprecate `DeviceQuantileDMatrix`.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
@@ -6,16 +5,14 @@ import sys
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
import test_quantile_dmatrix as tqd
|
||||
|
||||
|
||||
class TestDeviceQuantileDMatrix:
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
with pytest.raises(TypeError, match='is not supported'):
|
||||
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
||||
cputest = tqd.TestQuantileDMatrix()
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dmatrix_feature_weights(self):
|
||||
def test_dmatrix_feature_weights(self) -> None:
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
data = rng.randn(5, 5)
|
||||
@@ -29,7 +26,7 @@ class TestDeviceQuantileDMatrix:
|
||||
feature_weights.astype(np.float32))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dmatrix_cupy_init(self):
|
||||
def test_dmatrix_cupy_init(self) -> None:
|
||||
import cupy as cp
|
||||
data = cp.random.randn(5, 5)
|
||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||
@@ -55,3 +52,10 @@ class TestDeviceQuantileDMatrix:
|
||||
|
||||
cp.testing.assert_allclose(fw, got_fw)
|
||||
cp.testing.assert_allclose(labels, got_labels)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_ref_dmatrix(self) -> None:
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
|
||||
|
||||
@@ -429,9 +429,10 @@ class TestDistributedGPU:
|
||||
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
|
||||
del sig["client"]
|
||||
ddm_names = list(sig.keys())
|
||||
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters)
|
||||
sig = OrderedDict(signature(dxgb.DaskQuantileDMatrix).parameters)
|
||||
del sig["client"]
|
||||
del sig["max_bin"]
|
||||
del sig["ref"]
|
||||
ddqdm_names = list(sig.keys())
|
||||
assert len(ddm_names) == len(ddqdm_names)
|
||||
|
||||
@@ -442,9 +443,10 @@ class TestDistributedGPU:
|
||||
sig = OrderedDict(signature(xgb.DMatrix).parameters)
|
||||
del sig["nthread"] # no nthread in dask
|
||||
dm_names = list(sig.keys())
|
||||
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters)
|
||||
sig = OrderedDict(signature(xgb.QuantileDMatrix).parameters)
|
||||
del sig["nthread"]
|
||||
del sig["max_bin"]
|
||||
del sig["ref"]
|
||||
dqdm_names = list(sig.keys())
|
||||
|
||||
# between single node
|
||||
@@ -499,7 +501,6 @@ class TestDistributedGPU:
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
|
||||
Reference in New Issue
Block a user