Define metainfo and other parameters for all DMatrix interfaces. (#6601)

This PR ensures all DMatrix types have a common interface.

* Fix logic in avoiding duplicated DMatrix in sklearn.
* Check for consistency between DMatrix types.
* Add doc for bounds.
This commit is contained in:
Jiaming Yuan
2021-01-25 16:06:06 +08:00
committed by GitHub
parent 561809200a
commit 8942c98054
6 changed files with 365 additions and 158 deletions

View File

@@ -34,3 +34,25 @@ class TestDeviceQuantileDMatrix:
import cupy as cp
data = cp.random.randn(5, 5)
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy())
def test_metainfo(self) -> None:
import cupy as cp
rng = cp.random.RandomState(1994)
rows = 10
cols = 3
data = rng.randn(rows, cols)
labels = rng.randn(rows)
fw = rng.randn(rows)
fw -= fw.min()
m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw)
got_fw = m.get_float_info("feature_weights")
got_labels = m.get_label()
cp.testing.assert_allclose(fw, got_fw)
cp.testing.assert_allclose(labels, got_labels)

View File

@@ -6,7 +6,9 @@ import numpy as np
import asyncio
import xgboost
import subprocess
from hypothesis import given, strategies, settings, note, HealthCheck
from collections import OrderedDict
from inspect import signature
from hypothesis import given, strategies, settings, note
from hypothesis._settings import duration
from test_gpu_updaters import parameter_strategy
@@ -18,13 +20,15 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
from test_with_dask import suppress
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
import testing as tm # noqa
try:
import dask.dataframe as dd
from xgboost import dask as dxgb
import xgboost as xgb
from dask.distributed import Client
from dask import array as da
from dask_cuda import LocalCUDACluster
@@ -252,6 +256,64 @@ class TestDistributedGPU:
run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)
def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
X, y, _ = generate_array()
fw = da.random.random((random_cols, ))
fw = fw - fw.min()
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
workers = list(_get_client_workers(client).keys())
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()
futures = []
for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i],
m.create_fn_args(workers[i]), pure=False,
workers=[workers[i]]))
client.gather(futures)
def test_interface_consistency(self) -> None:
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
del sig["client"]
ddm_names = list(sig.keys())
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters)
del sig["client"]
del sig["max_bin"]
ddqdm_names = list(sig.keys())
assert len(ddm_names) == len(ddqdm_names)
# between dask
for i in range(len(ddm_names)):
assert ddm_names[i] == ddqdm_names[i]
sig = OrderedDict(signature(xgb.DMatrix).parameters)
del sig["nthread"] # no nthread in dask
dm_names = list(sig.keys())
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters)
del sig["nthread"]
del sig["max_bin"]
dqdm_names = list(sig.keys())
# between single node
assert len(dm_names) == len(dqdm_names)
for i in range(len(dm_names)):
assert dm_names[i] == dqdm_names[i]
# ddm <-> dm
for i in range(len(ddm_names)):
assert ddm_names[i] == dm_names[i]
# dqdm <-> ddqdm
for i in range(len(ddqdm_names)):
assert ddqdm_names[i] == dqdm_names[i]
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")