[dask] Support all parameters in regressor and classifier. (#6471)
* Add eval_metric. * Add callback. * Add feature weights. * Add custom objective.
This commit is contained in:
parent
c31e3efa7c
commit
a30461cf87
@ -326,4 +326,3 @@ addressed yet:
|
||||
- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
|
||||
to encode their training labels into discrete values first.
|
||||
- Ranking is not yet supported.
|
||||
- Callback functions are not tested.
|
||||
|
||||
@ -34,7 +34,7 @@ from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
||||
from .core import _deprecate_positional_args
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker, get_host_ip
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
||||
from .sklearn import xgboost_model_doc
|
||||
|
||||
|
||||
@ -47,8 +47,6 @@ distributed = LazyLoader('distributed', globals(), 'dask.distributed')
|
||||
# not properly supported yet.
|
||||
#
|
||||
# TODOs:
|
||||
# - Callback.
|
||||
# - Label encoding.
|
||||
# - CV
|
||||
# - Ranking
|
||||
#
|
||||
@ -184,6 +182,8 @@ class DaskDMatrix:
|
||||
Upper bound for survival training.
|
||||
label_upper_bound : dask.array.Array/dask.dataframe.DataFrame
|
||||
Lower bound for survival training.
|
||||
feature_weights : dask.array.Array/dask.dataframe.DataFrame
|
||||
Weight for features used in column sampling.
|
||||
feature_names : list, optional
|
||||
Set names for features.
|
||||
feature_types : list, optional
|
||||
@ -200,6 +200,7 @@ class DaskDMatrix:
|
||||
base_margin=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_weights=None,
|
||||
feature_names=None,
|
||||
feature_types=None):
|
||||
_assert_dask_support()
|
||||
@ -227,6 +228,7 @@ class DaskDMatrix:
|
||||
self._init = client.sync(self.map_local_data,
|
||||
client, data, label=label, weights=weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
|
||||
@ -234,7 +236,7 @@ class DaskDMatrix:
|
||||
return self._init.__await__()
|
||||
|
||||
async def map_local_data(self, client, data, label=None, weights=None,
|
||||
base_margin=None,
|
||||
base_margin=None, feature_weights=None,
|
||||
label_lower_bound=None, label_upper_bound=None):
|
||||
'''Obtain references to local data.'''
|
||||
|
||||
@ -328,6 +330,11 @@ class DaskDMatrix:
|
||||
self.worker_map = worker_map
|
||||
self.meta_names = meta_names
|
||||
|
||||
if feature_weights is None:
|
||||
self.feature_weights = None
|
||||
else:
|
||||
self.feature_weights = await client.compute(feature_weights).result()
|
||||
|
||||
return self
|
||||
|
||||
def create_fn_args(self, worker_addr: str):
|
||||
@ -337,6 +344,7 @@ class DaskDMatrix:
|
||||
'''
|
||||
return {'feature_names': self.feature_names,
|
||||
'feature_types': self.feature_types,
|
||||
'feature_weights': self.feature_weights,
|
||||
'meta_names': self.meta_names,
|
||||
'missing': self.missing,
|
||||
'parts': self.worker_map.get(worker_addr, None),
|
||||
@ -518,6 +526,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
|
||||
|
||||
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
feature_weights,
|
||||
meta_names, missing, parts,
|
||||
max_bin):
|
||||
worker = distributed.get_worker()
|
||||
@ -546,10 +555,12 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=max_bin)
|
||||
dmatrix.set_info(feature_weights=feature_weights)
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
|
||||
def _create_dmatrix(feature_names, feature_types, feature_weights, meta_names, missing,
|
||||
parts):
|
||||
'''Get data that local to worker from DaskDMatrix.
|
||||
|
||||
Returns
|
||||
@ -590,7 +601,8 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
|
||||
nthread=worker.nthreads)
|
||||
dmatrix.set_info(base_margin=base_margin, weight=weights,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_weights=feature_weights)
|
||||
return dmatrix
|
||||
|
||||
|
||||
@ -627,16 +639,15 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
|
||||
async def _train_async(client,
|
||||
global_config,
|
||||
params,
|
||||
dtrain: DaskDMatrix,
|
||||
*args,
|
||||
evals=(),
|
||||
early_stopping_rounds=None,
|
||||
**kwargs):
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
dtrain,
|
||||
num_boost_round,
|
||||
evals,
|
||||
obj,
|
||||
feval,
|
||||
early_stopping_rounds,
|
||||
verbose_eval,
|
||||
xgb_model,
|
||||
callbacks):
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
|
||||
@ -668,11 +679,15 @@ async def _train_async(client,
|
||||
local_param[p] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
*args,
|
||||
num_boost_round=num_boost_round,
|
||||
evals_result=local_history,
|
||||
evals=local_evals,
|
||||
obj=obj,
|
||||
feval=feval,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
**kwargs)
|
||||
verbose_eval=verbose_eval,
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks)
|
||||
ret = {'booster': bst, 'history': local_history}
|
||||
if local_dtrain.num_row() == 0:
|
||||
ret = None
|
||||
@ -703,8 +718,17 @@ async def _train_async(client,
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
|
||||
**kwargs):
|
||||
def train(client,
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=10,
|
||||
evals=(),
|
||||
obj=None,
|
||||
feval=None,
|
||||
early_stopping_rounds=None,
|
||||
xgb_model=None,
|
||||
verbose_eval=True,
|
||||
callbacks=None):
|
||||
'''Train XGBoost model.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
@ -737,9 +761,19 @@ def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
|
||||
# Get global configuration before transferring computation to another thread or
|
||||
# process.
|
||||
global_config = config.get_config()
|
||||
return client.sync(
|
||||
_train_async, client, global_config, params, dtrain=dtrain, *args, evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds, **kwargs)
|
||||
return client.sync(_train_async,
|
||||
client=client,
|
||||
global_config=global_config,
|
||||
num_boost_round=num_boost_round,
|
||||
obj=obj,
|
||||
feval=feval,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
async def _direct_predict_impl(client, data, predict_fn):
|
||||
@ -1030,10 +1064,13 @@ class DaskScikitLearnBase(XGBModel):
|
||||
sample_weight=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
eval_metric=None,
|
||||
sample_weight_eval_set=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True):
|
||||
'''Fit the regressor.
|
||||
verbose=True,
|
||||
feature_weights=None,
|
||||
callbacks=None):
|
||||
'''Fit gradient boosting model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1047,6 +1084,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
eval_metric : str, list of str, or callable, optional
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
|
||||
of group weights on the i-th validation set.
|
||||
@ -1054,7 +1092,23 @@ class DaskScikitLearnBase(XGBModel):
|
||||
Activates early stopping.
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.'''
|
||||
metric measured on the validation set to stderr.
|
||||
feature_weights: array_like
|
||||
Weight for each feature, defines the probability of each feature being
|
||||
selected when colsample is being used. All values must be greater than 0,
|
||||
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
|
||||
`exact` tree methods.
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
It is possible to use predefined callbacks by using :ref:`callback_api`.
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)]
|
||||
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, data): # pylint: disable=arguments-differ
|
||||
@ -1089,25 +1143,42 @@ class DaskScikitLearnBase(XGBModel):
|
||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
|
||||
sample_weight_eval_set, early_stopping_rounds,
|
||||
verbose):
|
||||
eval_metric, sample_weight_eval_set,
|
||||
early_stopping_rounds, verbose, feature_weights,
|
||||
callbacks):
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X,
|
||||
label=y,
|
||||
weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
evals = await _evaluation_matrices(self.client, eval_set,
|
||||
sample_weight_eval_set,
|
||||
self.missing)
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
else:
|
||||
obj = None
|
||||
metric = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({"eval_metric": eval_metric})
|
||||
|
||||
results = await train(client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
feval=metric,
|
||||
obj=obj,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds)
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks)
|
||||
self._Booster = results['booster']
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
@ -1122,9 +1193,12 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
sample_weight=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
eval_metric=None,
|
||||
sample_weight_eval_set=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
feature_weights=None,
|
||||
callbacks=None):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._fit_async,
|
||||
X=X,
|
||||
@ -1132,9 +1206,12 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
eval_metric=eval_metric,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks)
|
||||
|
||||
async def _predict_async(
|
||||
self, data, output_margin=False, base_margin=None):
|
||||
@ -1161,13 +1238,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
|
||||
sample_weight_eval_set, early_stopping_rounds,
|
||||
verbose):
|
||||
eval_metric, sample_weight_eval_set,
|
||||
early_stopping_rounds, verbose, feature_weights,
|
||||
callbacks):
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X,
|
||||
label=y,
|
||||
weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
|
||||
@ -1187,13 +1266,28 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
evals = await _evaluation_matrices(self.client, eval_set,
|
||||
sample_weight_eval_set,
|
||||
self.missing)
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
else:
|
||||
obj = None
|
||||
metric = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({"eval_metric": eval_metric})
|
||||
|
||||
results = await train(client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
obj=obj,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose)
|
||||
callbacks=callbacks)
|
||||
self._Booster = results['booster']
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
@ -1207,9 +1301,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
sample_weight=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
eval_metric=None,
|
||||
sample_weight_eval_set=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
feature_weights=None,
|
||||
callbacks=None):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._fit_async,
|
||||
X=X,
|
||||
@ -1217,9 +1314,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
eval_metric=eval_metric,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks)
|
||||
|
||||
async def _predict_proba_async(self, data, output_margin=False,
|
||||
base_margin=None):
|
||||
|
||||
@ -184,6 +184,43 @@ class TestDistributedGPU:
|
||||
run_with_dask_array(dxgb.DaskDMatrix, client)
|
||||
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
def test_early_stopping(self, local_cuda_cluster):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
import cupy
|
||||
with Client(local_cuda_cluster) as client:
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
|
||||
m = dxgb.DaskDMatrix(client, X, y)
|
||||
|
||||
valid = dxgb.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = dxgb.train(client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'gpu_hist'}, m,
|
||||
evals=[(valid, 'Valid')],
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)[
|
||||
'booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
valid_X = X
|
||||
valid_y = y
|
||||
cls = dxgb.DaskXGBClassifier(objective='binary:logistic',
|
||||
tree_method='gpu_hist',
|
||||
n_estimators=100)
|
||||
cls.client = client
|
||||
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
|
||||
eval_set=[(valid_X, valid_y)])
|
||||
booster = cls.get_booster()
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
|
||||
@ -5,11 +5,13 @@ import sys
|
||||
import numpy as np
|
||||
import json
|
||||
import asyncio
|
||||
import tempfile
|
||||
from sklearn.datasets import make_classification
|
||||
import os
|
||||
import subprocess
|
||||
from hypothesis import given, settings, note
|
||||
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
||||
from test_with_sklearn import run_feature_weights
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@ -74,7 +76,7 @@ def test_from_dask_dataframe():
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
# evals_result is not supported in dask interface.
|
||||
xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
@ -815,44 +817,6 @@ class TestWithDask:
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = xgb.dask.train(client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'hist'}, m,
|
||||
evals=[(m, 'Train')],
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping_custom_eval(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = xgb.dask.train(
|
||||
client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'hist'}, m,
|
||||
evals=[(m, 'Train')],
|
||||
feval=tm.eval_error_metric,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
def test_n_workers(self):
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
@ -872,6 +836,67 @@ class TestDaskCallbacks:
|
||||
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
|
||||
assert len(merged) == 2
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_feature_weights(self, client):
|
||||
kRows = 1024
|
||||
kCols = 64
|
||||
|
||||
X = da.random.random((kRows, kCols), chunks=(32, -1))
|
||||
y = da.random.random(kRows, chunks=32)
|
||||
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(i)
|
||||
fw = da.from_array(fw)
|
||||
poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
|
||||
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(kCols - i)
|
||||
fw = da.from_array(fw)
|
||||
poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
|
||||
|
||||
# Approxmated test, this is dependent on the implementation of random
|
||||
# number generator in std library.
|
||||
assert poly_increasing[0] > 0.08
|
||||
assert poly_decreasing[0] < -0.08
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_custom_objective(self, client):
|
||||
from sklearn.datasets import load_boston
|
||||
X, y = load_boston(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
rounds = 20
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'log')
|
||||
|
||||
def sqr(labels, predts):
|
||||
with open(path, 'a') as fd:
|
||||
print('Running sqr', file=fd)
|
||||
grad = predts - labels
|
||||
hess = np.ones(shape=labels.shape[0])
|
||||
return grad, hess
|
||||
|
||||
reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, objective=sqr,
|
||||
tree_method='hist')
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
|
||||
# Check the obj is ran for rounds.
|
||||
with open(path, 'r') as fd:
|
||||
out = fd.readlines()
|
||||
assert len(out) == rounds
|
||||
|
||||
results_custom = reg.evals_result()
|
||||
|
||||
reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, tree_method='hist')
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
results_native = reg.evals_result()
|
||||
|
||||
np.testing.assert_allclose(results_custom['validation_0']['rmse'],
|
||||
results_native['validation_0']['rmse'])
|
||||
tm.non_increasing(results_native['validation_0']['rmse'])
|
||||
|
||||
def test_data_initialization(self):
|
||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||
@ -912,3 +937,97 @@ class TestDaskCallbacks:
|
||||
assert len(data) == cnt
|
||||
# Subtract the on disk resource from each worker
|
||||
assert cnt - n_workers == n_partitions
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
|
||||
valid = xgb.dask.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = xgb.dask.train(client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'hist'}, m,
|
||||
evals=[(valid, 'Valid')],
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
valid_X, valid_y = load_breast_cancer(return_X_y=True)
|
||||
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
|
||||
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
|
||||
n_estimators=1000)
|
||||
cls.client = client
|
||||
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
|
||||
eval_set=[(valid_X, valid_y)])
|
||||
booster = cls.get_booster()
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
# Specify the metric
|
||||
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
|
||||
n_estimators=1000)
|
||||
cls.client = client
|
||||
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
|
||||
eval_set=[(valid_X, valid_y)], eval_metric='error')
|
||||
assert tm.non_increasing(cls.evals_result()['validation_0']['error'])
|
||||
booster = cls.get_booster()
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(cls.evals_result()['validation_0']['error']) < 20
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping_custom_eval(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
|
||||
valid = xgb.dask.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = xgb.dask.train(
|
||||
client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'hist'}, m,
|
||||
evals=[(m, 'Train'), (valid, 'Valid')],
|
||||
feval=tm.eval_error_metric,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
valid_X, valid_y = load_breast_cancer(return_X_y=True)
|
||||
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
|
||||
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
|
||||
n_estimators=1000)
|
||||
cls.client = client
|
||||
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
|
||||
eval_set=[(valid_X, valid_y)], eval_metric=tm.eval_error_metric)
|
||||
booster = cls.get_booster()
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_callback(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
|
||||
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
|
||||
n_estimators=10)
|
||||
cls.client = client
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cls.fit(X, y, callbacks=[xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
||||
iterations=1,
|
||||
name='model')])
|
||||
for i in range(1, 10):
|
||||
assert os.path.exists(
|
||||
os.path.join(tmpdir, 'model_' + str(i) + '.json'))
|
||||
|
||||
@ -984,21 +984,10 @@ def test_pandas_input():
|
||||
np.array([0, 1]))
|
||||
|
||||
|
||||
def run_feature_weights(increasing):
|
||||
def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
kRows = 512
|
||||
kCols = 64
|
||||
colsample_bynode = 0.5
|
||||
reg = xgb.XGBRegressor(tree_method='hist',
|
||||
colsample_bynode=colsample_bynode)
|
||||
X = rng.randn(kRows, kCols)
|
||||
y = rng.randn(kRows)
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
if increasing:
|
||||
fw[i] *= float(i)
|
||||
else:
|
||||
fw[i] *= float(kCols - i)
|
||||
reg = model(tree_method='hist', colsample_bynode=colsample_bynode)
|
||||
|
||||
reg.fit(X, y, feature_weights=fw)
|
||||
model_path = os.path.join(tmpdir, 'model.json')
|
||||
@ -1034,8 +1023,21 @@ def run_feature_weights(increasing):
|
||||
|
||||
|
||||
def test_feature_weights():
|
||||
poly_increasing = run_feature_weights(True)
|
||||
poly_decreasing = run_feature_weights(False)
|
||||
kRows = 512
|
||||
kCols = 64
|
||||
X = rng.randn(kRows, kCols)
|
||||
y = rng.randn(kRows)
|
||||
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(i)
|
||||
poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
|
||||
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(kCols - i)
|
||||
poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
|
||||
|
||||
# Approxmated test, this is dependent on the implementation of random
|
||||
# number generator in std library.
|
||||
assert poly_increasing[0] > 0.08
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user