Fix multiclass auc with empty dataset. (#6947)
This commit is contained in:
@@ -323,6 +323,7 @@ def run_dask_classifier(
|
||||
y: xgb.dask._DaskCollection,
|
||||
w: xgb.dask._DaskCollection,
|
||||
model: str,
|
||||
tree_method: Optional[str],
|
||||
client: "Client",
|
||||
n_classes,
|
||||
) -> None:
|
||||
@@ -330,11 +331,11 @@ def run_dask_classifier(
|
||||
|
||||
if model == "boosting":
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric=metric
|
||||
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
|
||||
)
|
||||
else:
|
||||
classifier = xgb.dask.DaskXGBRFClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric=metric
|
||||
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
|
||||
)
|
||||
|
||||
assert classifier._estimator_type == "classifier"
|
||||
@@ -403,12 +404,12 @@ def run_dask_classifier(
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
run_dask_classifier(X, y, w, model, client, 10)
|
||||
run_dask_classifier(X, y, w, model, None, client, 10)
|
||||
|
||||
y_bin = y.copy()
|
||||
y_bin[y > 5] = 1.0
|
||||
y_bin[y <= 5] = 0.0
|
||||
run_dask_classifier(X, y_bin, w, model, client, 2)
|
||||
run_dask_classifier(X, y_bin, w, model, None, client, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@@ -574,22 +575,26 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
|
||||
# multiclass
|
||||
X_, y_ = make_classification(
|
||||
n_samples=n_samples,
|
||||
n_classes=10,
|
||||
n_classes=n_workers,
|
||||
n_informative=n_features,
|
||||
n_redundant=0,
|
||||
n_repeated=0
|
||||
)
|
||||
for i in range(y_.shape[0]):
|
||||
y_[i] = i % n_workers
|
||||
X = dd.from_array(X_, chunksize=10)
|
||||
y = dd.from_array(y_, chunksize=10)
|
||||
|
||||
n_samples = n_workers - 1
|
||||
valid_X_, valid_y_ = make_classification(
|
||||
n_samples=n_samples,
|
||||
n_classes=10,
|
||||
n_classes=n_workers,
|
||||
n_informative=n_features,
|
||||
n_redundant=0,
|
||||
n_repeated=0
|
||||
)
|
||||
for i in range(valid_y_.shape[0]):
|
||||
valid_y_[i] = i % n_workers
|
||||
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
|
||||
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
|
||||
|
||||
@@ -600,9 +605,9 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
|
||||
|
||||
|
||||
def test_empty_dmatrix_auc() -> None:
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with LocalCluster(n_workers=8) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_empty_dmatrix_auc(client, "hist", 2)
|
||||
run_empty_dmatrix_auc(client, "hist", 8)
|
||||
|
||||
|
||||
def run_auc(client: "Client", tree_method: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user