Update base margin dask (#6155)

* Add `base-margin`
* Add `output_margin` to regressor.

Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
Kyle Nicholson
2020-09-26 09:30:52 -04:00
committed by GitHub
parent 03b8fdec74
commit e6a238c020
2 changed files with 160 additions and 48 deletions

View File

@@ -133,6 +133,68 @@ def test_dask_predict_shape_infer():
assert preds.shape[1] == preds.compute().shape[1]
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_0.fit(X=X_, y=y_)
margin = model_0.predict_proba(X_, output_margin=True)
model_1 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=4,
tree_method=tree_method,
)
model_1.fit(X=X_, y=y_, base_margin=margin)
predictions_1 = model_1.predict(X_, base_margin=margin)
proba_1 = model_1.predict_proba(X_, base_margin=margin)
cls_2 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_2.fit(X=X_, y=y_)
predictions_2 = cls_2.predict(X_)
proba_2 = cls_2.predict_proba(X_)
cls_3 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
random_state=123,
n_estimators=8,
tree_method=tree_method,
)
cls_3.fit(X=X_, y=y_)
proba_3 = cls_3.predict_proba(X_)
# compute variance of probability percentages between two of the
# same model, use this to check to make sure approx is functioning
# within normal parameters
expected_variance = np.max(np.abs(proba_3 - proba_2)).compute()
if expected_variance > 0:
margin_variance = np.max(np.abs(proba_1 - proba_2)).compute()
# Ensure the margin variance is less than the expected variance + 10%
assert np.all(margin_variance <= expected_variance + .1)
else:
np.testing.assert_equal(predictions_1.compute(), predictions_2.compute())
np.testing.assert_almost_equal(proba_1.compute(), proba_2.compute())
def test_dask_missing_value_reg():
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client: