[dask] Add type hints. (#6519)
* Add validate_features. * Show type hints in doc. Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import testing as tm
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import json
|
||||
from typing import List, Tuple, Union, Dict, Optional, Callable, Type
|
||||
import asyncio
|
||||
import tempfile
|
||||
from sklearn.datasets import make_classification
|
||||
@@ -19,56 +22,46 @@ if tm.no_dask()['condition']:
|
||||
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
||||
|
||||
|
||||
try:
|
||||
from distributed import LocalCluster, Client, get_client
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
from xgboost.dask import DaskDMatrix
|
||||
import dask
|
||||
except ImportError:
|
||||
LocalCluster = None
|
||||
Client = None
|
||||
get_client = None
|
||||
client = None
|
||||
loop = None
|
||||
cluster_fixture = None
|
||||
dd = None
|
||||
da = None
|
||||
DaskDMatrix = None
|
||||
dask = None
|
||||
from distributed import LocalCluster, Client, get_client
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
from xgboost.dask import DaskDMatrix
|
||||
|
||||
|
||||
kRows = 1000
|
||||
kCols = 10
|
||||
kWorkers = 5
|
||||
|
||||
|
||||
def _get_client_workers(client):
|
||||
def _get_client_workers(client: "Client") -> Dict[str, Dict]:
|
||||
workers = client.scheduler_info()['workers']
|
||||
return workers
|
||||
|
||||
|
||||
def generate_array(with_weights=False):
|
||||
def generate_array(
|
||||
with_weights: bool = False
|
||||
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
||||
Optional[xgb.dask._DaskCollection]]:
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, kCols), partition_size)
|
||||
y = da.random.random(kRows, partition_size)
|
||||
if with_weights:
|
||||
w = da.random.random(kRows, partition_size)
|
||||
return X, y, w
|
||||
return X, y
|
||||
return X, y, None
|
||||
|
||||
|
||||
def test_from_dask_dataframe():
|
||||
def test_from_dask_dataframe() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)['booster']
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
|
||||
@@ -78,7 +71,7 @@ def test_from_dask_dataframe():
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# evals_result is not supported in dask interface.
|
||||
xgb.dask.train(
|
||||
xgb.dask.train( # type:ignore
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
# force prediction to be computed
|
||||
from_dmatrix = prediction.compute()
|
||||
@@ -96,10 +89,10 @@ def test_from_dask_dataframe():
|
||||
from_dmatrix)
|
||||
|
||||
|
||||
def test_from_dask_array():
|
||||
def test_from_dask_array() -> None:
|
||||
with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
# results is {'booster': Booster, 'history': {...}}
|
||||
result = xgb.dask.train(client, {}, dtrain)
|
||||
@@ -111,7 +104,7 @@ def test_from_dask_array():
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
|
||||
booster = result['booster']
|
||||
booster: xgb.Booster = result['booster']
|
||||
single_node_predt = booster.predict(
|
||||
xgb.DMatrix(X.compute())
|
||||
)
|
||||
@@ -127,7 +120,7 @@ def test_from_dask_array():
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
|
||||
|
||||
def test_dask_predict_shape_infer():
|
||||
def test_dask_predict_shape_infer() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||
@@ -148,7 +141,7 @@ def test_dask_predict_shape_infer():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_boost_from_prediction(tree_method):
|
||||
def test_boost_from_prediction(tree_method: str) -> None:
|
||||
if tree_method == 'approx':
|
||||
pytest.xfail(reason='test_boost_from_prediction[approx] is flaky')
|
||||
|
||||
@@ -212,7 +205,7 @@ def test_boost_from_prediction(tree_method):
|
||||
np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute())
|
||||
|
||||
|
||||
def test_dask_missing_value_reg():
|
||||
def test_dask_missing_value_reg() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((20 // 2, kCols))
|
||||
@@ -236,7 +229,7 @@ def test_dask_missing_value_reg():
|
||||
np.testing.assert_allclose(np_predt, dd_predt)
|
||||
|
||||
|
||||
def test_dask_missing_value_cls():
|
||||
def test_dask_missing_value_cls() -> None:
|
||||
with LocalCluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((kRows // 2, kCols))
|
||||
@@ -263,7 +256,7 @@ def test_dask_missing_value_cls():
|
||||
assert hasattr(cls, 'missing')
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
def test_dask_regressor() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
@@ -285,7 +278,7 @@ def test_dask_regressor():
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
|
||||
|
||||
def test_dask_classifier():
|
||||
def test_dask_classifier() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
@@ -335,11 +328,11 @@ def test_dask_classifier():
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn_grid_search():
|
||||
def test_sklearn_grid_search() -> None:
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||
tree_method='hist')
|
||||
reg.client = client
|
||||
@@ -353,7 +346,7 @@ def test_sklearn_grid_search():
|
||||
assert len(means) == len(set(means))
|
||||
|
||||
|
||||
def test_empty_dmatrix_training_continuation(client):
|
||||
def test_empty_dmatrix_training_continuation(client: "Client") -> None:
|
||||
kRows, kCols = 1, 97
|
||||
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||
y = dd.from_array(np.random.rand(kRows))
|
||||
@@ -377,8 +370,8 @@ def test_empty_dmatrix_training_continuation(client):
|
||||
assert xgb.dask.predict(client, out, dtrain).compute().shape[0] == 1
|
||||
|
||||
|
||||
def run_empty_dmatrix_reg(client, parameters):
|
||||
def _check_outputs(out, predictions):
|
||||
def run_empty_dmatrix_reg(client: "Client", parameters: dict) -> None:
|
||||
def _check_outputs(out: xgb.dask.TrainReturnT, predictions: np.ndarray) -> None:
|
||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||
assert len(out['history']['validation']['rmse']) == 2
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
@@ -426,10 +419,10 @@ def run_empty_dmatrix_reg(client, parameters):
|
||||
_check_outputs(out, predictions)
|
||||
|
||||
|
||||
def run_empty_dmatrix_cls(client, parameters):
|
||||
def run_empty_dmatrix_cls(client: "Client", parameters: dict) -> None:
|
||||
n_classes = 4
|
||||
|
||||
def _check_outputs(out, predictions):
|
||||
def _check_outputs(out: xgb.dask.TrainReturnT, predictions: np.ndarray) -> None:
|
||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||
assert len(out['history']['validation']['merror']) == 2
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
@@ -472,7 +465,7 @@ def run_empty_dmatrix_cls(client, parameters):
|
||||
# No test for Exact, as empty DMatrix handling are mostly for distributed
|
||||
# environment and Exact doesn't support it.
|
||||
|
||||
def test_empty_dmatrix_hist():
|
||||
def test_empty_dmatrix_hist() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'hist'}
|
||||
@@ -480,7 +473,7 @@ def test_empty_dmatrix_hist():
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
|
||||
def test_empty_dmatrix_approx():
|
||||
def test_empty_dmatrix_approx() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'approx'}
|
||||
@@ -488,9 +481,9 @@ def test_empty_dmatrix_approx():
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address):
|
||||
async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainReturnT:
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
m = await DaskDMatrix(client, X, y)
|
||||
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||
|
||||
@@ -510,9 +503,9 @@ async def run_from_dask_array_asyncio(scheduler_address):
|
||||
return output
|
||||
|
||||
|
||||
async def run_dask_regressor_asyncio(scheduler_address):
|
||||
async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
|
||||
n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
@@ -532,9 +525,9 @@ async def run_dask_regressor_asyncio(scheduler_address):
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
|
||||
|
||||
async def run_dask_classifier_asyncio(scheduler_address):
|
||||
async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = await xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric='merror')
|
||||
@@ -574,7 +567,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
|
||||
def test_with_asyncio():
|
||||
def test_with_asyncio() -> None:
|
||||
with LocalCluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
address = client.scheduler.address
|
||||
@@ -586,10 +579,10 @@ def test_with_asyncio():
|
||||
asyncio.run(run_dask_classifier_asyncio(address))
|
||||
|
||||
|
||||
def test_predict():
|
||||
def test_predict() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
@@ -610,13 +603,14 @@ def test_predict():
|
||||
assert shap.shape[1] == kCols + 1
|
||||
|
||||
|
||||
def test_predict_with_meta(client):
|
||||
def test_predict_with_meta(client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
assert w is not None
|
||||
partition_size = 20
|
||||
margin = da.random.random(kRows, partition_size) + 1e4
|
||||
|
||||
dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
|
||||
booster = xgb.dask.train(
|
||||
booster: xgb.Booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=4)['booster']
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
@@ -632,7 +626,7 @@ def test_predict_with_meta(client):
|
||||
assert np.all(prediction == single)
|
||||
|
||||
|
||||
def run_aft_survival(client, dmatrix_t):
|
||||
def run_aft_survival(client: "Client", dmatrix_t: Type) -> None:
|
||||
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
||||
'veterans_lung_cancer.csv'))
|
||||
y_lower_bound = df['Survival_label_lower_bound']
|
||||
@@ -669,39 +663,43 @@ def run_aft_survival(client, dmatrix_t):
|
||||
assert nloglik_rec['extreme'][-1] > 4.9
|
||||
|
||||
|
||||
def test_aft_survival():
|
||||
def test_aft_survival() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_aft_survival(client, DaskDMatrix)
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_global_config(self, client):
|
||||
X, y = generate_array()
|
||||
def test_global_config(self, client: "Client") -> None:
|
||||
X, y, _ = generate_array()
|
||||
xgb.config.set_config(verbosity=0)
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
before_fname = './before_training-test_global_config'
|
||||
after_fname = './after_training-test_global_config'
|
||||
|
||||
class TestCallback(xgb.callback.TrainingCallback):
|
||||
def write_file(self, fname):
|
||||
def write_file(self, fname: str) -> None:
|
||||
with open(fname, 'w') as fd:
|
||||
fd.write(str(xgb.config.get_config()['verbosity']))
|
||||
|
||||
def before_training(self, model):
|
||||
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
||||
self.write_file(before_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, evals_log):
|
||||
def before_iteration(
|
||||
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
||||
) -> bool:
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
||||
) -> bool:
|
||||
self.write_file(after_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return False
|
||||
@@ -716,8 +714,14 @@ class TestWithDask:
|
||||
os.remove(before_fname)
|
||||
os.remove(after_fname)
|
||||
|
||||
def run_updater_test(self, client, params, num_rounds, dataset,
|
||||
tree_method):
|
||||
def run_updater_test(
|
||||
self,
|
||||
client: "Client",
|
||||
params: Dict,
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
tree_method: str
|
||||
) -> None:
|
||||
params['tree_method'] = tree_method
|
||||
params = dataset.set_params(params)
|
||||
# It doesn't make sense to distribute a completely
|
||||
@@ -748,22 +752,26 @@ class TestWithDask:
|
||||
@given(params=hist_parameter_strategy,
|
||||
dataset=tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_hist(self, params, dataset, client):
|
||||
def test_hist(
|
||||
self, params: Dict, dataset: tm.TestDataset, client: "Client"
|
||||
) -> None:
|
||||
num_rounds = 30
|
||||
self.run_updater_test(client, params, num_rounds, dataset, 'hist')
|
||||
|
||||
@given(params=exact_parameter_strategy,
|
||||
dataset=tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_approx(self, client, params, dataset):
|
||||
def test_approx(
|
||||
self, client: "Client", params: Dict, dataset: tm.TestDataset
|
||||
) -> None:
|
||||
num_rounds = 30
|
||||
self.run_updater_test(client, params, num_rounds, dataset, 'approx')
|
||||
|
||||
def run_quantile(self, name):
|
||||
def run_quantile(self, name: str) -> None:
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows")
|
||||
|
||||
exe = None
|
||||
exe: Optional[str] = None
|
||||
for possible_path in {'./testxgboost', './build/testxgboost',
|
||||
'../build/testxgboost',
|
||||
'../cpu-build/testxgboost'}:
|
||||
@@ -774,16 +782,16 @@ class TestWithDask:
|
||||
|
||||
test = "--gtest_filter=Quantile." + name
|
||||
|
||||
def runit(worker_addr, rabit_args):
|
||||
port = None
|
||||
def runit(worker_addr: str, rabit_args: List[bytes]) -> subprocess.CompletedProcess:
|
||||
port_env = ''
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port = arg.decode('utf-8')
|
||||
port = port.split('=')
|
||||
port_env = arg.decode('utf-8')
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
return subprocess.run([exe, test], env=env, capture_output=True)
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
@@ -804,20 +812,20 @@ class TestWithDask:
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_basic(self):
|
||||
def test_quantile_basic(self) -> None:
|
||||
self.run_quantile('DistributedBasic')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile(self):
|
||||
def test_quantile(self) -> None:
|
||||
self.run_quantile('Distributed')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
def test_quantile_same_on_all_workers(self) -> None:
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
|
||||
def test_n_workers(self):
|
||||
def test_n_workers(self) -> None:
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
@@ -837,7 +845,7 @@ class TestWithDask:
|
||||
assert len(merged) == 2
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_feature_weights(self, client):
|
||||
def test_feature_weights(self, client: "Client") -> None:
|
||||
kRows = 1024
|
||||
kCols = 64
|
||||
|
||||
@@ -863,7 +871,7 @@ class TestWithDask:
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_custom_objective(self, client):
|
||||
def test_custom_objective(self, client: "Client") -> None:
|
||||
from sklearn.datasets import load_boston
|
||||
X, y = load_boston(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
@@ -872,7 +880,7 @@ class TestWithDask:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'log')
|
||||
|
||||
def sqr(labels, predts):
|
||||
def sqr(labels: np.ndarray, predts: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
with open(path, 'a') as fd:
|
||||
print('Running sqr', file=fd)
|
||||
grad = predts - labels
|
||||
@@ -898,21 +906,21 @@ class TestWithDask:
|
||||
results_native['validation_0']['rmse'])
|
||||
tm.non_increasing(results_native['validation_0']['rmse'])
|
||||
|
||||
def test_data_initialization(self):
|
||||
def test_data_initialization(self) -> None:
|
||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||
generate unnecessary copies of data.
|
||||
|
||||
'''
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, _ = generate_array()
|
||||
n_partitions = X.npartitions
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr, data_ref):
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
@@ -941,7 +949,7 @@ class TestWithDask:
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client):
|
||||
def test_early_stopping(self, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
@@ -983,7 +991,7 @@ class TestDaskCallbacks:
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping_custom_eval(self, client):
|
||||
def test_early_stopping_custom_eval(self, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
@@ -1015,7 +1023,7 @@ class TestDaskCallbacks:
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_callback(self, client):
|
||||
def test_callback(self, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
@@ -1025,9 +1033,11 @@ class TestDaskCallbacks:
|
||||
cls.client = client
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cls.fit(X, y, callbacks=[xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
||||
iterations=1,
|
||||
name='model')])
|
||||
cls.fit(X, y, callbacks=[xgb.callback.TrainingCheckPoint(
|
||||
directory=Path(tmpdir),
|
||||
iterations=1,
|
||||
name='model'
|
||||
)])
|
||||
for i in range(1, 10):
|
||||
assert os.path.exists(
|
||||
os.path.join(tmpdir, 'model_' + str(i) + '.json'))
|
||||
|
||||
Reference in New Issue
Block a user