* Use `output_margin` for `softmax`. * Add test for dask binary cls. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
357a78b3de
commit
c6a0bdbb5a
@ -21,6 +21,7 @@ from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from threading import Thread
|
||||
from functools import partial, update_wrapper
|
||||
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
|
||||
from typing import Awaitable, Generator, TypeVar
|
||||
|
||||
@ -967,7 +968,7 @@ def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
||||
return is_df and len(output_shape) <= 2
|
||||
|
||||
|
||||
async def _direct_predict_impl(
|
||||
async def _direct_predict_impl( # pylint: disable=too-many-branches
|
||||
mapped_predict: Callable,
|
||||
booster: "distributed.Future",
|
||||
data: _DaskCollection,
|
||||
@ -1022,6 +1023,14 @@ async def _direct_predict_impl(
|
||||
new_axis = list(range(len(output_shape) - 2))
|
||||
else:
|
||||
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||
if len(output_shape) == 2:
|
||||
# Somehow dask fail to infer output shape change for 2-dim prediction, and
|
||||
# `chunks = (None, output_shape[1])` doesn't work due to None is not
|
||||
# supported in map_blocks.
|
||||
chunks = list(data.chunks)
|
||||
chunks[1] = (output_shape[1], )
|
||||
else:
|
||||
chunks = None
|
||||
predictions = da.map_blocks(
|
||||
mapped_predict,
|
||||
booster,
|
||||
@ -1029,6 +1038,8 @@ async def _direct_predict_impl(
|
||||
False,
|
||||
columns,
|
||||
base_margin_array,
|
||||
|
||||
chunks=chunks,
|
||||
drop_axis=drop_axis,
|
||||
new_axis=new_axis,
|
||||
dtype=numpy.float32,
|
||||
@ -1776,20 +1787,20 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
validate_features: bool,
|
||||
output_margin: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
) -> _DaskCollection:
|
||||
if iteration_range is None:
|
||||
iteration_range = (0, 0)
|
||||
predts = await super()._predict_async(
|
||||
data=X,
|
||||
output_margin=output_margin,
|
||||
output_margin=self.objective == "multi:softmax",
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
return _cls_predict_proba(self.objective, predts, da.vstack)
|
||||
vstack = update_wrapper(
|
||||
partial(da.vstack, allow_unknown_chunksizes=True), da.vstack
|
||||
)
|
||||
return _cls_predict_proba(getattr(self, "n_classes_", None), predts, vstack)
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
def predict_proba(
|
||||
@ -1797,7 +1808,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
X: _DaskCollection,
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
output_margin: bool = False,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
@ -1808,7 +1818,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
self._predict_proba_async,
|
||||
X=X,
|
||||
validate_features=validate_features,
|
||||
output_margin=output_margin,
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import copy
|
||||
import warnings
|
||||
import json
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar
|
||||
import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
from .core import _deprecate_positional_args, _convert_ntree_limit
|
||||
@ -561,6 +561,8 @@ class XGBModel(XGBModelBase):
|
||||
self._Booster.load_model(fname)
|
||||
meta = self._Booster.attr('scikit_learn')
|
||||
if meta is None:
|
||||
# FIXME(jiaming): This doesn't have to be a problem as most of the needed
|
||||
# information like num_class and objective is in Learner class.
|
||||
warnings.warn(
|
||||
'Loading a native XGBoost model with Scikit-Learn interface.')
|
||||
return
|
||||
@ -571,6 +573,8 @@ class XGBModel(XGBModelBase):
|
||||
self._le = XGBoostLabelEncoder()
|
||||
self._le.from_json(v)
|
||||
continue
|
||||
# FIXME(jiaming): This can be removed once label encoder is gone since we can
|
||||
# generate it from `np.arange(self.n_classes_)`
|
||||
if k == 'classes_':
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
@ -1024,17 +1028,14 @@ class XGBModel(XGBModelBase):
|
||||
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
||||
|
||||
|
||||
def _cls_predict_proba(
|
||||
objective: Union[str, Callable], prediction: Any, vstack: Callable
|
||||
) -> Any:
|
||||
if objective == 'multi:softmax':
|
||||
raise ValueError('multi:softmax objective does not support predict_proba,'
|
||||
' use `multi:softprob` or `binary:logistic` instead.')
|
||||
if objective == 'multi:softprob' or callable(objective):
|
||||
# Return prediction directly if if objective is defined by user since we don't
|
||||
# know how to perform the transformation
|
||||
PredtT = TypeVar("PredtT")
|
||||
|
||||
|
||||
def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> PredtT:
|
||||
assert len(prediction.shape) <= 2
|
||||
if len(prediction.shape) == 2 and prediction.shape[1] == n_classes:
|
||||
return prediction
|
||||
# Lastly the binary logistic function
|
||||
# binary logistic function
|
||||
classone_probs = prediction
|
||||
classzero_probs = 1.0 - classone_probs
|
||||
return vstack((classzero_probs, classone_probs)).transpose()
|
||||
@ -1218,8 +1219,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
return class_probs
|
||||
|
||||
if len(class_probs.shape) > 1:
|
||||
# turns softprob into softmax
|
||||
column_indexes = np.argmax(class_probs, axis=1)
|
||||
else:
|
||||
# turns soft logit into class label
|
||||
column_indexes = np.repeat(0, class_probs.shape[0])
|
||||
column_indexes[class_probs > 0.5] = 1
|
||||
|
||||
@ -1262,15 +1265,23 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
a numpy array of shape array-like of shape (n_samples, n_classes) with the
|
||||
probability of each data example being of a given class.
|
||||
"""
|
||||
# custom obj: Do nothing as we don't know what to do.
|
||||
# softprob: Do nothing, output is proba.
|
||||
# softmax: Use output margin to remove the argmax in PredTransform.
|
||||
# binary:logistic: Expand the prob vector into 2-class matrix after predict.
|
||||
# binary:logitraw: Unsupported by predict_proba()
|
||||
class_probs = super().predict(
|
||||
X=X,
|
||||
output_margin=False,
|
||||
output_margin=self.objective == "multi:softmax",
|
||||
ntree_limit=ntree_limit,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range
|
||||
)
|
||||
return _cls_predict_proba(self.objective, class_probs, np.vstack)
|
||||
# If model is loaded from a raw booster there's no `n_classes_`
|
||||
return _cls_predict_proba(
|
||||
getattr(self, "n_classes_", None), class_probs, np.vstack
|
||||
)
|
||||
|
||||
def evals_result(self):
|
||||
"""Return the evaluation results.
|
||||
|
||||
@ -173,6 +173,7 @@ def run_gpu_hist(
|
||||
assert tm.non_increasing(history["train"][dataset.metric])
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
with Client(local_cuda_cluster) as client:
|
||||
@ -201,6 +202,7 @@ class TestDistributedGPU:
|
||||
@settings(deadline=duration(seconds=120), suppress_health_check=suppress)
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.parametrize(
|
||||
"local_cuda_cluster", [{"n_workers": 2}], indirect=["local_cuda_cluster"]
|
||||
)
|
||||
@ -275,7 +277,7 @@ class TestDistributedGPU:
|
||||
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
|
||||
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
|
||||
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
|
||||
run_dask_classifier(X, y, w, model, client)
|
||||
run_dask_classifier(X, y, w, model, client, 10)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@ -453,6 +455,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_with_asyncio(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
with Client(local_cuda_cluster) as client:
|
||||
|
||||
@ -318,14 +318,17 @@ def run_dask_classifier(
|
||||
w: xgb.dask._DaskCollection,
|
||||
model: str,
|
||||
client: "Client",
|
||||
n_classes,
|
||||
) -> None:
|
||||
metric = "merror" if n_classes > 2 else "logloss"
|
||||
|
||||
if model == "boosting":
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric="merror"
|
||||
verbosity=1, n_estimators=2, eval_metric=metric
|
||||
)
|
||||
else:
|
||||
classifier = xgb.dask.DaskXGBRFClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric="merror"
|
||||
verbosity=1, n_estimators=2, eval_metric=metric
|
||||
)
|
||||
|
||||
assert classifier._estimator_type == "classifier"
|
||||
@ -343,7 +346,7 @@ def run_dask_classifier(
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == "validation_0"
|
||||
assert list(history["validation_0"].keys())[0] == "merror"
|
||||
assert list(history["validation_0"].keys())[0] == metric
|
||||
assert len(list(history["validation_0"])) == 1
|
||||
forest = int(
|
||||
json.loads(classifier.get_booster().save_config())["learner"][
|
||||
@ -351,34 +354,35 @@ def run_dask_classifier(
|
||||
]["gbtree_train_param"]["num_parallel_tree"]
|
||||
)
|
||||
if model == "boosting":
|
||||
assert len(history["validation_0"]["merror"]) == 2
|
||||
assert len(history["validation_0"][metric]) == 2
|
||||
assert forest == 1
|
||||
else:
|
||||
assert len(history["validation_0"]["merror"]) == 1
|
||||
assert len(history["validation_0"][metric]) == 1
|
||||
assert forest == 2
|
||||
|
||||
# Test .predict_proba()
|
||||
probas = classifier.predict_proba(X).compute()
|
||||
assert classifier.n_classes_ == 10
|
||||
assert classifier.n_classes_ == n_classes
|
||||
assert probas.ndim == 2
|
||||
assert probas.shape[0] == kRows
|
||||
assert probas.shape[1] == 10
|
||||
assert probas.shape[1] == n_classes
|
||||
|
||||
cls_booster = classifier.get_booster()
|
||||
single_node_proba = cls_booster.inplace_predict(X.compute())
|
||||
if n_classes > 2:
|
||||
cls_booster = classifier.get_booster()
|
||||
single_node_proba = cls_booster.inplace_predict(X.compute())
|
||||
|
||||
# test shared by CPU and GPU
|
||||
if isinstance(single_node_proba, np.ndarray):
|
||||
np.testing.assert_allclose(single_node_proba, probas)
|
||||
else:
|
||||
import cupy
|
||||
cupy.testing.assert_allclose(single_node_proba, probas)
|
||||
# test shared by CPU and GPU
|
||||
if isinstance(single_node_proba, np.ndarray):
|
||||
np.testing.assert_allclose(single_node_proba, probas)
|
||||
else:
|
||||
import cupy
|
||||
cupy.testing.assert_allclose(single_node_proba, probas)
|
||||
|
||||
# Test with dataframe, not shared with GPU as cupy doesn't work well with da.unique.
|
||||
if isinstance(X, da.Array):
|
||||
if isinstance(X, da.Array) and n_classes > 2:
|
||||
X_d: dd.DataFrame = X.to_dask_dataframe()
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
assert classifier.n_classes_ == n_classes
|
||||
prediction_df = classifier.predict(X_d).compute()
|
||||
|
||||
assert prediction_df.ndim == 1
|
||||
@ -393,7 +397,12 @@ def run_dask_classifier(
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
run_dask_classifier(X, y, w, model, client)
|
||||
run_dask_classifier(X, y, w, model, client, 10)
|
||||
|
||||
y_bin = y.copy()
|
||||
y_bin[y > 5] = 1.0
|
||||
y_bin[y <= 5] = 0.0
|
||||
run_dask_classifier(X, y_bin, w, model, client, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user