Fix multiclass auc with empty dataset. (#6947)

This commit is contained in:
Jiaming Yuan
2021-05-12 15:01:14 +08:00
committed by GitHub
parent 05ac415780
commit 44cc9c04ea
5 changed files with 83 additions and 48 deletions

View File

@@ -323,6 +323,7 @@ def run_dask_classifier(
y: xgb.dask._DaskCollection,
w: xgb.dask._DaskCollection,
model: str,
tree_method: Optional[str],
client: "Client",
n_classes,
) -> None:
@@ -330,11 +331,11 @@ def run_dask_classifier(
if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric=metric
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
)
else:
classifier = xgb.dask.DaskXGBRFClassifier(
verbosity=1, n_estimators=2, eval_metric=metric
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
)
assert classifier._estimator_type == "classifier"
@@ -403,12 +404,12 @@ def run_dask_classifier(
def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
run_dask_classifier(X, y, w, model, client, 10)
run_dask_classifier(X, y, w, model, None, client, 10)
y_bin = y.copy()
y_bin[y > 5] = 1.0
y_bin[y <= 5] = 0.0
run_dask_classifier(X, y_bin, w, model, client, 2)
run_dask_classifier(X, y_bin, w, model, None, client, 2)
@pytest.mark.skipif(**tm.no_sklearn())
@@ -574,22 +575,26 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
# multiclass
X_, y_ = make_classification(
n_samples=n_samples,
n_classes=10,
n_classes=n_workers,
n_informative=n_features,
n_redundant=0,
n_repeated=0
)
for i in range(y_.shape[0]):
y_[i] = i % n_workers
X = dd.from_array(X_, chunksize=10)
y = dd.from_array(y_, chunksize=10)
n_samples = n_workers - 1
valid_X_, valid_y_ = make_classification(
n_samples=n_samples,
n_classes=10,
n_classes=n_workers,
n_informative=n_features,
n_redundant=0,
n_repeated=0
)
for i in range(valid_y_.shape[0]):
valid_y_[i] = i % n_workers
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
@@ -600,9 +605,9 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
def test_empty_dmatrix_auc() -> None:
with LocalCluster(n_workers=2) as cluster:
with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
run_empty_dmatrix_auc(client, "hist", 2)
run_empty_dmatrix_auc(client, "hist", 8)
def run_auc(client: "Client", tree_method: str) -> None: