[dask] Support all parameters in regressor and classifier. (#6471)

* Add eval_metric.
* Add callback.
* Add feature weights.
* Add custom objective.
This commit is contained in:
Jiaming Yuan 2020-12-14 07:35:56 +08:00 committed by GitHub
parent c31e3efa7c
commit a30461cf87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 348 additions and 91 deletions

View File

@ -326,4 +326,3 @@ addressed yet:
- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
to encode their training labels into discrete values first.
- Ranking is not yet supported.
- Callback functions are not tested.

View File

@ -34,7 +34,7 @@ from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc
@ -47,8 +47,6 @@ distributed = LazyLoader('distributed', globals(), 'dask.distributed')
# not properly supported yet.
#
# TODOs:
# - Callback.
# - Label encoding.
# - CV
# - Ranking
#
@ -184,6 +182,8 @@ class DaskDMatrix:
Upper bound for survival training.
label_upper_bound : dask.array.Array/dask.dataframe.DataFrame
Lower bound for survival training.
feature_weights : dask.array.Array/dask.dataframe.DataFrame
Weight for features used in column sampling.
feature_names : list, optional
Set names for features.
feature_types : list, optional
@ -200,6 +200,7 @@ class DaskDMatrix:
base_margin=None,
label_lower_bound=None,
label_upper_bound=None,
feature_weights=None,
feature_names=None,
feature_types=None):
_assert_dask_support()
@ -227,6 +228,7 @@ class DaskDMatrix:
self._init = client.sync(self.map_local_data,
client, data, label=label, weights=weight,
base_margin=base_margin,
feature_weights=feature_weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
@ -234,7 +236,7 @@ class DaskDMatrix:
return self._init.__await__()
async def map_local_data(self, client, data, label=None, weights=None,
base_margin=None,
base_margin=None, feature_weights=None,
label_lower_bound=None, label_upper_bound=None):
'''Obtain references to local data.'''
@ -328,6 +330,11 @@ class DaskDMatrix:
self.worker_map = worker_map
self.meta_names = meta_names
if feature_weights is None:
self.feature_weights = None
else:
self.feature_weights = await client.compute(feature_weights).result()
return self
def create_fn_args(self, worker_addr: str):
@ -337,6 +344,7 @@ class DaskDMatrix:
'''
return {'feature_names': self.feature_names,
'feature_types': self.feature_types,
'feature_weights': self.feature_weights,
'meta_names': self.meta_names,
'missing': self.missing,
'parts': self.worker_map.get(worker_addr, None),
@ -518,6 +526,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
def _create_device_quantile_dmatrix(feature_names, feature_types,
feature_weights,
meta_names, missing, parts,
max_bin):
worker = distributed.get_worker()
@ -546,10 +555,12 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
feature_types=feature_types,
nthread=worker.nthreads,
max_bin=max_bin)
dmatrix.set_info(feature_weights=feature_weights)
return dmatrix
def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
def _create_dmatrix(feature_names, feature_types, feature_weights, meta_names, missing,
parts):
'''Get data that local to worker from DaskDMatrix.
Returns
@ -590,7 +601,8 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
nthread=worker.nthreads)
dmatrix.set_info(base_margin=base_margin, weight=weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
label_upper_bound=label_upper_bound,
feature_weights=feature_weights)
return dmatrix
@ -627,16 +639,15 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
async def _train_async(client,
global_config,
params,
dtrain: DaskDMatrix,
*args,
evals=(),
early_stopping_rounds=None,
**kwargs):
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')
dtrain,
num_boost_round,
evals,
obj,
feval,
early_stopping_rounds,
verbose_eval,
xgb_model,
callbacks):
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)
@ -668,11 +679,15 @@ async def _train_async(client,
local_param[p] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
**kwargs)
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
ret = {'booster': bst, 'history': local_history}
if local_dtrain.num_row() == 0:
ret = None
@ -703,8 +718,17 @@ async def _train_async(client,
return list(filter(lambda ret: ret is not None, results))[0]
def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
**kwargs):
def train(client,
params,
dtrain,
num_boost_round=10,
evals=(),
obj=None,
feval=None,
early_stopping_rounds=None,
xgb_model=None,
verbose_eval=True,
callbacks=None):
'''Train XGBoost model.
.. versionadded:: 1.0.0
@ -737,9 +761,19 @@ def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
# Get global configuration before transferring computation to another thread or
# process.
global_config = config.get_config()
return client.sync(
_train_async, client, global_config, params, dtrain=dtrain, *args, evals=evals,
early_stopping_rounds=early_stopping_rounds, **kwargs)
return client.sync(_train_async,
client=client,
global_config=global_config,
num_boost_round=num_boost_round,
obj=obj,
feval=feval,
params=params,
dtrain=dtrain,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
async def _direct_predict_impl(client, data, predict_fn):
@ -1030,10 +1064,13 @@ class DaskScikitLearnBase(XGBModel):
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
sample_weight_eval_set=None,
early_stopping_rounds=None,
verbose=True):
'''Fit the regressor.
verbose=True,
feature_weights=None,
callbacks=None):
'''Fit gradient boosting model
Parameters
----------
@ -1047,6 +1084,7 @@ class DaskScikitLearnBase(XGBModel):
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
of group weights on the i-th validation set.
@ -1054,7 +1092,23 @@ class DaskScikitLearnBase(XGBModel):
Activates early stopping.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.'''
metric measured on the validation set to stderr.
feature_weights: array_like
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)]
'''
raise NotImplementedError
def predict(self, data): # pylint: disable=arguments-differ
@ -1089,25 +1143,42 @@ class DaskScikitLearnBase(XGBModel):
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
# pylint: disable=missing-class-docstring
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set, early_stopping_rounds,
verbose):
eval_metric, sample_weight_eval_set,
early_stopping_rounds, verbose, feature_weights,
callbacks):
dtrain = await DaskDMatrix(client=self.client,
data=X,
label=y,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing)
params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
self.missing)
if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})
results = await train(client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds)
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks)
self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
@ -1122,9 +1193,12 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
sample_weight_eval_set=None,
early_stopping_rounds=None,
verbose=True):
verbose=True,
feature_weights=None,
callbacks=None):
_assert_dask_support()
return self.client.sync(self._fit_async,
X=X,
@ -1132,9 +1206,12 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose)
verbose=verbose,
feature_weights=feature_weights,
callbacks=callbacks)
async def _predict_async(
self, data, output_margin=False, base_margin=None):
@ -1161,13 +1238,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set, early_stopping_rounds,
verbose):
eval_metric, sample_weight_eval_set,
early_stopping_rounds, verbose, feature_weights,
callbacks):
dtrain = await DaskDMatrix(client=self.client,
data=X,
label=y,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing)
params = self.get_xgb_params()
@ -1187,13 +1266,28 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
self.missing)
if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})
results = await train(client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose)
callbacks=callbacks)
self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
@ -1207,9 +1301,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
sample_weight_eval_set=None,
early_stopping_rounds=None,
verbose=True):
verbose=True,
feature_weights=None,
callbacks=None):
_assert_dask_support()
return self.client.sync(self._fit_async,
X=X,
@ -1217,9 +1314,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose)
verbose=verbose,
feature_weights=feature_weights,
callbacks=callbacks)
async def _predict_proba_async(self, data, output_margin=False,
base_margin=None):

View File

@ -184,6 +184,43 @@ class TestDistributedGPU:
run_with_dask_array(dxgb.DaskDMatrix, client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
def test_early_stopping(self, local_cuda_cluster):
from sklearn.datasets import load_breast_cancer
import cupy
with Client(local_cuda_cluster) as client:
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
m = dxgb.DaskDMatrix(client, X, y)
valid = dxgb.DaskDMatrix(client, X, y)
early_stopping_rounds = 5
booster = dxgb.train(client, {'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'gpu_hist'}, m,
evals=[(valid, 'Valid')],
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)[
'booster']
assert hasattr(booster, 'best_score')
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
valid_X = X
valid_y = y
cls = dxgb.DaskXGBClassifier(objective='binary:logistic',
tree_method='gpu_hist',
n_estimators=100)
cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu

View File

@ -5,11 +5,13 @@ import sys
import numpy as np
import json
import asyncio
import tempfile
from sklearn.datasets import make_classification
import os
import subprocess
from hypothesis import given, settings, note
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
from test_with_sklearn import run_feature_weights
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -74,7 +76,7 @@ def test_from_dask_dataframe():
assert isinstance(prediction, da.Array)
assert prediction.shape[0] == kRows
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# evals_result is not supported in dask interface.
xgb.dask.train(
client, {}, dtrain, num_boost_round=2, evals_result={})
@ -815,44 +817,6 @@ class TestWithDask:
def test_quantile_same_on_all_workers(self):
self.run_quantile('SameOnAllWorkers')
class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
m = xgb.dask.DaskDMatrix(client, X, y)
early_stopping_rounds = 5
booster = xgb.dask.train(client, {'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'hist'}, m,
evals=[(m, 'Train')],
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)['booster']
assert hasattr(booster, 'best_score')
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping_custom_eval(self, client):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
m = xgb.dask.DaskDMatrix(client, X, y)
early_stopping_rounds = 5
booster = xgb.dask.train(
client, {'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'hist'}, m,
evals=[(m, 'Train')],
feval=tm.eval_error_metric,
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)['booster']
assert hasattr(booster, 'best_score')
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_n_workers(self):
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
@ -872,6 +836,67 @@ class TestDaskCallbacks:
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
assert len(merged) == 2
@pytest.mark.skipif(**tm.no_dask())
def test_feature_weights(self, client):
kRows = 1024
kCols = 64
X = da.random.random((kRows, kCols), chunks=(32, -1))
y = da.random.random(kRows, chunks=32)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)
fw = da.from_array(fw)
poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(kCols - i)
fw = da.from_array(fw)
poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor)
# Approxmated test, this is dependent on the implementation of random
# number generator in std library.
assert poly_increasing[0] > 0.08
assert poly_decreasing[0] < -0.08
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_sklearn())
def test_custom_objective(self, client):
from sklearn.datasets import load_boston
X, y = load_boston(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
rounds = 20
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'log')
def sqr(labels, predts):
with open(path, 'a') as fd:
print('Running sqr', file=fd)
grad = predts - labels
hess = np.ones(shape=labels.shape[0])
return grad, hess
reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, objective=sqr,
tree_method='hist')
reg.fit(X, y, eval_set=[(X, y)])
# Check the obj is ran for rounds.
with open(path, 'r') as fd:
out = fd.readlines()
assert len(out) == rounds
results_custom = reg.evals_result()
reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, tree_method='hist')
reg.fit(X, y, eval_set=[(X, y)])
results_native = reg.evals_result()
np.testing.assert_allclose(results_custom['validation_0']['rmse'],
results_native['validation_0']['rmse'])
tm.non_increasing(results_native['validation_0']['rmse'])
def test_data_initialization(self):
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
@ -912,3 +937,97 @@ class TestDaskCallbacks:
assert len(data) == cnt
# Subtract the on disk resource from each worker
assert cnt - n_workers == n_partitions
class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
m = xgb.dask.DaskDMatrix(client, X, y)
valid = xgb.dask.DaskDMatrix(client, X, y)
early_stopping_rounds = 5
booster = xgb.dask.train(client, {'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'hist'}, m,
evals=[(valid, 'Valid')],
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)['booster']
assert hasattr(booster, 'best_score')
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
valid_X, valid_y = load_breast_cancer(return_X_y=True)
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
n_estimators=1000)
cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
# Specify the metric
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
n_estimators=1000)
cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)], eval_metric='error')
assert tm.non_increasing(cls.evals_result()['validation_0']['error'])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(cls.evals_result()['validation_0']['error']) < 20
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping_custom_eval(self, client):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
m = xgb.dask.DaskDMatrix(client, X, y)
valid = xgb.dask.DaskDMatrix(client, X, y)
early_stopping_rounds = 5
booster = xgb.dask.train(
client, {'objective': 'binary:logistic',
'eval_metric': 'error',
'tree_method': 'hist'}, m,
evals=[(m, 'Train'), (valid, 'Valid')],
feval=tm.eval_error_metric,
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)['booster']
assert hasattr(booster, 'best_score')
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
valid_X, valid_y = load_breast_cancer(return_X_y=True)
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
n_estimators=1000)
cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
eval_set=[(valid_X, valid_y)], eval_metric=tm.eval_error_metric)
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_sklearn())
def test_callback(self, client):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X, y = da.from_array(X), da.from_array(y)
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist',
n_estimators=10)
cls.client = client
with tempfile.TemporaryDirectory() as tmpdir:
cls.fit(X, y, callbacks=[xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=1,
name='model')])
for i in range(1, 10):
assert os.path.exists(
os.path.join(tmpdir, 'model_' + str(i) + '.json'))

View File

@ -984,21 +984,10 @@ def test_pandas_input():
np.array([0, 1]))
def run_feature_weights(increasing):
def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
with TemporaryDirectory() as tmpdir:
kRows = 512
kCols = 64
colsample_bynode = 0.5
reg = xgb.XGBRegressor(tree_method='hist',
colsample_bynode=colsample_bynode)
X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
if increasing:
fw[i] *= float(i)
else:
fw[i] *= float(kCols - i)
reg = model(tree_method='hist', colsample_bynode=colsample_bynode)
reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, 'model.json')
@ -1034,8 +1023,21 @@ def run_feature_weights(increasing):
def test_feature_weights():
poly_increasing = run_feature_weights(True)
poly_decreasing = run_feature_weights(False)
kRows = 512
kCols = 64
X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)
poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(kCols - i)
poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
# Approxmated test, this is dependent on the implementation of random
# number generator in std library.
assert poly_increasing[0] > 0.08