[breaking] Bump Python requirement to 3.10. (#10434)
- Bump the Python requirement. - Fix type hints. - Use loky to avoid deadlock. - Workaround cupy-numpy compatibility issue on Windows caused by the `safe` casting rule. - Simplify the repartitioning logic to avoid dask errors.
This commit is contained in:
parent
757aafc131
commit
827d0e8edb
6
.github/workflows/main.yml
vendored
6
.github/workflows/main.yml
vendored
@ -74,7 +74,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8"]
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
|
||||
with:
|
||||
@ -116,7 +116,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["ubuntu-latest"]
|
||||
python-version: ["3.8"]
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
|
||||
with:
|
||||
@ -182,7 +182,7 @@ jobs:
|
||||
submodules: 'true'
|
||||
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
architecture: 'x64'
|
||||
- name: Install Python packages
|
||||
run: |
|
||||
|
||||
12
.github/workflows/python_tests.yml
vendored
12
.github/workflows/python_tests.yml
vendored
@ -84,7 +84,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [macos-13, windows-latest]
|
||||
python-version: ["3.8"]
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
|
||||
with:
|
||||
@ -174,7 +174,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- {os: windows-latest, python-version: '3.8'}
|
||||
- {os: windows-latest, python-version: '3.10'}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
|
||||
@ -218,7 +218,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- {os: ubuntu-latest, python-version: "3.8"}
|
||||
- {os: ubuntu-latest, python-version: "3.10"}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
|
||||
@ -271,7 +271,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- {os: ubuntu-latest, python-version: "3.8"}
|
||||
- {os: ubuntu-latest, python-version: "3.10"}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
|
||||
@ -318,10 +318,10 @@ jobs:
|
||||
with:
|
||||
submodules: 'true'
|
||||
|
||||
- name: Set up Python 3.8
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install ninja
|
||||
run: |
|
||||
|
||||
2
.github/workflows/python_wheels.yml
vendored
2
.github/workflows/python_wheels.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
with:
|
||||
miniforge-variant: Mambaforge
|
||||
miniforge-version: latest
|
||||
python-version: 3.9
|
||||
python-version: ["3.10"]
|
||||
use-mamba: true
|
||||
- name: Build wheels
|
||||
run: bash tests/ci_build/build_python_wheels.sh ${{ matrix.platform_id }} ${{ github.sha }}
|
||||
|
||||
2
.github/workflows/r_tests.yml
vendored
2
.github/workflows/r_tests.yml
vendored
@ -86,7 +86,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
architecture: 'x64'
|
||||
|
||||
- uses: r-lib/actions/setup-tinytex@v2
|
||||
|
||||
@ -106,8 +106,8 @@ plt.figure(figsize=(12, 13))
|
||||
bst = xgb.train(
|
||||
params,
|
||||
dmat,
|
||||
15,
|
||||
[(dmat, "train")],
|
||||
num_boost_round=15,
|
||||
evals=[(dmat, "train")],
|
||||
evals_result=res,
|
||||
callbacks=[PlotIntermediateModel()],
|
||||
)
|
||||
|
||||
@ -42,7 +42,7 @@ class IterForDMatrixDemo(xgboost.core.DataIter):
|
||||
"""
|
||||
self.rows = ROWS_PER_BATCH
|
||||
self.cols = COLS
|
||||
rng = cupy.random.RandomState(1994)
|
||||
rng = cupy.random.RandomState(numpy.uint64(1994))
|
||||
self._data = [rng.randn(self.rows, self.cols)] * BATCHES
|
||||
self._labels = [rng.randn(self.rows)] * BATCHES
|
||||
self._weights = [rng.uniform(size=self.rows)] * BATCHES
|
||||
|
||||
@ -8,7 +8,7 @@ This directory contains a demo of Horizontal Federated Learning using
|
||||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
|
||||
[README](../../../plugin/federated/README.md)).
|
||||
|
||||
Install NVFlare (note that currently NVFlare only supports Python 3.8):
|
||||
Install NVFlare:
|
||||
```shell
|
||||
pip install nvflare
|
||||
```
|
||||
|
||||
@ -8,7 +8,7 @@ This directory contains a demo of Vertical Federated Learning using
|
||||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
|
||||
[README](../../../plugin/federated/README.md)).
|
||||
|
||||
Install NVFlare (note that currently NVFlare only supports Python 3.8):
|
||||
Install NVFlare:
|
||||
```shell
|
||||
pip install nvflare
|
||||
```
|
||||
|
||||
@ -286,7 +286,7 @@ latex_documents = [
|
||||
]
|
||||
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3.8", None),
|
||||
"python": ("https://docs.python.org/3.10", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
|
||||
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
|
||||
|
||||
@ -14,7 +14,7 @@ authors = [
|
||||
{ name = "Jiaming Yuan", email = "jm.yuan@outlook.com" }
|
||||
]
|
||||
version = "2.2.0-dev"
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.10"
|
||||
license = { text = "Apache-2.0" }
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
@ -22,8 +22,6 @@ classifiers = [
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
|
||||
@ -14,6 +14,7 @@ from collections.abc import Mapping
|
||||
from enum import IntEnum, unique
|
||||
from functools import wraps
|
||||
from inspect import Parameter, signature
|
||||
from types import EllipsisType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -1826,7 +1827,7 @@ class Booster:
|
||||
state["handle"] = handle
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
|
||||
def __getitem__(self, val: Union[Integer, tuple, slice, EllipsisType]) -> "Booster":
|
||||
"""Get a slice of the tree-based model.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
@ -1835,21 +1836,20 @@ class Booster:
|
||||
# convert to slice for all other types
|
||||
if isinstance(val, (np.integer, int)):
|
||||
val = slice(int(val), int(val + 1))
|
||||
if isinstance(val, type(Ellipsis)):
|
||||
if isinstance(val, EllipsisType):
|
||||
val = slice(0, 0)
|
||||
if isinstance(val, tuple):
|
||||
raise ValueError("Only supports slicing through 1 dimension.")
|
||||
# All supported types are now slice
|
||||
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
|
||||
if not isinstance(val, slice):
|
||||
msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
|
||||
msg = _expect((int, slice, np.integer, EllipsisType), type(val))
|
||||
raise TypeError(msg)
|
||||
|
||||
if isinstance(val.start, type(Ellipsis)) or val.start is None:
|
||||
if isinstance(val.start, EllipsisType) or val.start is None:
|
||||
start = 0
|
||||
else:
|
||||
start = val.start
|
||||
if isinstance(val.stop, type(Ellipsis)) or val.stop is None:
|
||||
if isinstance(val.stop, EllipsisType) or val.stop is None:
|
||||
stop = 0
|
||||
else:
|
||||
stop = val.stop
|
||||
|
||||
@ -292,7 +292,7 @@ class DaskDMatrix:
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
client: Optional["distributed.Client"],
|
||||
data: _DataT,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
*,
|
||||
@ -663,7 +663,7 @@ class DaskQuantileDMatrix(DaskDMatrix):
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
client: Optional["distributed.Client"],
|
||||
data: _DataT,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
*,
|
||||
@ -674,7 +674,7 @@ class DaskQuantileDMatrix(DaskDMatrix):
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
max_bin: Optional[int] = None,
|
||||
ref: Optional[DMatrix] = None,
|
||||
ref: Optional[DaskDMatrix] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
@ -1832,8 +1832,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
verbose: Union[int, bool] = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
verbose: Optional[Union[int, bool]] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
@ -1940,8 +1940,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
verbose: Union[int, bool] = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
verbose: Optional[Union[int, bool]] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
@ -2122,8 +2122,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_group: Optional[Sequence[_DaskCollection]] = None,
|
||||
eval_qid: Optional[Sequence[_DaskCollection]] = None,
|
||||
verbose: Union[int, bool] = False,
|
||||
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
||||
verbose: Optional[Union[int, bool]] = False,
|
||||
xgb_model: Optional[Union[XGBModel, str, Booster]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
@ -2185,8 +2185,8 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
verbose: Union[int, bool] = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
verbose: Optional[Union[int, bool]] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
@ -2246,8 +2246,8 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
verbose: Union[int, bool] = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
verbose: Optional[Union[int, bool]] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
|
||||
@ -5,7 +5,17 @@ import ctypes
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeGuard,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -212,7 +222,7 @@ def is_scipy_coo(data: DataType) -> bool:
|
||||
return is_array or is_matrix
|
||||
|
||||
|
||||
def _is_np_array_like(data: DataType) -> bool:
|
||||
def _is_np_array_like(data: DataType) -> TypeGuard[np.ndarray]:
|
||||
return hasattr(data, "__array_interface__")
|
||||
|
||||
|
||||
@ -241,7 +251,7 @@ def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray:
|
||||
|
||||
|
||||
def _from_numpy_array(
|
||||
data: DataType,
|
||||
data: np.ndarray,
|
||||
missing: FloatCompatible,
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
@ -266,7 +276,7 @@ def _from_numpy_array(
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_pandas_df(data: DataType) -> bool:
|
||||
def _is_pandas_df(data: DataType) -> TypeGuard[DataFrame]:
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
@ -1057,12 +1067,12 @@ def _from_dlpack(
|
||||
return _from_cupy_array(data, missing, nthread, feature_names, feature_types)
|
||||
|
||||
|
||||
def _is_uri(data: DataType) -> bool:
|
||||
def _is_uri(data: DataType) -> TypeGuard[Union[str, os.PathLike]]:
|
||||
return isinstance(data, (str, os.PathLike))
|
||||
|
||||
|
||||
def _from_uri(
|
||||
data: DataType,
|
||||
data: Union[str, os.PathLike],
|
||||
missing: Optional[FloatCompatible],
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
@ -1080,7 +1090,7 @@ def _from_uri(
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_list(data: DataType) -> bool:
|
||||
def _is_list(data: DataType) -> TypeGuard[list]:
|
||||
return isinstance(data, list)
|
||||
|
||||
|
||||
@ -1099,7 +1109,7 @@ def _from_list(
|
||||
)
|
||||
|
||||
|
||||
def _is_tuple(data: DataType) -> bool:
|
||||
def _is_tuple(data: DataType) -> TypeGuard[tuple]:
|
||||
return isinstance(data, tuple)
|
||||
|
||||
|
||||
@ -1116,7 +1126,7 @@ def _from_tuple(
|
||||
)
|
||||
|
||||
|
||||
def _is_iter(data: DataType) -> bool:
|
||||
def _is_iter(data: DataType) -> TypeGuard[DataIter]:
|
||||
return isinstance(data, DataIter)
|
||||
|
||||
|
||||
|
||||
@ -6,14 +6,12 @@ change without notice.
|
||||
# pylint: disable=invalid-name,missing-function-docstring,import-error
|
||||
import gc
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from platform import system
|
||||
@ -46,6 +44,7 @@ from xgboost.testing.data import (
|
||||
get_digits,
|
||||
get_sparse,
|
||||
make_batches,
|
||||
make_sparse_regression,
|
||||
memory,
|
||||
)
|
||||
|
||||
@ -115,6 +114,10 @@ def no_dask() -> PytestSkip:
|
||||
return no_mod("dask")
|
||||
|
||||
|
||||
def no_loky() -> PytestSkip:
|
||||
return no_mod("loky")
|
||||
|
||||
|
||||
def no_dask_ml() -> PytestSkip:
|
||||
if sys.platform.startswith("win"):
|
||||
return {"reason": "Unsupported platform.", "condition": True}
|
||||
@ -136,7 +139,14 @@ def no_arrow() -> PytestSkip:
|
||||
|
||||
|
||||
def no_modin() -> PytestSkip:
|
||||
return no_mod("modin")
|
||||
try:
|
||||
import modin.pandas as md
|
||||
|
||||
md.DataFrame([[1, 2.0, True], [2, 3.0, False]], columns=["a", "b", "c"])
|
||||
|
||||
except ImportError:
|
||||
return {"reason": "Failed import modin.", "condition": True}
|
||||
return {"reason": "Failed import modin.", "condition": True}
|
||||
|
||||
|
||||
def no_dt() -> PytestSkip:
|
||||
@ -487,94 +497,6 @@ def _cat_sampled_from() -> strategies.SearchStrategy:
|
||||
|
||||
categorical_dataset_strategy: strategies.SearchStrategy = _cat_sampled_from()
|
||||
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
@memory.cache
|
||||
def make_sparse_regression(
|
||||
n_samples: int, n_features: int, sparsity: float, as_dense: bool
|
||||
) -> Tuple[Union[sparse.csr_matrix], np.ndarray]:
|
||||
"""Make sparse matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
as_dense:
|
||||
|
||||
Return the matrix as np.ndarray with missing values filled by NaN
|
||||
|
||||
"""
|
||||
if not hasattr(np.random, "default_rng"):
|
||||
rng = np.random.RandomState(1994)
|
||||
X = sparse.random(
|
||||
m=n_samples,
|
||||
n=n_features,
|
||||
density=1.0 - sparsity,
|
||||
random_state=rng,
|
||||
format="csr",
|
||||
)
|
||||
y = rng.normal(loc=0.0, scale=1.0, size=n_samples)
|
||||
return X, y
|
||||
|
||||
# Use multi-thread to speed up the generation, convenient if you use this function
|
||||
# for benchmarking.
|
||||
n_threads = min(multiprocessing.cpu_count(), n_features)
|
||||
|
||||
def random_csc(t_id: int) -> sparse.csc_matrix:
|
||||
rng = np.random.default_rng(1994 * t_id)
|
||||
thread_size = n_features // n_threads
|
||||
if t_id == n_threads - 1:
|
||||
n_features_tloc = n_features - t_id * thread_size
|
||||
else:
|
||||
n_features_tloc = thread_size
|
||||
|
||||
X = sparse.random(
|
||||
m=n_samples,
|
||||
n=n_features_tloc,
|
||||
density=1.0 - sparsity,
|
||||
random_state=rng,
|
||||
).tocsc()
|
||||
y = np.zeros((n_samples, 1))
|
||||
|
||||
for i in range(X.shape[1]):
|
||||
size = X.indptr[i + 1] - X.indptr[i]
|
||||
if size != 0:
|
||||
y += X[:, i].toarray() * rng.random((n_samples, 1)) * 0.2
|
||||
|
||||
return X, y
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=n_threads) as executor:
|
||||
for i in range(n_threads):
|
||||
futures.append(executor.submit(random_csc, i))
|
||||
|
||||
X_results = []
|
||||
y_results = []
|
||||
for f in futures:
|
||||
X, y = f.result()
|
||||
X_results.append(X)
|
||||
y_results.append(y)
|
||||
|
||||
assert len(y_results) == n_threads
|
||||
|
||||
csr: sparse.csr_matrix = sparse.hstack(X_results, format="csr")
|
||||
y = np.asarray(y_results)
|
||||
y = y.reshape((y.shape[0], y.shape[1])).T
|
||||
y = np.sum(y, axis=1)
|
||||
|
||||
assert csr.shape[0] == n_samples
|
||||
assert csr.shape[1] == n_features
|
||||
assert y.shape[0] == n_samples
|
||||
|
||||
if as_dense:
|
||||
arr = csr.toarray()
|
||||
assert arr.shape[0] == n_samples
|
||||
assert arr.shape[1] == n_features
|
||||
arr[arr == 0] = np.nan
|
||||
return arr, y
|
||||
|
||||
return csr, y
|
||||
|
||||
|
||||
sparse_datasets_strategy = strategies.sampled_from(
|
||||
[
|
||||
TestDataset(
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# pylint: disable=invalid-name
|
||||
"""Utilities for data generation."""
|
||||
import multiprocessing
|
||||
import os
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -523,7 +525,7 @@ def make_batches( # pylint: disable=too-many-arguments,too-many-locals
|
||||
if use_cupy:
|
||||
import cupy # pylint: disable=import-error
|
||||
|
||||
rng = cupy.random.RandomState(random_state)
|
||||
rng = cupy.random.RandomState(np.uint64(random_state))
|
||||
else:
|
||||
rng = np.random.RandomState(random_state)
|
||||
for i in range(n_batches):
|
||||
@ -843,3 +845,90 @@ def run_base_margin_info(
|
||||
base_margin = X.reshape(2, 5, 2, 5)
|
||||
with pytest.raises(ValueError, match=r".*base_margin.*"):
|
||||
Xy.set_base_margin(base_margin)
|
||||
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
@memory.cache
|
||||
def make_sparse_regression(
|
||||
n_samples: int, n_features: int, sparsity: float, as_dense: bool
|
||||
) -> Tuple[Union[sparse.csr_matrix], np.ndarray]:
|
||||
"""Make sparse matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
as_dense:
|
||||
|
||||
Return the matrix as np.ndarray with missing values filled by NaN
|
||||
|
||||
"""
|
||||
if not hasattr(np.random, "default_rng"):
|
||||
rng = np.random.RandomState(1994)
|
||||
X = sparse.random(
|
||||
m=n_samples,
|
||||
n=n_features,
|
||||
density=1.0 - sparsity,
|
||||
random_state=rng,
|
||||
format="csr",
|
||||
)
|
||||
y = rng.normal(loc=0.0, scale=1.0, size=n_samples)
|
||||
return X, y
|
||||
|
||||
# Use multi-thread to speed up the generation, convenient if you use this function
|
||||
# for benchmarking.
|
||||
n_threads = min(multiprocessing.cpu_count(), n_features)
|
||||
|
||||
def random_csc(t_id: int) -> sparse.csc_matrix:
|
||||
rng = np.random.default_rng(1994 * t_id)
|
||||
thread_size = n_features // n_threads
|
||||
if t_id == n_threads - 1:
|
||||
n_features_tloc = n_features - t_id * thread_size
|
||||
else:
|
||||
n_features_tloc = thread_size
|
||||
|
||||
X = sparse.random(
|
||||
m=n_samples,
|
||||
n=n_features_tloc,
|
||||
density=1.0 - sparsity,
|
||||
random_state=rng,
|
||||
).tocsc()
|
||||
y = np.zeros((n_samples, 1))
|
||||
|
||||
for i in range(X.shape[1]):
|
||||
size = X.indptr[i + 1] - X.indptr[i]
|
||||
if size != 0:
|
||||
y += X[:, i].toarray() * rng.random((n_samples, 1)) * 0.2
|
||||
|
||||
return X, y
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=n_threads) as executor:
|
||||
for i in range(n_threads):
|
||||
futures.append(executor.submit(random_csc, i))
|
||||
|
||||
X_results = []
|
||||
y_results = []
|
||||
for f in futures:
|
||||
X, y = f.result()
|
||||
X_results.append(X)
|
||||
y_results.append(y)
|
||||
|
||||
assert len(y_results) == n_threads
|
||||
|
||||
csr: sparse.csr_matrix = sparse.hstack(X_results, format="csr")
|
||||
y = np.asarray(y_results)
|
||||
y = y.reshape((y.shape[0], y.shape[1])).T
|
||||
y = np.sum(y, axis=1)
|
||||
|
||||
assert csr.shape[0] == n_samples
|
||||
assert csr.shape[1] == n_features
|
||||
assert y.shape[0] == n_samples
|
||||
|
||||
if as_dense:
|
||||
arr = csr.toarray()
|
||||
assert arr.shape[0] == n_samples
|
||||
assert arr.shape[1] == n_features
|
||||
arr[arr == 0] = np.nan
|
||||
return arr, y
|
||||
|
||||
return csr, y
|
||||
|
||||
@ -198,7 +198,7 @@ class CVPack:
|
||||
def __init__(
|
||||
self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]
|
||||
) -> None:
|
||||
""" "Initialize the CVPack"""
|
||||
"""Initialize the CVPack."""
|
||||
self.dtrain = dtrain
|
||||
self.dtest = dtest
|
||||
self.watchlist = [(dtrain, "train"), (dtest, "test")]
|
||||
@ -277,7 +277,7 @@ class _PackedBooster:
|
||||
self.set_attr(best_score=score)
|
||||
|
||||
|
||||
def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarray:
|
||||
def groups_to_rows(groups: np.ndarray, boundaries: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Given group row boundaries, convert ground indexes to row indexes
|
||||
:param groups: list of groups for testing
|
||||
|
||||
@ -27,7 +27,8 @@ RUN \
|
||||
"nccl>=${NCCL_SHORT_VER}" \
|
||||
dask \
|
||||
dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
|
||||
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
|
||||
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel \
|
||||
python-kubernetes urllib3 graphviz hypothesis loky \
|
||||
"pyspark>=3.4.0" cloudpickle cuda-python && \
|
||||
mamba clean --all --yes && \
|
||||
conda run --no-capture-output -n gpu_test pip install buildkite-test-collector
|
||||
|
||||
@ -30,7 +30,8 @@ RUN \
|
||||
"nccl>=${NCCL_SHORT_VER}" \
|
||||
dask \
|
||||
"dask-cuda=$RAPIDS_VERSION_ARG.*" "dask-cudf=$RAPIDS_VERSION_ARG.*" cupy \
|
||||
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
|
||||
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel \
|
||||
python-kubernetes urllib3 graphviz hypothesis loky \
|
||||
"pyspark>=3.4.0" cloudpickle cuda-python && \
|
||||
mamba clean --all --yes && \
|
||||
conda run --no-capture-output -n gpu_test pip install buildkite-test-collector
|
||||
|
||||
@ -2,7 +2,7 @@ name: aarch64_test
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- pip
|
||||
- wheel
|
||||
- pytest
|
||||
@ -26,7 +26,7 @@ dependencies:
|
||||
- awscli
|
||||
- numba
|
||||
- llvmlite
|
||||
- cffi
|
||||
- loky
|
||||
- pyarrow
|
||||
- pyspark>=3.4.0
|
||||
- cloudpickle
|
||||
|
||||
@ -2,7 +2,7 @@ name: linux_cpu_test
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- cmake
|
||||
- c-compiler
|
||||
- cxx-compiler
|
||||
@ -33,7 +33,7 @@ dependencies:
|
||||
- boto3
|
||||
- awscli
|
||||
- py-ubjson
|
||||
- cffi
|
||||
- loky
|
||||
- pyarrow
|
||||
- protobuf
|
||||
- cloudpickle
|
||||
|
||||
@ -3,7 +3,7 @@ channels:
|
||||
- conda-forge
|
||||
- https://software.repos.intel.com/python/conda/
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- cmake
|
||||
- c-compiler
|
||||
- cxx-compiler
|
||||
|
||||
@ -2,7 +2,7 @@ name: macos_test
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- pip
|
||||
- wheel
|
||||
- pyyaml
|
||||
@ -32,7 +32,7 @@ dependencies:
|
||||
- jsonschema
|
||||
- boto3
|
||||
- awscli
|
||||
- cffi
|
||||
- loky
|
||||
- pyarrow
|
||||
- pyspark>=3.4.0
|
||||
- cloudpickle
|
||||
|
||||
@ -2,11 +2,11 @@ name: python_lint
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- pylint<3.2.4 # https://github.com/pylint-dev/pylint/issues/9751
|
||||
- wheel
|
||||
- setuptools
|
||||
- mypy>=0.981
|
||||
- mypy
|
||||
- numpy
|
||||
- scipy
|
||||
- pandas
|
||||
|
||||
@ -3,7 +3,7 @@ name: sdist_test
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- pip
|
||||
- wheel
|
||||
- cmake
|
||||
|
||||
@ -2,7 +2,7 @@ name: win64_env
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- wheel
|
||||
- numpy
|
||||
- scipy
|
||||
@ -18,5 +18,5 @@ dependencies:
|
||||
- python-graphviz
|
||||
- pip
|
||||
- py-ubjson
|
||||
- cffi
|
||||
- loky
|
||||
- pyarrow
|
||||
|
||||
@ -2,7 +2,7 @@ name: win64_env
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- python=3.10
|
||||
- numpy
|
||||
- scipy
|
||||
- matplotlib
|
||||
@ -12,9 +12,9 @@ dependencies:
|
||||
- boto3
|
||||
- hypothesis
|
||||
- jsonschema
|
||||
- cupy
|
||||
- cupy>=13.2
|
||||
- python-graphviz
|
||||
- pip
|
||||
- py-ubjson
|
||||
- cffi
|
||||
- loky
|
||||
- pyarrow
|
||||
|
||||
@ -20,7 +20,7 @@ class TestQuantileDMatrix:
|
||||
def test_dmatrix_feature_weights(self) -> None:
|
||||
import cupy as cp
|
||||
|
||||
rng = cp.random.RandomState(1994)
|
||||
rng = cp.random.RandomState(np.uint64(1994))
|
||||
data = rng.randn(5, 5)
|
||||
m = xgb.DMatrix(data)
|
||||
|
||||
@ -146,7 +146,7 @@ class TestQuantileDMatrix:
|
||||
def test_metainfo(self) -> None:
|
||||
import cupy as cp
|
||||
|
||||
rng = cp.random.RandomState(1994)
|
||||
rng = cp.random.RandomState(np.uint64(1994))
|
||||
|
||||
rows = 10
|
||||
cols = 3
|
||||
@ -170,7 +170,7 @@ class TestQuantileDMatrix:
|
||||
def test_ref_dmatrix(self) -> None:
|
||||
import cupy as cp
|
||||
|
||||
rng = cp.random.RandomState(1994)
|
||||
rng = cp.random.RandomState(np.uint64(1994))
|
||||
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
|
||||
|
||||
@given(
|
||||
|
||||
@ -66,7 +66,7 @@ def _test_from_cupy(DMatrixT):
|
||||
|
||||
def _test_cupy_training(DMatrixT):
|
||||
np.random.seed(1)
|
||||
cp.random.seed(1)
|
||||
cp.random.seed(np.uint64(1))
|
||||
X = cp.random.randn(50, 10, dtype="float32")
|
||||
y = cp.random.randn(50, dtype="float32")
|
||||
weights = np.random.random(50) + 1
|
||||
@ -131,7 +131,7 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_cupy_training_with_sklearn():
|
||||
np.random.seed(1)
|
||||
cp.random.seed(1)
|
||||
cp.random.seed(np.uint64(1))
|
||||
X = cp.random.randn(50, 10, dtype="float32")
|
||||
y = (cp.random.randn(50, dtype="float32") > 0).astype("int8")
|
||||
weights = np.random.random(50) + 1
|
||||
@ -210,7 +210,7 @@ class TestFromCupy:
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_qid(self):
|
||||
rng = cp.random.RandomState(1994)
|
||||
rng = cp.random.RandomState(np.uint64(1994))
|
||||
rows = 100
|
||||
cols = 10
|
||||
X, y = rng.randn(rows, cols), rng.randn(rows)
|
||||
|
||||
@ -226,7 +226,7 @@ class TestGPUPredict:
|
||||
cols = 10
|
||||
missing = 11 # set to integer for testing
|
||||
|
||||
cp_rng = cp.random.RandomState(1994)
|
||||
cp_rng = cp.random.RandomState(np.uint64(1994))
|
||||
cp.random.set_random_state(cp_rng)
|
||||
|
||||
X = cp.random.randn(rows, cols)
|
||||
@ -546,7 +546,7 @@ class TestGPUPredict:
|
||||
|
||||
rows = 1000
|
||||
cols = 10
|
||||
rng = cp.random.RandomState(1994)
|
||||
rng = cp.random.RandomState(np.uint64(1994))
|
||||
orig = rng.randint(low=0, high=127, size=rows * cols).reshape(rows, cols)
|
||||
y = rng.randint(low=0, high=127, size=rows)
|
||||
dtrain = xgb.DMatrix(orig, label=y)
|
||||
@ -576,8 +576,8 @@ class TestGPUPredict:
|
||||
# boolean
|
||||
orig = cp.random.binomial(1, 0.5, size=rows * cols).reshape(rows, cols)
|
||||
predt_orig = booster.inplace_predict(orig)
|
||||
for dtype in [cp.bool8, cp.bool_]:
|
||||
X = cp.array(orig, dtype=dtype)
|
||||
|
||||
X = cp.array(orig, dtype=cp.bool_)
|
||||
predt = booster.inplace_predict(X)
|
||||
cp.testing.assert_allclose(predt, predt_orig)
|
||||
|
||||
|
||||
@ -425,8 +425,8 @@ class TestModels:
|
||||
np.testing.assert_allclose(merged, single, atol=1e-6)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@pytest.mark.parametrize("booster", ["gbtree", "dart"])
|
||||
def test_slice(self, booster):
|
||||
@pytest.mark.parametrize("booster_name", ["gbtree", "dart"])
|
||||
def test_slice(self, booster_name: str) -> None:
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
num_classes = 3
|
||||
@ -442,7 +442,7 @@ class TestModels:
|
||||
"num_parallel_tree": num_parallel_tree,
|
||||
"subsample": 0.5,
|
||||
"num_class": num_classes,
|
||||
"booster": booster,
|
||||
"booster": booster_name,
|
||||
"objective": "multi:softprob",
|
||||
},
|
||||
num_boost_round=num_boost_round,
|
||||
@ -452,6 +452,8 @@ class TestModels:
|
||||
|
||||
assert len(booster.get_dump()) == total_trees
|
||||
|
||||
assert booster[...].num_boosted_rounds() == num_boost_round
|
||||
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||
)
|
||||
|
||||
@ -1,44 +1,46 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
import sys
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from loky import get_reusable_executor
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker, build_info, federated
|
||||
from xgboost import testing as tm
|
||||
|
||||
|
||||
def run_rabit_worker(rabit_env, world_size):
|
||||
def run_rabit_worker(rabit_env: dict, world_size: int) -> int:
|
||||
with xgb.collective.CommunicatorContext(**rabit_env):
|
||||
assert xgb.collective.get_world_size() == world_size
|
||||
assert xgb.collective.is_distributed()
|
||||
assert xgb.collective.get_processor_name() == socket.gethostname()
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
reduced = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(reduced, np.asarray([2, 4, 6]))
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_loky())
|
||||
def test_rabit_communicator() -> None:
|
||||
world_size = 2
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
|
||||
tracker.start()
|
||||
workers = []
|
||||
with get_reusable_executor(max_workers=world_size) as pool:
|
||||
for _ in range(world_size):
|
||||
worker = multiprocessing.Process(
|
||||
target=run_rabit_worker, args=(tracker.worker_args(), world_size)
|
||||
worker = pool.submit(
|
||||
run_rabit_worker, rabit_env=tracker.worker_args(), world_size=world_size
|
||||
)
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
assert worker.result() == 0
|
||||
|
||||
|
||||
def run_federated_worker(port: int, world_size: int, rank: int) -> None:
|
||||
def run_federated_worker(port: int, world_size: int, rank: int) -> int:
|
||||
with xgb.collective.CommunicatorContext(
|
||||
dmlc_communicator="federated",
|
||||
federated_server_address=f"localhost:{port}",
|
||||
@ -52,30 +54,28 @@ def run_federated_worker(port: int, world_size: int, rank: int) -> None:
|
||||
assert str(bret) == "test1234"
|
||||
aret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(aret, np.asarray([2, 4, 6]))
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.skip_win())
|
||||
@pytest.mark.skipif(**tm.no_loky())
|
||||
def test_federated_communicator():
|
||||
if not build_info()["USE_FEDERATED"]:
|
||||
pytest.skip("XGBoost not built with federated learning enabled")
|
||||
|
||||
port = 9091
|
||||
world_size = 2
|
||||
tracker = multiprocessing.Process(
|
||||
target=federated.run_federated_server,
|
||||
kwargs={"port": port, "n_workers": world_size, "blocking": False},
|
||||
)
|
||||
tracker.start()
|
||||
if not tracker.is_alive():
|
||||
raise Exception("Error starting Federated Learning server")
|
||||
with get_reusable_executor(max_workers=world_size+1) as pool:
|
||||
kwargs={"port": port, "n_workers": world_size, "blocking": False}
|
||||
tracker = pool.submit(federated.run_federated_server, **kwargs)
|
||||
if not tracker.running():
|
||||
raise RuntimeError("Error starting Federated Learning server")
|
||||
|
||||
workers = []
|
||||
for rank in range(world_size):
|
||||
worker = multiprocessing.Process(
|
||||
target=run_federated_worker, args=(port, world_size, rank)
|
||||
worker = pool.submit(
|
||||
run_federated_worker, port=port, world_size=world_size, rank=rank
|
||||
)
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
assert worker.result() == 0
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
"""Copyright 2019-2023, XGBoost contributors"""
|
||||
"""Copyright 2019-2024, XGBoost contributors"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from inspect import signature
|
||||
from typing import Any, Dict, Type, TypeVar
|
||||
|
||||
@ -53,15 +54,13 @@ except ImportError:
|
||||
|
||||
def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
||||
import cupy as cp
|
||||
import dask_cudf
|
||||
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
_X, _y, _ = generate_array()
|
||||
|
||||
X = dd.from_dask_array(_X)
|
||||
y = dd.from_dask_array(_y)
|
||||
|
||||
X = X.map_partitions(cudf.from_pandas)
|
||||
y = y.map_partitions(cudf.from_pandas)
|
||||
X = dd.from_dask_array(_X).to_backend("cudf")
|
||||
y = dd.from_dask_array(_y).to_backend("cudf")
|
||||
|
||||
dtrain = DMatrixT(client, X, y)
|
||||
out = dxgb.train(
|
||||
@ -216,18 +215,22 @@ def test_tree_stats() -> None:
|
||||
class TestDistributedGPU:
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_boost_from_prediction(self, local_cuda_client: Client) -> None:
|
||||
import cudf
|
||||
import dask_cudf
|
||||
from sklearn.datasets import load_breast_cancer, load_iris
|
||||
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction(X, y, "hist", "cuda", local_cuda_client)
|
||||
X = dd.from_array(X_, chunksize=100).to_backend("cudf")
|
||||
y = dd.from_array(y_, chunksize=100).to_backend("cudf")
|
||||
divisions = copy(X.divisions)
|
||||
run_boost_from_prediction(X, y, "hist", "cuda", local_cuda_client, divisions)
|
||||
|
||||
X_, y_ = load_iris(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=50).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction_multi_class(X, y, "hist", "cuda", local_cuda_client)
|
||||
X = dd.from_array(X_, chunksize=50).to_backend("cudf")
|
||||
y = dd.from_array(y_, chunksize=50).to_backend("cudf")
|
||||
divisions = copy(X.divisions)
|
||||
run_boost_from_prediction_multi_class(
|
||||
X, y, "hist", "cuda", local_cuda_client, divisions
|
||||
)
|
||||
|
||||
def test_init_estimation(self, local_cuda_client: Client) -> None:
|
||||
check_init_estimation("hist", "cuda", local_cuda_client)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Copyright 2019-2022 XGBoost contributors"""
|
||||
"""Copyright 2019-2024, XGBoost contributors"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@ -7,12 +7,24 @@ import pickle
|
||||
import socket
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from itertools import starmap
|
||||
from math import ceil
|
||||
from operator import attrgetter, getitem
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import hypothesis
|
||||
import numpy as np
|
||||
@ -133,34 +145,6 @@ def generate_array(
|
||||
return X, y, None
|
||||
|
||||
|
||||
def deterministic_persist_per_worker(
|
||||
df: dd.DataFrame, client: "Client"
|
||||
) -> dd.DataFrame:
|
||||
# Got this script from https://github.com/dmlc/xgboost/issues/7927
|
||||
# Query workers
|
||||
n_workers = len(client.cluster.workers)
|
||||
workers = map(attrgetter("worker_address"), client.cluster.workers.values())
|
||||
|
||||
# Slice data into roughly equal partitions
|
||||
subpartition_size = ceil(df.npartitions / n_workers)
|
||||
subpartition_divisions = range(
|
||||
0, df.npartitions + subpartition_size, subpartition_size
|
||||
)
|
||||
subpartition_slices = starmap(slice, sliding_window(2, subpartition_divisions))
|
||||
subpartitions = map(partial(getitem, df.partitions), subpartition_slices)
|
||||
|
||||
# Persist each subpartition on each worker
|
||||
# Rebuild dataframe from persisted subpartitions
|
||||
df2 = dd.concat(
|
||||
[
|
||||
sp.persist(workers=w, allow_other_workers=False)
|
||||
for sp, w in zip(subpartitions, workers)
|
||||
]
|
||||
)
|
||||
|
||||
return df2
|
||||
|
||||
|
||||
Margin = TypeVar("Margin", dd.DataFrame, dd.Series, None)
|
||||
|
||||
|
||||
@ -169,30 +153,14 @@ def deterministic_repartition(
|
||||
X: dd.DataFrame,
|
||||
y: dd.Series,
|
||||
m: Margin,
|
||||
divisions,
|
||||
) -> Tuple[dd.DataFrame, dd.Series, Margin]:
|
||||
# force repartition the data to avoid non-deterministic result
|
||||
if any(X.map_partitions(lambda x: _is_cudf_df(x)).compute()):
|
||||
# dask_cudf seems to be doing fine for now
|
||||
return X, y, m
|
||||
|
||||
X["_y"] = y
|
||||
if m is not None:
|
||||
if isinstance(m, dd.DataFrame):
|
||||
m_columns = m.columns
|
||||
X = dd.concat([X, m], join="outer", axis=1)
|
||||
else:
|
||||
m_columns = ["_m"]
|
||||
X["_m"] = m
|
||||
|
||||
X = deterministic_persist_per_worker(X, client)
|
||||
|
||||
y = X["_y"]
|
||||
X = X[X.columns.difference(["_y"])]
|
||||
if m is not None:
|
||||
m = X[m_columns]
|
||||
X = X[X.columns.difference(m_columns)]
|
||||
|
||||
return X, y, m
|
||||
X, y, margin = (
|
||||
dd.repartition(X, divisions=divisions, force=True),
|
||||
dd.repartition(y, divisions=divisions, force=True),
|
||||
dd.repartition(m, divisions=divisions, force=True) if m is not None else None,
|
||||
)
|
||||
return X, y, margin
|
||||
|
||||
|
||||
@pytest.mark.parametrize("to_frame", [True, False])
|
||||
@ -218,10 +186,10 @@ def test_xgbclassifier_classes_type_and_value(to_frame: bool, client: "Client"):
|
||||
def test_from_dask_dataframe() -> None:
|
||||
with LocalCluster(n_workers=kWorkers, dashboard_address=":0") as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y, _ = generate_array()
|
||||
X_, y_, _ = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
X = dd.from_dask_array(X_)
|
||||
y = dd.from_dask_array(y_)
|
||||
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"]
|
||||
@ -456,6 +424,7 @@ def run_boost_from_prediction_multi_class(
|
||||
tree_method: str,
|
||||
device: str,
|
||||
client: "Client",
|
||||
divisions: List[int],
|
||||
) -> None:
|
||||
model_0 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3,
|
||||
@ -464,7 +433,7 @@ def run_boost_from_prediction_multi_class(
|
||||
max_bin=768,
|
||||
device=device,
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = xgb.dask.inplace_predict(
|
||||
client, model_0.get_booster(), X, predict_type="margin"
|
||||
@ -478,7 +447,7 @@ def run_boost_from_prediction_multi_class(
|
||||
max_bin=768,
|
||||
device=device,
|
||||
)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = xgb.dask.predict(
|
||||
client,
|
||||
@ -494,7 +463,7 @@ def run_boost_from_prediction_multi_class(
|
||||
max_bin=768,
|
||||
device=device,
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
|
||||
model_2.fit(X=X, y=y)
|
||||
predictions_2 = xgb.dask.inplace_predict(
|
||||
client, model_2.get_booster(), X, predict_type="margin"
|
||||
@ -517,6 +486,7 @@ def run_boost_from_prediction(
|
||||
tree_method: str,
|
||||
device: str,
|
||||
client: "Client",
|
||||
divisions: List[int],
|
||||
) -> None:
|
||||
X, y = client.persist([X, y])
|
||||
|
||||
@ -527,7 +497,7 @@ def run_boost_from_prediction(
|
||||
max_bin=512,
|
||||
device=device,
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin: dd.Series = model_0.predict(X, output_margin=True)
|
||||
|
||||
@ -538,9 +508,9 @@ def run_boost_from_prediction(
|
||||
max_bin=512,
|
||||
device=device,
|
||||
)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
|
||||
predictions_1: dd.Series = model_1.predict(X, base_margin=margin)
|
||||
|
||||
model_2 = xgb.dask.DaskXGBClassifier(
|
||||
@ -550,7 +520,7 @@ def run_boost_from_prediction(
|
||||
max_bin=512,
|
||||
device=device,
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
|
||||
model_2.fit(X=X, y=y)
|
||||
predictions_2: dd.Series = model_2.predict(X)
|
||||
|
||||
@ -563,13 +533,13 @@ def run_boost_from_prediction(
|
||||
np.testing.assert_allclose(predt_1, predt_2, atol=1e-5)
|
||||
|
||||
margined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
|
||||
margined.fit(
|
||||
X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin]
|
||||
)
|
||||
|
||||
unmargined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
|
||||
unmargined.fit(X=X, y=y, eval_set=[(X, y)], base_margin=margin)
|
||||
|
||||
margined_res = margined.evals_result()["validation_0"]["logloss"]
|
||||
@ -587,11 +557,13 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=200), dd.from_array(y_, chunksize=200)
|
||||
run_boost_from_prediction(X, y, tree_method, "cpu", client)
|
||||
divisions = copy(X.divisions)
|
||||
run_boost_from_prediction(X, y, tree_method, "cpu", client, divisions)
|
||||
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
|
||||
run_boost_from_prediction_multi_class(X, y, tree_method, "cpu", client)
|
||||
divisions = copy(X.divisions)
|
||||
run_boost_from_prediction_multi_class(X, y, tree_method, "cpu", client, divisions)
|
||||
|
||||
|
||||
def test_inplace_predict(client: "Client") -> None:
|
||||
@ -1594,7 +1566,7 @@ class TestWithDask:
|
||||
def test_empty_quantile_dmatrix(self, client: Client) -> None:
|
||||
X, y = make_categorical(client, 2, 30, 13)
|
||||
X_valid, y_valid = make_categorical(client, 10000, 30, 13)
|
||||
X_valid, y_valid, _ = deterministic_repartition(client, X_valid, y_valid, None)
|
||||
divisions = copy(X_valid.divisions)
|
||||
|
||||
Xy = xgb.dask.DaskQuantileDMatrix(client, X, y, enable_categorical=True)
|
||||
Xy_valid = xgb.dask.DaskQuantileDMatrix(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user