Update base margin dask (#6155)

* Add `base-margin`
* Add `output_margin` to regressor.

Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
Kyle Nicholson 2020-09-26 09:30:52 -04:00 committed by GitHub
parent 03b8fdec74
commit e6a238c020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 160 additions and 48 deletions

View File

@ -337,14 +337,22 @@ class DaskDMatrix:
'is_quantile': self.is_quantile} 'is_quantile': self.is_quantile}
def _get_worker_x_ordered(worker_map, partition_order, worker): def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker):
list_of_parts = worker_map[worker.address] list_of_parts = worker_map[worker.address]
client = get_client() client = get_client()
list_of_parts_value = client.gather(list_of_parts) list_of_parts_value = client.gather(list_of_parts)
result = [] result = []
for i, part in enumerate(list_of_parts): for i, part in enumerate(list_of_parts):
result.append((list_of_parts_value[i][0], data = list_of_parts_value[i][0]
partition_order[part.key])) if has_base_margin:
base_margin = list_of_parts_value[i][1]
else:
base_margin = None
result.append((data, base_margin, partition_order[part.key]))
return result return result
@ -740,9 +748,7 @@ async def _direct_predict_impl(client, data, predict_fn):
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, missing=numpy.nan, async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
**kwargs):
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model booster = model
elif isinstance(model, dict): elif isinstance(model, dict):
@ -775,22 +781,30 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan,
feature_names = data.feature_names feature_names = data.feature_names
feature_types = data.feature_types feature_types = data.feature_types
missing = data.missing missing = data.missing
has_margin = "base_margin" in data.meta_names
def dispatched_predict(worker_id): def dispatched_predict(worker_id):
'''Perform prediction on each worker.''' '''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id) LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker() worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map, partition_order, list_of_parts = _get_worker_parts_ordered(
worker) has_margin, worker_map, partition_order, worker)
predictions = [] predictions = []
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts: for data, base_margin, order in list_of_parts:
local_x = DMatrix(part, feature_names=feature_names, local_part = DMatrix(
feature_types=feature_types, data,
missing=missing, nthread=worker.nthreads) base_margin=base_margin,
predt = booster.predict(data=local_x, feature_names=feature_names,
validate_features=local_x.num_row() != 0, feature_types=feature_types,
**kwargs) missing=missing,
nthread=worker.nthreads
)
predt = booster.predict(
data=local_part,
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1] columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order) ret = ((delayed(predt), columns), order)
predictions.append(ret) predictions.append(ret)
@ -800,9 +814,13 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan,
'''Get shape of data in each worker.''' '''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id) LOGGER.info('Get shape on %d', worker_id)
worker = distributed_get_worker() worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map, list_of_parts = _get_worker_parts_ordered(
partition_order, worker) False,
shapes = [(part.shape, order) for part, order in list_of_parts] worker_map,
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
return shapes return shapes
async def map_function(func): async def map_function(func):
@ -984,6 +1002,7 @@ class DaskScikitLearnBase(XGBModel):
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def fit(self, X, y, def fit(self, X, y,
sample_weights=None, sample_weights=None,
base_margin=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
@ -1044,12 +1063,14 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
X, X,
y, y,
sample_weights=None, sample_weights=None,
base_margin=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
dtrain = await DaskDMatrix(client=self.client, dtrain = await DaskDMatrix(
data=X, label=y, weight=sample_weights, client=self.client, data=X, label=y, weight=sample_weights,
missing=self.missing) base_margin=base_margin, missing=self.missing
)
params = self.get_xgb_params() params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client, evals = await _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set, eval_set, sample_weight_eval_set,
@ -1065,24 +1086,33 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
def fit(self, X, y, def fit(self, X, y,
sample_weights=None, sample_weights=None,
base_margin=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
_assert_dask_support() _assert_dask_support()
return self.client.sync(self._fit_async, X, y, sample_weights, return self.client.sync(
eval_set, sample_weight_eval_set, self._fit_async, X, y, sample_weights, base_margin,
verbose) eval_set, sample_weight_eval_set, verbose
)
async def _predict_async(self, data): # pylint: disable=arguments-differ async def _predict_async(
test_dmatrix = await DaskDMatrix(client=self.client, data=data, self, data, output_margin=False, base_margin=None):
missing=self.missing) test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client, pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix,
output_margin=output_margin)
return pred_probs return pred_probs
def predict(self, data): # pylint: disable=arguments-differ
def predict(self, data, output_margin=False, base_margin=None):
_assert_dask_support() _assert_dask_support()
return self.client.sync(self._predict_async, data) return self.client.sync(self._predict_async, data,
output_margin=output_margin,
base_margin=base_margin)
@xgboost_model_doc( @xgboost_model_doc(
@ -1092,11 +1122,13 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _fit_async(self, X, y, async def _fit_async(self, X, y,
sample_weights=None, sample_weights=None,
base_margin=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
dtrain = await DaskDMatrix(client=self.client, dtrain = await DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights, data=X, label=y, weight=sample_weights,
base_margin=base_margin,
missing=self.missing) missing=self.missing)
params = self.get_xgb_params() params = self.get_xgb_params()
@ -1126,33 +1158,46 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
def fit(self, X, y, def fit(self, X, y,
sample_weights=None, sample_weights=None,
base_margin=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
_assert_dask_support() _assert_dask_support()
return self.client.sync(self._fit_async, X, y, sample_weights, return self.client.sync(
eval_set, sample_weight_eval_set, verbose) self._fit_async, X, y, sample_weights, base_margin, eval_set,
sample_weight_eval_set, verbose
)
async def _predict_proba_async(self, data): async def _predict_proba_async(self, data, output_margin=False,
_assert_dask_support() base_margin=None):
test_dmatrix = await DaskDMatrix(
test_dmatrix = await DaskDMatrix(client=self.client, data=data, client=self.client, data=data, base_margin=base_margin,
missing=self.missing) missing=self.missing
)
pred_probs = await predict(client=self.client, pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin)
return pred_probs return pred_probs
def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support() _assert_dask_support()
return self.client.sync(self._predict_proba_async, data) return self.client.sync(
self._predict_proba_async,
data,
output_margin=output_margin,
base_margin=base_margin
)
async def _predict_async(self, data): async def _predict_async(self, data, output_margin=False, base_margin=None):
_assert_dask_support() test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
test_dmatrix = await DaskDMatrix(client=self.client, data=data, missing=self.missing
missing=self.missing) )
pred_probs = await predict(client=self.client, pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin)
if self.n_classes_ == 2: if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int) preds = (pred_probs > 0.5).astype(int)
@ -1161,6 +1206,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
return preds return preds
def predict(self, data): # pylint: disable=arguments-differ def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ
_assert_dask_support() _assert_dask_support()
return self.client.sync(self._predict_async, data) return self.client.sync(
self._predict_async,
data,
output_margin=output_margin,
base_margin=base_margin
)

View File

@ -133,6 +133,68 @@ def test_dask_predict_shape_infer():
assert preds.shape[1] == preds.compute().shape[1] assert preds.shape[1] == preds.compute().shape[1]
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_0.fit(X=X_, y=y_)
margin = model_0.predict_proba(X_, output_margin=True)
model_1 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_1.fit(X=X_, y=y_, base_margin=margin)
predictions_1 = model_1.predict(X_, base_margin=margin)
proba_1 = model_1.predict_proba(X_, base_margin=margin)
cls_2 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_2.fit(X=X_, y=y_)
predictions_2 = cls_2.predict(X_)
proba_2 = cls_2.predict_proba(X_)
cls_3 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_3.fit(X=X_, y=y_)
proba_3 = cls_3.predict_proba(X_)
# compute variance of probability percentages between two of the
# same model, use this to check to make sure approx is functioning
# within normal parameters
expected_variance = np.max(np.abs(proba_3 - proba_2)).compute()
if expected_variance > 0:
margin_variance = np.max(np.abs(proba_1 - proba_2)).compute()
# Ensure the margin variance is less than the expected variance + 10%
assert np.all(margin_variance <= expected_variance + .1)
else:
np.testing.assert_equal(predictions_1.compute(), predictions_2.compute())
np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute())
def test_dask_missing_value_reg(): def test_dask_missing_value_reg():
with LocalCluster(n_workers=kWorkers) as cluster: with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client: with Client(cluster) as client: