Move Python testing utilities into xgboost module. (#8379)

- Add typehints.
- Fixes for pylint.

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2022-10-26 16:56:11 +08:00 committed by GitHub
parent 7e53189e7c
commit cf70864fa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 652 additions and 595 deletions

View File

@ -65,7 +65,7 @@ def _check_rf_callback(
) )
_SklObjective = Optional[ SklObjective = Optional[
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
] ]
@ -144,7 +144,7 @@ __model_doc = f"""
Boosting learning rate (xgb's "eta") Boosting learning rate (xgb's "eta")
verbosity : Optional[int] verbosity : Optional[int]
The degree of verbosity. Valid values are 0 (silent) - 3 (debug). The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : {_SklObjective} objective : {SklObjective}
Specify the learning task and the corresponding learning objective or Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below). a custom objective function to be used (see note below).
booster: Optional[str] booster: Optional[str]
@ -546,7 +546,7 @@ class XGBModel(XGBModelBase):
learning_rate: Optional[float] = None, learning_rate: Optional[float] = None,
n_estimators: int = 100, n_estimators: int = 100,
verbosity: Optional[int] = None, verbosity: Optional[int] = None,
objective: _SklObjective = None, objective: SklObjective = None,
booster: Optional[str] = None, booster: Optional[str] = None,
tree_method: Optional[str] = None, tree_method: Optional[str] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
@ -1409,7 +1409,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def __init__( def __init__(
self, self,
*, *,
objective: _SklObjective = "binary:logistic", objective: SklObjective = "binary:logistic",
use_label_encoder: Optional[bool] = None, use_label_encoder: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -1712,7 +1712,7 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
@_deprecate_positional_args @_deprecate_positional_args
def __init__( def __init__(
self, *, objective: _SklObjective = "reg:squarederror", **kwargs: Any self, *, objective: SklObjective = "reg:squarederror", **kwargs: Any
) -> None: ) -> None:
super().__init__(objective=objective, **kwargs) super().__init__(objective=objective, **kwargs)

View File

@ -1,64 +0,0 @@
"""Utilities for defining Python tests."""
import socket
from platform import system
from typing import Any, TypedDict
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
def has_ipv6() -> bool:
"""Check whether IPv6 is enabled on this host."""
# connection error in macos, still need some fixes.
if system() not in ("Linux", "Windows"):
return False
if socket.has_ipv6:
try:
with socket.socket(
socket.AF_INET6, socket.SOCK_STREAM
) as server, socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as client:
server.bind(("::1", 0))
port = server.getsockname()[1]
server.listen()
client.connect(("::1", port))
conn, _ = server.accept()
client.sendall("abc".encode())
msg = conn.recv(3).decode()
# if the code can be executed to this point, the message should be
# correct.
assert msg == "abc"
return True
except OSError:
pass
return False
def skip_ipv6() -> PytestSkip:
"""PyTest skip mark for IPv6."""
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}
def timeout(sec: int, *args: Any, enable: bool = True, **kwargs: Any) -> Any:
"""Make a pytest mark for the `pytest-timeout` package.
Parameters
----------
sec :
Timeout seconds.
enable :
Control whether timeout should be applied, used for debugging.
Returns
-------
pytest.mark.timeout
"""
import pytest # pylint: disable=import-error
# This is disabled for now due to regression caused by conflicts between federated
# learning build and the CI container environment.
if enable:
return pytest.mark.timeout(sec, *args, **kwargs)
return pytest.mark.timeout(None, *args, **kwargs)

View File

@ -1,192 +1,190 @@
from concurrent.futures import ThreadPoolExecutor """Utilities for defining Python tests. The module is private and subject to frequent
import os change without notice.
"""
# pylint: disable=invalid-name,missing-function-docstring,import-error
import gc
import importlib.util
import multiprocessing import multiprocessing
from typing import Tuple, Union, List, Sequence, Callable import os
import platform
import socket
import sys
import urllib import urllib
import zipfile import zipfile
import sys from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Dict, Any
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO from io import StringIO
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED from platform import system
import pytest from typing import (
import gc Any,
import xgboost as xgb Callable,
from xgboost.core import ArrayLike Dict,
import numpy as np Generator,
from scipy import sparse List,
import platform Optional,
Sequence,
Set,
Tuple,
TypedDict,
Union,
)
hypothesis = pytest.importorskip('hypothesis') import numpy as np
sklearn = pytest.importorskip('sklearn') import pytest
from scipy import sparse
from xgboost.core import ArrayLike
from xgboost.sklearn import SklObjective
import xgboost as xgb
hypothesis = pytest.importorskip("hypothesis")
# pylint:disable=wrong-import-position,wrong-import-order
from hypothesis import strategies from hypothesis import strategies
from hypothesis.extra.numpy import arrays from hypothesis.extra.numpy import arrays
from joblib import Memory
from sklearn import datasets
try: joblib = pytest.importorskip("joblib")
import cupy as cp datasets = pytest.importorskip("sklearn.datasets")
except ImportError:
cp = None
memory = Memory('./cachedir', verbose=0) Memory = joblib.Memory
memory = Memory("./cachedir", verbose=0)
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
def no_ubjson(): def has_ipv6() -> bool:
reason = "ubjson is not intsalled." """Check whether IPv6 is enabled on this host."""
try: # connection error in macos, still need some fixes.
import ubjson # noqa if system() not in ("Linux", "Windows"):
return {"condition": False, "reason": reason} return False
except ImportError:
return {"condition": True, "reason": reason} if socket.has_ipv6:
try:
with socket.socket(
socket.AF_INET6, socket.SOCK_STREAM
) as server, socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as client:
server.bind(("::1", 0))
port = server.getsockname()[1]
server.listen()
client.connect(("::1", port))
conn, _ = server.accept()
client.sendall("abc".encode())
msg = conn.recv(3).decode()
# if the code can be executed to this point, the message should be
# correct.
assert msg == "abc"
return True
except OSError:
pass
return False
def no_sklearn(): def no_mod(name: str) -> PytestSkip:
return {'condition': not SKLEARN_INSTALLED, spec = importlib.util.find_spec(name)
'reason': 'Scikit-Learn is not installed'} return {"condition": spec is None, "reason": f"{name} is not installed."}
def no_dask(): def no_ipv6() -> PytestSkip:
try: """PyTest skip mark for IPv6."""
import pkg_resources return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}
pkg_resources.get_distribution("dask")
DASK_INSTALLED = True
except pkg_resources.DistributionNotFound:
DASK_INSTALLED = False
return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"}
def no_spark(): def no_ubjson() -> PytestSkip:
try: return no_mod("ubjson")
import pyspark # noqa
SPARK_INSTALLED = True
except ImportError:
SPARK_INSTALLED = False
return {"condition": not SPARK_INSTALLED, "reason": "Spark is not installed"}
def no_pandas(): def no_sklearn() -> PytestSkip:
return {'condition': not PANDAS_INSTALLED, return no_mod("sklearn")
'reason': 'Pandas is not installed.'}
def no_arrow(): def no_dask() -> PytestSkip:
reason = "pyarrow is not installed" return no_mod("dask")
try:
import pyarrow # noqa
return {"condition": False, "reason": reason}
except ImportError:
return {"condition": True, "reason": reason}
def no_modin(): def no_spark() -> PytestSkip:
reason = 'Modin is not installed.' return no_mod("pyspark")
try:
import modin.pandas as _ # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_dt(): def no_pandas() -> PytestSkip:
import importlib.util return no_mod("pandas")
spec = importlib.util.find_spec('datatable')
return {'condition': spec is None,
'reason': 'Datatable is not installed.'}
def no_matplotlib(): def no_arrow() -> PytestSkip:
reason = 'Matplotlib is not installed.' return no_mod("pyarrow")
def no_modin() -> PytestSkip:
return no_mod("modin")
def no_dt() -> PytestSkip:
return no_mod("datatable")
def no_matplotlib() -> PytestSkip:
reason = "Matplotlib is not installed."
try: try:
import matplotlib.pyplot as _ # noqa import matplotlib.pyplot as _ # noqa
return {'condition': False,
'reason': reason} return {"condition": False, "reason": reason}
except ImportError: except ImportError:
return {'condition': True, return {"condition": True, "reason": reason}
'reason': reason}
def no_dask_cuda(): def no_dask_cuda() -> PytestSkip:
reason = 'dask_cuda is not installed.' return no_mod("dask_cuda")
try:
import dask_cuda as _ # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_cudf(): def no_cudf() -> PytestSkip:
try: return no_mod("cudf")
import cudf # noqa
CUDF_INSTALLED = True
except ImportError:
CUDF_INSTALLED = False
return {'condition': not CUDF_INSTALLED,
'reason': 'CUDF is not installed'}
def no_cupy(): def no_cupy() -> PytestSkip:
reason = 'cupy is not installed.' return no_mod("cupy")
try:
import cupy as _ # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_dask_cudf(): def no_dask_cudf() -> PytestSkip:
reason = 'dask_cudf is not installed.' return no_mod("dask_cudf")
try:
import dask_cudf as _ # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_json_schema(): def no_json_schema() -> PytestSkip:
reason = 'jsonschema is not installed' return no_mod("jsonschema")
try:
import jsonschema # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_graphviz(): def no_graphviz() -> PytestSkip:
reason = 'graphviz is not installed' return no_mod("graphviz")
try:
import graphviz # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_multiple(*args): def no_multiple(*args: Any) -> PytestSkip:
condition = False condition = False
reason = '' reason = ""
for arg in args: for arg in args:
condition = (condition or arg['condition']) condition = condition or arg["condition"]
if arg['condition']: if arg["condition"]:
reason = arg['reason'] reason = arg["reason"]
break break
return {'condition': condition, 'reason': reason} return {"condition": condition, "reason": reason}
def skip_s390x(): def skip_s390x() -> PytestSkip:
condition = platform.machine() == "s390x" condition = platform.machine() == "s390x"
reason = "Known to fail on s390x" reason = "Known to fail on s390x"
return {"condition": condition, "reason": reason} return {"condition": condition, "reason": reason}
class IteratorForTest(xgb.core.DataIter): class IteratorForTest(xgb.core.DataIter):
"""Iterator for testing streaming DMatrix. (external memory, quantile)"""
def __init__( def __init__(
self, self,
X: Sequence, X: Sequence,
y: Sequence, y: Sequence,
w: Optional[Sequence], w: Optional[Sequence],
cache: Optional[str] = "./" cache: Optional[str] = "./",
) -> None: ) -> None:
assert len(X) == len(y) assert len(X) == len(y)
self.X = X self.X = X
@ -242,7 +240,7 @@ def make_batches(
rng = cupy.random.RandomState(1994) rng = cupy.random.RandomState(1994)
else: else:
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
for i in range(n_batches): for _ in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features) _X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch) _y = rng.randn(n_samples_per_batch)
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch) _w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
@ -259,7 +257,7 @@ def make_batches_sparse(
y = [] y = []
w = [] w = []
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
for i in range(n_batches): for _ in range(n_batches):
_X = sparse.random( _X = sparse.random(
n_samples_per_batch, n_samples_per_batch,
n_features, n_features,
@ -276,8 +274,9 @@ def make_batches_sparse(
return X, y, w return X, y, w
# Contains a dataset in numpy format as well as the relevant objective and metric
class TestDataset: class TestDataset:
"""Contains a dataset in numpy format as well as the relevant objective and metric."""
def __init__( def __init__(
self, name: str, get_dataset: Callable, objective: str, metric: str self, name: str, get_dataset: Callable, objective: str, metric: str
) -> None: ) -> None:
@ -289,18 +288,24 @@ class TestDataset:
self.margin: Optional[np.ndarray] = None self.margin: Optional[np.ndarray] = None
def set_params(self, params_in: Dict[str, Any]) -> Dict[str, Any]: def set_params(self, params_in: Dict[str, Any]) -> Dict[str, Any]:
params_in['objective'] = self.objective params_in["objective"] = self.objective
params_in['eval_metric'] = self.metric params_in["eval_metric"] = self.metric
if self.objective == "multi:softmax": if self.objective == "multi:softmax":
params_in["num_class"] = int(np.max(self.y) + 1) params_in["num_class"] = int(np.max(self.y) + 1)
return params_in return params_in
def get_dmat(self) -> xgb.DMatrix: def get_dmat(self) -> xgb.DMatrix:
return xgb.DMatrix( return xgb.DMatrix(
self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True self.X,
self.y,
weight=self.w,
base_margin=self.margin,
enable_categorical=True,
) )
def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix: def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix:
import cupy as cp
w = None if self.w is None else cp.array(self.w) w = None if self.w is None else cp.array(self.w)
X = cp.array(self.X, dtype=np.float32) X = cp.array(self.X, dtype=np.float32)
y = cp.array(self.y, dtype=np.float32) y = cp.array(self.y, dtype=np.float32)
@ -318,9 +323,9 @@ class TestDataset:
beg = i * per_batch beg = i * per_batch
end = min((i + 1) * per_batch, n_samples) end = min((i + 1) * per_batch, n_samples)
assert end != beg assert end != beg
X = self.X[beg: end, ...] X = self.X[beg:end, ...]
y = self.y[beg: end] y = self.y[beg:end]
w = self.w[beg: end] if self.w is not None else None w = self.w[beg:end] if self.w is not None else None
predictor.append(X) predictor.append(X)
response.append(y) response.append(y)
if w is not None: if w is not None:
@ -334,25 +339,24 @@ class TestDataset:
@memory.cache @memory.cache
def get_california_housing(): def get_california_housing() -> Tuple[np.ndarray, np.ndarray]:
data = datasets.fetch_california_housing() data = datasets.fetch_california_housing()
return data.data, data.target return data.data, data.target
@memory.cache @memory.cache
def get_digits(): def get_digits() -> Tuple[np.ndarray, np.ndarray]:
data = datasets.load_digits() data = datasets.load_digits()
return data.data, data.target return data.data, data.target
@memory.cache @memory.cache
def get_cancer(): def get_cancer() -> Tuple[np.ndarray, np.ndarray]:
data = datasets.load_breast_cancer() return datasets.load_breast_cancer(return_X_y=True)
return data.data, data.target
@memory.cache @memory.cache
def get_sparse(): def get_sparse() -> Tuple[np.ndarray, np.ndarray]:
rng = np.random.RandomState(199) rng = np.random.RandomState(199)
n = 2000 n = 2000
sparsity = 0.75 sparsity = 0.75
@ -366,7 +370,7 @@ def get_sparse():
@memory.cache @memory.cache
def get_ames_housing(): def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]:
""" """
Number of samples: 1460 Number of samples: 1460
Number of features: 20 Number of features: 20
@ -374,22 +378,23 @@ def get_ames_housing():
Number of numerical features: 10 Number of numerical features: 10
""" """
from sklearn.datasets import fetch_openml from sklearn.datasets import fetch_openml
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True) X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
categorical_columns_subset: list[str] = [ categorical_columns_subset: List[str] = [
"BldgType", # 5 cats, no nan "BldgType", # 5 cats, no nan
"GarageFinish", # 3 cats, nan "GarageFinish", # 3 cats, nan
"LotConfig", # 5 cats, no nan "LotConfig", # 5 cats, no nan
"Functional", # 7 cats, no nan "Functional", # 7 cats, no nan
"MasVnrType", # 4 cats, nan "MasVnrType", # 4 cats, nan
"HouseStyle", # 8 cats, no nan "HouseStyle", # 8 cats, no nan
"FireplaceQu", # 5 cats, nan "FireplaceQu", # 5 cats, nan
"ExterCond", # 5 cats, no nan "ExterCond", # 5 cats, no nan
"ExterQual", # 4 cats, no nan "ExterQual", # 4 cats, no nan
"PoolQC", # 3 cats, nan "PoolQC", # 3 cats, nan
] ]
numerical_columns_subset: list[str] = [ numerical_columns_subset: List[str] = [
"3SsnPorch", "3SsnPorch",
"Fireplaces", "Fireplaces",
"BsmtHalfBath", "BsmtHalfBath",
@ -408,32 +413,70 @@ def get_ames_housing():
@memory.cache @memory.cache
def get_mq2008(dpath): def get_mq2008(
dpath: str,
) -> Tuple[
sparse.csr_matrix,
np.ndarray,
np.ndarray,
sparse.csr_matrix,
np.ndarray,
np.ndarray,
sparse.csr_matrix,
np.ndarray,
np.ndarray,
]:
from sklearn.datasets import load_svmlight_files from sklearn.datasets import load_svmlight_files
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
target = dpath + '/MQ2008.zip' target = dpath + "/MQ2008.zip"
if not os.path.exists(target): if not os.path.exists(target):
urllib.request.urlretrieve(url=src, filename=target) urllib.request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, 'r') as f: with zipfile.ZipFile(target, "r") as f:
f.extractall(path=dpath) f.extractall(path=dpath)
(x_train, y_train, qid_train, x_test, y_test, qid_test, (
x_valid, y_valid, qid_valid) = load_svmlight_files( x_train,
(dpath + "MQ2008/Fold1/train.txt", y_train,
dpath + "MQ2008/Fold1/test.txt", qid_train,
dpath + "MQ2008/Fold1/vali.txt"), x_test,
query_id=True, zero_based=False) y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
) = load_svmlight_files(
(
dpath + "MQ2008/Fold1/train.txt",
dpath + "MQ2008/Fold1/test.txt",
dpath + "MQ2008/Fold1/vali.txt",
),
query_id=True,
zero_based=False,
)
return (x_train, y_train, qid_train, x_test, y_test, qid_test, return (
x_valid, y_valid, qid_valid) x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
)
@memory.cache @memory.cache
def make_categorical( def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot: bool, sparsity=0.0, n_samples: int,
): n_features: int,
n_categories: int,
onehot: bool,
sparsity: float = 0.0,
) -> Tuple[ArrayLike, np.ndarray]:
import pandas as pd import pandas as pd
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -457,7 +500,9 @@ def make_categorical(
if sparsity > 0.0: if sparsity > 0.0:
for i in range(n_features): for i in range(n_features):
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity)) index = rng.randint(
low=0, high=n_samples - 1, size=int(n_samples * sparsity)
)
df.iloc[index, i] = np.NaN df.iloc[index, i] = np.NaN
assert n_categories == np.unique(df.dtypes[i].categories).size assert n_categories == np.unique(df.dtypes[i].categories).size
@ -466,9 +511,9 @@ def make_categorical(
return df, label return df, label
def _cat_sampled_from(): def _cat_sampled_from() -> strategies.SearchStrategy:
@strategies.composite @strategies.composite
def _make_cat(draw): def _make_cat(draw: Callable) -> Tuple[int, int, int, float]:
n_samples = draw(strategies.integers(2, 512)) n_samples = draw(strategies.integers(2, 512))
n_features = draw(strategies.integers(1, 4)) n_features = draw(strategies.integers(1, 4))
n_cats = draw(strategies.integers(1, 128)) n_cats = draw(strategies.integers(1, 128))
@ -483,7 +528,7 @@ def _cat_sampled_from():
) )
return n_samples, n_features, n_cats, sparsity return n_samples, n_features, n_cats, sparsity
def _build(args): def _build(args: Tuple[int, int, int, float]) -> TestDataset:
n_samples = args[0] n_samples = args[0]
n_features = args[1] n_features = args[1]
n_cats = args[2] n_cats = args[2]
@ -495,12 +540,13 @@ def _cat_sampled_from():
"rmse", "rmse",
) )
return _make_cat().map(_build) return _make_cat().map(_build) # pylint: disable=no-member
categorical_dataset_strategy = _cat_sampled_from() categorical_dataset_strategy: strategies.SearchStrategy = _cat_sampled_from()
# pylint: disable=too-many-locals
@memory.cache @memory.cache
def make_sparse_regression( def make_sparse_regression(
n_samples: int, n_features: int, sparsity: float, as_dense: bool n_samples: int, n_features: int, sparsity: float, as_dense: bool
@ -530,8 +576,7 @@ def make_sparse_regression(
# Use multi-thread to speed up the generation, convenient if you use this function # Use multi-thread to speed up the generation, convenient if you use this function
# for benchmarking. # for benchmarking.
n_threads = multiprocessing.cpu_count() n_threads = min(multiprocessing.cpu_count(), n_features)
n_threads = min(n_threads, n_features)
def random_csc(t_id: int) -> sparse.csc_matrix: def random_csc(t_id: int) -> sparse.csc_matrix:
rng = np.random.default_rng(1994 * t_id) rng = np.random.default_rng(1994 * t_id)
@ -653,7 +698,7 @@ _unweighted_datasets_strategy = strategies.sampled_from(
@strategies.composite @strategies.composite
def _dataset_weight_margin(draw): def _dataset_weight_margin(draw: Callable) -> TestDataset:
data: TestDataset = draw(_unweighted_datasets_strategy) data: TestDataset = draw(_unweighted_datasets_strategy)
if draw(strategies.booleans()): if draw(strategies.booleans()):
data.w = draw( data.w = draw(
@ -673,6 +718,7 @@ def _dataset_weight_margin(draw):
elements=strategies.floats(0.5, 1.0), elements=strategies.floats(0.5, 1.0),
) )
) )
assert data.margin is not None
if num_class != 1: if num_class != 1:
data.margin = data.margin.reshape(data.y.shape[0], num_class) data.margin = data.margin.reshape(data.y.shape[0], num_class)
@ -684,24 +730,24 @@ def _dataset_weight_margin(draw):
dataset_strategy = _dataset_weight_margin() dataset_strategy = _dataset_weight_margin()
def non_increasing(L, tolerance=1e-4): def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool:
return all((y - x) < tolerance for x, y in zip(L, L[1:])) return all((y - x) < tolerance for x, y in zip(L, L[1:]))
def eval_error_metric(predt, dtrain: xgb.DMatrix): def eval_error_metric(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, np.float64]:
"""Evaluation metric for xgb.train""" """Evaluation metric for xgb.train"""
label = dtrain.get_label() label = dtrain.get_label()
r = np.zeros(predt.shape) r = np.zeros(predt.shape)
gt = predt > 0.5 gt = predt > 0.5
if predt.size == 0: if predt.size == 0:
return "CustomErr", 0 return "CustomErr", np.float64(0.0)
r[gt] = 1 - label[gt] r[gt] = 1 - label[gt]
le = predt <= 0.5 le = predt <= 0.5
r[le] = label[le] r[le] = label[le]
return 'CustomErr', np.sum(r) return "CustomErr", np.sum(r)
def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float: def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> np.float64:
"""Evaluation metric that looks like metrics provided by sklearn.""" """Evaluation metric that looks like metrics provided by sklearn."""
r = np.zeros(y_score.shape) r = np.zeros(y_score.shape)
gt = y_score > 0.5 gt = y_score > 0.5
@ -717,13 +763,15 @@ def root_mean_square(y_true: np.ndarray, y_score: np.ndarray) -> float:
return rmse return rmse
def softmax(x): def softmax(x: np.ndarray) -> np.ndarray:
e = np.exp(x) e = np.exp(x)
return e / np.sum(e) return e / np.sum(e)
def softprob_obj(classes): def softprob_obj(classes: int) -> SklObjective:
def objective(labels, predt): def objective(
labels: np.ndarray, predt: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
rows = labels.shape[0] rows = labels.shape[0]
grad = np.zeros((rows, classes), dtype=float) grad = np.zeros((rows, classes), dtype=float)
hess = np.zeros((rows, classes), dtype=float) hess = np.zeros((rows, classes), dtype=float)
@ -746,29 +794,33 @@ def softprob_obj(classes):
class DirectoryExcursion: class DirectoryExcursion:
def __init__(self, path: os.PathLike, cleanup=False): """Change directory. Change back and optionally cleaning up the directory when
'''Change directory. Change back and optionally cleaning up the directory when exit. exit.
''' """
def __init__(self, path: os.PathLike, cleanup: bool = False):
self.path = path self.path = path
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir)) self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
self.cleanup = cleanup self.cleanup = cleanup
self.files = {} self.files: Set[str] = set()
def __enter__(self): def __enter__(self) -> None:
os.chdir(self.path) os.chdir(self.path)
if self.cleanup: if self.cleanup:
self.files = { self.files = {
os.path.join(root, f) os.path.join(root, f)
for root, subdir, files in os.walk(self.path) for f in files for root, subdir, files in os.walk(os.path.expanduser(self.path))
for f in files
} }
def __exit__(self, *args): def __exit__(self, *args: Any) -> None:
os.chdir(self.curdir) os.chdir(self.curdir)
if self.cleanup: if self.cleanup:
files = { files = {
os.path.join(root, f) os.path.join(root, f)
for root, subdir, files in os.walk(self.path) for f in files for root, subdir, files in os.walk(os.path.expanduser(self.path))
for f in files
} }
diff = files.difference(self.files) diff = files.difference(self.files)
for f in diff: for f in diff:
@ -776,7 +828,7 @@ class DirectoryExcursion:
@contextmanager @contextmanager
def captured_output(): def captured_output() -> Generator[Tuple[StringIO, StringIO], None, None]:
"""Reassign stdout temporarily in order to test printed statements """Reassign stdout temporarily in order to test printed statements
Taken from: Taken from:
https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
@ -793,14 +845,46 @@ def captured_output():
sys.stdout, sys.stderr = old_out, old_err sys.stdout, sys.stderr = old_out, old_err
try: def timeout(sec: int, *args: Any, enable: bool = True, **kwargs: Any) -> Any:
# Python 3.7+ """Make a pytest mark for the `pytest-timeout` package.
from contextlib import nullcontext as noop_context
except ImportError: Parameters
# Python 3.6 ----------
from contextlib import suppress as noop_context sec :
Timeout seconds.
enable :
Control whether timeout should be applied, used for debugging.
Returns
-------
pytest.mark.timeout
"""
if enable:
return pytest.mark.timeout(sec, *args, **kwargs)
return pytest.mark.timeout(None, *args, **kwargs)
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__))) def demo_dir(path: str) -> str:
PROJECT_ROOT = os.path.normpath( """Look for the demo directory based on the test file name."""
os.path.join(CURDIR, os.path.pardir, os.path.pardir)) path = normpath(os.path.dirname(path))
while True:
subdirs = [f.path for f in os.scandir(path) if f.is_dir()]
subdirs = [os.path.basename(d) for d in subdirs]
if "demo" in subdirs:
return os.path.join(path, "demo")
new_path = normpath(os.path.join(path, os.path.pardir))
assert new_path != path
path = new_path
def normpath(path: str) -> str:
return os.path.normpath(os.path.abspath(path))
def data_dir(path: str) -> str:
return os.path.join(demo_dir(path), "data")
def project_root(path: str) -> str:
return normpath(os.path.join(demo_dir(path), os.path.pardir))

View File

@ -121,12 +121,14 @@ if __name__ == "__main__":
"python-package/xgboost/sklearn.py", "python-package/xgboost/sklearn.py",
"python-package/xgboost/spark", "python-package/xgboost/spark",
"python-package/xgboost/federated.py", "python-package/xgboost/federated.py",
"python-package/xgboost/testing.py", "python-package/xgboost/testing",
# tests # tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_data_iterator.py",
"tests/python/test_spark/", "tests/python/test_spark/",
"tests/python/test_quantile_dmatrix.py", "tests/python/test_quantile_dmatrix.py",
"tests/python-gpu/test_gpu_spark/", "tests/python-gpu/test_gpu_spark/",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
# demo # demo
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",

View File

@ -1,9 +1,7 @@
import sys
import pytest import pytest
import logging
sys.path.append("tests/python") from xgboost import testing as tm # noqa
import testing as tm # noqa
def has_rmm(): def has_rmm():
try: try:
@ -34,8 +32,8 @@ def local_cuda_client(request, pytestconfig):
kwargs['rmm_pool_size'] = '2GB' kwargs['rmm_pool_size'] = '2GB'
if tm.no_dask_cuda()['condition']: if tm.no_dask_cuda()['condition']:
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package') raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
from dask_cuda import LocalCUDACluster
from dask.distributed import Client from dask.distributed import Client
from dask_cuda import LocalCUDACluster
yield Client(LocalCUDACluster(**kwargs)) yield Client(LocalCUDACluster(**kwargs))
def pytest_addoption(parser): def pytest_addoption(parser):

View File

@ -1,16 +1,14 @@
'''Loading a pickled model generated by test_pickling.py, only used by '''Loading a pickled model generated by test_pickling.py, only used by
`test_gpu_with_dask.py`''' `test_gpu_with_dask.py`'''
import os
import numpy as np
import xgboost as xgb
import json import json
import os
import numpy as np
import pytest import pytest
import sys from test_gpu_pickling import build_dataset, load_pickle, model_path
from test_gpu_pickling import build_dataset, model_path, load_pickle import xgboost as xgb
from xgboost import testing as tm
sys.path.append("tests/python")
import testing as tm
class TestLoadPickle: class TestLoadPickle:

View File

@ -5,10 +5,10 @@ import pytest
from hypothesis import given, settings, strategies from hypothesis import given, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import test_quantile_dmatrix as tqd import test_quantile_dmatrix as tqd
import testing as tm
class TestDeviceQuantileDMatrix: class TestDeviceQuantileDMatrix:

View File

@ -2,11 +2,12 @@ import json
import sys import sys
import numpy as np import numpy as np
import xgboost as xgb
import pytest import pytest
import xgboost as xgb
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm
from test_dmatrix import set_base_margin_info from test_dmatrix import set_base_margin_info
@ -85,8 +86,8 @@ def _test_from_cudf(DMatrixT):
def _test_cudf_training(DMatrixT): def _test_cudf_training(DMatrixT):
from cudf import DataFrame as df
import pandas as pd import pandas as pd
from cudf import DataFrame as df
np.random.seed(1) np.random.seed(1)
X = pd.DataFrame(np.random.randn(50, 10)) X = pd.DataFrame(np.random.randn(50, 10))
y = pd.DataFrame(np.random.randn(50)) y = pd.DataFrame(np.random.randn(50))
@ -109,8 +110,8 @@ def _test_cudf_training(DMatrixT):
def _test_cudf_metainfo(DMatrixT): def _test_cudf_metainfo(DMatrixT):
from cudf import DataFrame as df
import pandas as pd import pandas as pd
from cudf import DataFrame as df
n = 100 n = 100
X = np.random.random((n, 2)) X = np.random.random((n, 2))
dmat_cudf = DMatrixT(df.from_pandas(pd.DataFrame(X))) dmat_cudf = DMatrixT(df.from_pandas(pd.DataFrame(X)))
@ -247,9 +248,9 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_cudf_training_with_sklearn(): def test_cudf_training_with_sklearn():
import pandas as pd
from cudf import DataFrame as df from cudf import DataFrame as df
from cudf import Series as ss from cudf import Series as ss
import pandas as pd
np.random.seed(1) np.random.seed(1)
X = pd.DataFrame(np.random.randn(50, 10)) X = pd.DataFrame(np.random.randn(50, 10))
y = pd.DataFrame((np.random.randn(50) > 0).astype(np.int8)) y = pd.DataFrame((np.random.randn(50) > 0).astype(np.int8))

View File

@ -1,12 +1,15 @@
import numpy as np
import xgboost as xgb
import sys import sys
import numpy as np
import pytest import pytest
import xgboost as xgb
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm
from test_dmatrix import set_base_margin_info from test_dmatrix import set_base_margin_info
from xgboost import testing as tm
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN): def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
'''Test constructing DMatrix from cupy''' '''Test constructing DMatrix from cupy'''

View File

@ -1,13 +1,18 @@
import sys
import os import os
import sys
import numpy as np import numpy as np
import xgboost as xgb
import pytest import pytest
import xgboost as xgb
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import test_basic_models as test_bm
# Don't import the test class, otherwise they will run twice. # Don't import the test class, otherwise they will run twice.
import test_callback as test_cb # noqa import test_callback as test_cb # noqa
import test_basic_models as test_bm
import testing as tm
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -1,13 +1,12 @@
import numpy as np
import xgboost as xgb
from hypothesis import given, strategies, settings
import pytest
import sys import sys
import pytest
from hypothesis import given, settings, strategies
from xgboost.testing import no_cupy
sys.path.append("tests/python") sys.path.append("tests/python")
from test_data_iterator import test_single_batch as cpu_single_batch
from test_data_iterator import run_data_iterator from test_data_iterator import run_data_iterator
from testing import no_cupy from test_data_iterator import test_single_batch as cpu_single_batch
def test_gpu_single_batch() -> None: def test_gpu_single_batch() -> None:
@ -24,7 +23,11 @@ def test_gpu_single_batch() -> None:
) )
@settings(deadline=None, max_examples=10, print_blob=True) @settings(deadline=None, max_examples=10, print_blob=True)
def test_gpu_data_iterator( def test_gpu_data_iterator(
n_samples_per_batch: int, n_features: int, n_batches: int, subsample: bool, use_cupy: bool n_samples_per_batch: int,
n_features: int,
n_batches: int,
subsample: bool,
use_cupy: bool,
) -> None: ) -> None:
run_data_iterator( run_data_iterator(
n_samples_per_batch, n_features, n_batches, "gpu_hist", subsample, use_cupy n_samples_per_batch, n_features, n_batches, "gpu_hist", subsample, use_cupy

View File

@ -1,10 +1,13 @@
import os import os
import subprocess import subprocess
import sys import sys
import pytest import pytest
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm import test_demos as td # noqa
import test_demos as td # noqa
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@ -31,6 +34,6 @@ def test_categorical_demo():
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_dask_training(): def test_dask_training():
script = os.path.join(tm.PROJECT_ROOT, 'demo', 'dask', 'gpu_training.py') script = os.path.join(tm.demo_dir(__file__), 'dask', 'gpu_training.py')
cmd = ['python', script] cmd = ['python', script]
subprocess.check_call(cmd) subprocess.check_call(cmd)

View File

@ -1,7 +1,9 @@
import sys import sys
import xgboost
import pytest import pytest
import xgboost
sys.path.append("tests/python") sys.path.append("tests/python")
import test_eval_metrics as test_em # noqa import test_eval_metrics as test_em # noqa

View File

@ -1,8 +1,11 @@
import numpy as np
import sys import sys
import numpy as np
sys.path.append("tests/python") sys.path.append("tests/python")
# Don't import the test class, otherwise they will run twice. # Don't import the test class, otherwise they will run twice.
import test_interaction_constraints as test_ic # noqa import test_interaction_constraints as test_ic # noqa
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -1,15 +1,10 @@
import sys
import pytest import pytest
from hypothesis import assume, given, note, settings, strategies from hypothesis import assume, given, note, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
sys.path.append("tests/python") pytestmark = tm.timeout(10)
import testing as tm
pytestmark = testing.timeout(10)
parameter_strategy = strategies.fixed_dictionaries({ parameter_strategy = strategies.fixed_dictionaries({
'booster': strategies.just('gblinear'), 'booster': strategies.just('gblinear'),

View File

@ -3,20 +3,17 @@ import json
import os import os
import pickle import pickle
import subprocess import subprocess
import sys
import numpy as np import numpy as np
import pytest import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import XGBClassifier, testing from xgboost import XGBClassifier
from xgboost import testing as tm
sys.path.append("tests/python")
import testing as tm
model_path = './model.pkl' model_path = './model.pkl'
pytestmark = testing.timeout(30) pytestmark = tm.timeout(30)
def build_dataset(): def build_dataset():

View File

@ -1,10 +1,11 @@
import sys import sys
import pytest import pytest
sys.path.append("tests/python") from xgboost import testing as tm
import testing as tm
import test_plotting as tp
sys.path.append("tests/python")
import test_plotting as tp
pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz())) pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz()))

View File

@ -6,7 +6,7 @@ from hypothesis import assume, given, settings, strategies
from xgboost.compat import PANDAS_INSTALLED from xgboost.compat import PANDAS_INSTALLED
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
if PANDAS_INSTALLED: if PANDAS_INSTALLED:
from hypothesis.extra.pandas import column, data_frames, range_indexes from hypothesis.extra.pandas import column, data_frames, range_indexes
@ -16,7 +16,6 @@ else:
column, data_frames, range_indexes = noop, noop, noop column, data_frames, range_indexes = noop, noop, noop
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm
from test_predict import run_predict_leaf # noqa from test_predict import run_predict_leaf # noqa
from test_predict import run_threaded_predict # noqa from test_predict import run_threaded_predict # noqa
@ -33,7 +32,7 @@ predict_parameter_strategy = strategies.fixed_dictionaries({
'num_parallel_tree': strategies.sampled_from([1, 4]), 'num_parallel_tree': strategies.sampled_from([1, 4]),
}) })
pytestmark = testing.timeout(20) pytestmark = tm.timeout(20)
class TestGPUPredict: class TestGPUPredict:
@ -227,8 +226,8 @@ class TestGPUPredict:
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_inplace_predict_cudf(self): def test_inplace_predict_cudf(self):
import cupy as cp
import cudf import cudf
import cupy as cp
import pandas as pd import pandas as pd
rows = 1000 rows = 1000
cols = 10 cols = 10
@ -379,8 +378,8 @@ class TestGPUPredict:
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.parametrize("n_classes", [2, 3]) @pytest.mark.parametrize("n_classes", [2, 3])
def test_predict_dart(self, n_classes): def test_predict_dart(self, n_classes):
from sklearn.datasets import make_classification
import cupy as cp import cupy as cp
from sklearn.datasets import make_classification
n_samples = 1000 n_samples = 1000
X_, y_ = make_classification( X_, y_ = make_classification(
n_samples=n_samples, n_informative=5, n_classes=n_classes n_samples=n_samples, n_informative=5, n_classes=n_classes

View File

@ -1,20 +1,15 @@
import itertools import itertools
import os import os
import shutil import shutil
import sys
import urllib.request import urllib.request
import zipfile import zipfile
import numpy as np import numpy as np
import xgboost import xgboost
from xgboost import testing from xgboost import testing as tm
sys.path.append("tests/python") pytestmark = tm.timeout(10)
import testing as tm # noqa
pytestmark = testing.timeout(10)
class TestRanking: class TestRanking:
@ -24,8 +19,9 @@ class TestRanking:
Download and setup the test fixtures Download and setup the test fixtures
""" """
from sklearn.datasets import load_svmlight_files from sklearn.datasets import load_svmlight_files
# download the test data # download the test data
cls.dpath = os.path.join(tm.PROJECT_ROOT, "demo/rank/") cls.dpath = os.path.join(tm.demo_dir(__file__), "rank/")
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = os.path.join(cls.dpath, "MQ2008.zip") target = os.path.join(cls.dpath, "MQ2008.zip")

View File

@ -1,13 +1,8 @@
import sys import sys
from typing import List
import numpy as np
import pandas as pd
import pytest import pytest
sys.path.append("tests/python") from xgboost import testing as tm
import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
@ -15,6 +10,7 @@ if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True)
sys.path.append("tests/python")
from test_spark.test_data import run_dmatrix_ctor from test_spark.test_data import run_dmatrix_ctor

View File

@ -6,8 +6,7 @@ import sys
import pytest import pytest
import sklearn import sklearn
sys.path.append("tests/python") from xgboost import testing as tm
import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)

View File

@ -1,7 +1,9 @@
import numpy as np
import xgboost as xgb
import json import json
import numpy as np
import xgboost as xgb
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -6,13 +6,12 @@ import pytest
from hypothesis import assume, given, note, settings, strategies from hypothesis import assume, given, note, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import test_updaters as test_up import test_updaters as test_up
import testing as tm
pytestmark = testing.timeout(30) pytestmark = tm.timeout(30)
parameter_strategy = strategies.fixed_dictionaries({ parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(0, 11), 'max_depth': strategies.integers(0, 11),

View File

@ -1,52 +1,54 @@
"""Copyright 2019-2022 XGBoost contributors""" """Copyright 2019-2022 XGBoost contributors"""
import sys
import os
from typing import Type, TypeVar, Any, Dict, List, Union
import pytest
import numpy as np
import asyncio import asyncio
import xgboost import os
import subprocess import subprocess
import sys
from collections import OrderedDict from collections import OrderedDict
from inspect import signature from inspect import signature
from hypothesis import given, strategies, settings, note from typing import Any, Dict, Type, TypeVar, Union
import numpy as np
import pytest
from hypothesis import given, note, settings, strategies
from hypothesis._settings import duration from hypothesis._settings import duration
from test_gpu_updaters import parameter_strategy from test_gpu_updaters import parameter_strategy
import xgboost
from xgboost import testing as tm
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm # noqa
if tm.no_dask_cuda()["condition"]: if tm.no_dask_cuda()["condition"]:
pytest.skip(tm.no_dask_cuda()["reason"], allow_module_level=True) pytest.skip(tm.no_dask_cuda()["reason"], allow_module_level=True)
from test_with_dask import run_empty_dmatrix_reg # noqa from test_with_dask import _get_client_workers # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa from test_with_dask import generate_array # noqa
from test_with_dask import make_categorical # noqa
from test_with_dask import run_auc # noqa from test_with_dask import run_auc # noqa
from test_with_dask import run_boost_from_prediction # noqa from test_with_dask import run_boost_from_prediction # noqa
from test_with_dask import run_boost_from_prediction_multi_class # noqa from test_with_dask import run_boost_from_prediction_multi_class # noqa
from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
from test_with_dask import run_tree_stats # noqa
from test_with_dask import run_categorical # noqa from test_with_dask import run_categorical # noqa
from test_with_dask import make_categorical # noqa from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_tree_stats # noqa
from test_with_dask import suppress # noqa
from test_with_dask import kCols as random_cols # noqa
try: try:
import dask.dataframe as dd
from xgboost import dask as dxgb
import xgboost as xgb
from dask.distributed import Client
from dask import array as da
from dask_cuda import LocalCUDACluster, utils
import cudf import cudf
import dask.dataframe as dd
from dask import array as da
from dask.distributed import Client
from dask_cuda import LocalCUDACluster, utils
import xgboost as xgb
from xgboost import dask as dxgb
except ImportError: except ImportError:
pass pass
@ -334,9 +336,9 @@ class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf())
def test_empty_partition(self, local_cuda_client: Client) -> None: def test_empty_partition(self, local_cuda_client: Client) -> None:
import dask_cudf
import cudf import cudf
import cupy import cupy
import dask_cudf
mult = 100 mult = 100
df = cudf.DataFrame( df = cudf.DataFrame(

View File

@ -1,13 +1,15 @@
import json import json
import xgboost as xgb
import pytest
import tempfile
import sys
import numpy as np
import os import os
import sys
import tempfile
import numpy as np
import pytest
import xgboost as xgb
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm # noqa
import test_with_sklearn as twskl # noqa import test_with_sklearn as twskl # noqa
pytestmark = pytest.mark.skipif(**tm.no_sklearn()) pytestmark = pytest.mark.skipif(**tm.no_sklearn())
@ -38,9 +40,9 @@ def test_gpu_binary_classification():
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_boost_from_prediction_gpu_hist(): def test_boost_from_prediction_gpu_hist():
from sklearn.datasets import load_breast_cancer, load_digits
import cupy as cp
import cudf import cudf
import cupy as cp
from sklearn.datasets import load_breast_cancer, load_digits
tree_method = "gpu_hist" tree_method = "gpu_hist"
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
@ -68,12 +70,12 @@ def test_num_parallel_tree():
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_categorical(): def test_categorical():
import pandas as pd
import cudf import cudf
import cupy as cp import cupy as cp
import pandas as pd
from sklearn.datasets import load_svmlight_file from sklearn.datasets import load_svmlight_file
data_dir = os.path.join(tm.PROJECT_ROOT, "demo", "data") data_dir = tm.data_dir(__file__)
X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train")) X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train"))
clf = xgb.XGBClassifier( clf = xgb.XGBClassifier(
tree_method="gpu_hist", tree_method="gpu_hist",
@ -123,9 +125,9 @@ def test_categorical():
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_classififer(): def test_classififer():
from sklearn.datasets import load_digits
import cupy as cp
import cudf import cudf
import cupy as cp
from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True) X, y = load_digits(return_X_y=True)
y *= 10 y *= 10

View File

@ -1,23 +1,23 @@
import numpy as np import cupy as cp
import xgboost as xgb import numpy as np
import cupy as cp import pytest
import time
import pytest import xgboost as xgb
# Test for integer overflow or out of memory exceptions # Test for integer overflow or out of memory exceptions
def test_large_input(): def test_large_input():
available_bytes, _ = cp.cuda.runtime.memGetInfo() available_bytes, _ = cp.cuda.runtime.memGetInfo()
# 15 GB # 15 GB
required_bytes = 1.5e+10 required_bytes = 1.5e+10
if available_bytes < required_bytes: if available_bytes < required_bytes:
pytest.skip("Not enough memory on this device") pytest.skip("Not enough memory on this device")
n = 1000 n = 1000
m = ((1 << 31) + n - 1) // n m = ((1 << 31) + n - 1) // n
assert (np.log2(m * n) > 31) assert (np.log2(m * n) > 31)
X = cp.ones((m, n), dtype=np.float32) X = cp.ones((m, n), dtype=np.float32)
y = cp.ones(m) y = cp.ones(m)
dmat = xgb.DeviceQuantileDMatrix(X, y) dmat = xgb.DeviceQuantileDMatrix(X, y)
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1) booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
del y del y
booster.inplace_predict(X) booster.inplace_predict(X)

View File

@ -1,11 +1,12 @@
import sys import sys
import numpy as np
import numpy as np
import pytest import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm
import test_monotone_constraints as tmc import test_monotone_constraints as tmc
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -1,7 +1,9 @@
import xgboost
import numpy as np
import os import os
import numpy as np
import xgboost
kRounds = 2 kRounds = 2
kRows = 1000 kRows = 1000
kCols = 4 kCols = 4

View File

@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
import numpy as np
import os
import xgboost as xgb
import pytest
import json import json
from pathlib import Path import os
import tempfile import tempfile
import testing as tm from pathlib import Path
import numpy as np
import pytest
import xgboost as xgb
from xgboost import testing as tm
dpath = 'demo/data/' dpath = 'demo/data/'
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -1,13 +1,15 @@
import numpy as np
import xgboost as xgb
import os
import json import json
import testing as tm
import pytest
import locale import locale
import os
import tempfile import tempfile
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/') import numpy as np
import pytest
import xgboost as xgb
from xgboost import testing as tm
dpath = tm.data_dir(__file__)
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -36,8 +38,8 @@ class TestModels:
param = {'verbosity': 0, 'objective': 'binary:logistic', param = {'verbosity': 0, 'objective': 'binary:logistic',
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
'nthread': 1} 'nthread': 1}
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4 num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist) bst = xgb.train(param, dtrain, num_round, watchlist)
@ -49,8 +51,8 @@ class TestModels:
assert err < 0.2 assert err < 0.2
def test_dart(self): def test_dart(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
param = {'max_depth': 5, 'objective': 'binary:logistic', param = {'max_depth': 5, 'objective': 'binary:logistic',
'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1} 'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1}
# specify validations set to watch performance # specify validations set to watch performance
@ -116,7 +118,7 @@ class TestModels:
def test_boost_from_prediction(self): def test_boost_from_prediction(self):
# Re-construct dtrain here to avoid modification # Re-construct dtrain here to avoid modification
margined = xgb.DMatrix(dpath + 'agaricus.txt.train') margined = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
bst = xgb.train({'tree_method': 'hist'}, margined, 1) bst = xgb.train({'tree_method': 'hist'}, margined, 1)
predt_0 = bst.predict(margined, output_margin=True) predt_0 = bst.predict(margined, output_margin=True)
margined.set_base_margin(predt_0) margined.set_base_margin(predt_0)
@ -124,13 +126,13 @@ class TestModels:
predt_1 = bst.predict(margined) predt_1 = bst.predict(margined)
assert np.any(np.abs(predt_1 - predt_0) > 1e-6) assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2) bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
predt_2 = bst.predict(dtrain) predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6) assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_boost_from_existing_model(self): def test_boost_from_existing_model(self):
X = xgb.DMatrix(dpath + 'agaricus.txt.train') X = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4) booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4 assert booster.num_boosted_rounds() == 4
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4, booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4,
@ -150,8 +152,8 @@ class TestModels:
'objective': 'reg:logistic', 'objective': 'reg:logistic',
"tree_method": tree_method "tree_method": tree_method
} }
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 10 num_round = 10
@ -197,8 +199,8 @@ class TestModels:
self.run_custom_objective() self.run_custom_objective()
def test_multi_eval_metric(self): def test_multi_eval_metric(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
param = {'max_depth': 2, 'eta': 0.2, 'verbosity': 1, param = {'max_depth': 2, 'eta': 0.2, 'verbosity': 1,
'objective': 'binary:logistic'} 'objective': 'binary:logistic'}
@ -220,7 +222,7 @@ class TestModels:
param['scale_pos_weight'] = ratio param['scale_pos_weight'] = ratio
return (dtrain, dtest, param) return (dtrain, dtest, param)
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
xgb.cv(param, dtrain, num_round, nfold=5, xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'auc'}, seed=0, fpreproc=fpreproc) metrics={'auc'}, seed=0, fpreproc=fpreproc)
@ -228,7 +230,7 @@ class TestModels:
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'} 'objective': 'binary:logistic'}
num_round = 2 num_round = 2
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
xgb.cv(param, dtrain, num_round, nfold=5, xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0, show_stdv=False) metrics={'error'}, seed=0, show_stdv=False)
@ -346,7 +348,7 @@ class TestModels:
os.remove(model_path) os.remove(model_path)
try: try:
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
xgb.train({'objective': 'foo'}, dtrain, num_boost_round=1) xgb.train({'objective': 'foo'}, dtrain, num_boost_round=1)
except ValueError as e: except ValueError as e:
e_str = str(e) e_str = str(e)

View File

@ -1,9 +1,12 @@
from typing import Union
import xgboost as xgb
import pytest
import os import os
import testing as tm
import tempfile import tempfile
from contextlib import nullcontext
from typing import Union
import pytest
import xgboost as xgb
from xgboost import testing as tm
# We use the dataset for tests. # We use the dataset for tests.
pytestmark = pytest.mark.skipif(**tm.no_sklearn()) pytestmark = pytest.mark.skipif(**tm.no_sklearn())
@ -271,13 +274,14 @@ class TestCallbacks:
"""Test learning rate scheduler, used by both CPU and GPU tests.""" """Test learning rate scheduler, used by both CPU and GPU tests."""
scheduler = xgb.callback.LearningRateScheduler scheduler = xgb.callback.LearningRateScheduler
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/') dpath = tm.data_dir(__file__)
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train"))
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test"))
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4 num_round = 4
warning_check = tm.noop_context() warning_check = nullcontext()
# learning_rates as a list # learning_rates as a list
# init eta with 0 to check whether learning_rates work # init eta with 0 to check whether learning_rates work

View File

@ -1,11 +1,13 @@
import os
import tempfile
import platform
import xgboost
import subprocess
import numpy
import json import json
import testing as tm import os
import platform
import subprocess
import tempfile
import numpy
import xgboost
from xgboost import testing as tm
class TestCLI: class TestCLI:
@ -29,7 +31,7 @@ data = {data_path}
eval[test] = {data_path} eval[test] = {data_path}
''' '''
PROJECT_ROOT = tm.PROJECT_ROOT PROJECT_ROOT = tm.project_root(__file__)
def get_exe(self): def get_exe(self):
if platform.system() == 'Windows': if platform.system() == 'Windows':

View File

@ -1,14 +1,16 @@
from typing import Dict, List
import numpy as np import numpy as np
import pytest import pytest
from hypothesis import given, settings, strategies from hypothesis import given, settings, strategies
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from testing import IteratorForTest, make_batches, non_increasing
from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
pytestmark = testing.timeout(30) pytestmark = tm.timeout(30)
def test_single_batch(tree_method: str = "approx") -> None: def test_single_batch(tree_method: str = "approx") -> None:
@ -83,7 +85,7 @@ def run_data_iterator(
if tree_method == "gpu_hist": if tree_method == "gpu_hist":
parameters["sampling_method"] = "gradient_based" parameters["sampling_method"] = "gradient_based"
results_from_it: xgb.callback.EvaluationMonitor.EvalsLog = {} results_from_it: Dict[str, Dict[str, List[float]]] = {}
from_it = xgb.train( from_it = xgb.train(
parameters, parameters,
Xy, Xy,
@ -106,7 +108,7 @@ def run_data_iterator(
assert Xy.num_row() == n_samples_per_batch * n_batches assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features assert Xy.num_col() == n_features
results_from_arrays: xgb.callback.EvaluationMonitor.EvalsLog = {} results_from_arrays: Dict[str, Dict[str, List[float]]] = {}
from_arrays = xgb.train( from_arrays = xgb.train(
parameters, parameters,
Xy, Xy,

View File

@ -3,14 +3,12 @@ import subprocess
import sys import sys
import pytest import pytest
import testing as tm
from xgboost import testing from xgboost import testing as tm
pytestmark = testing.timeout(30) pytestmark = tm.timeout(30)
ROOT_DIR = tm.PROJECT_ROOT DEMO_DIR = tm.demo_dir(__file__)
DEMO_DIR = os.path.join(ROOT_DIR, 'demo')
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python') PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python')
CLI_DEMO_DIR = os.path.join(DEMO_DIR, 'CLI') CLI_DEMO_DIR = os.path.join(DEMO_DIR, 'CLI')
@ -156,7 +154,7 @@ def test_cli_regression_demo():
cmd = ['python', script, 'machine.txt', '1'] cmd = ['python', script, 'machine.txt', '1']
subprocess.check_call(cmd, cwd=reg_dir) subprocess.check_call(cmd, cwd=reg_dir)
exe = os.path.join(tm.PROJECT_ROOT, 'xgboost') exe = os.path.join(DEMO_DIR, os.path.pardir, 'xgboost')
conf = os.path.join(reg_dir, 'machine.conf') conf = os.path.join(reg_dir, 'machine.conf')
subprocess.check_call([exe, conf], cwd=reg_dir) subprocess.check_call([exe, conf], cwd=reg_dir)

View File

@ -4,11 +4,11 @@ import tempfile
import numpy as np import numpy as np
import pytest import pytest
import scipy.sparse import scipy.sparse
import testing as tm
from hypothesis import given, settings, strategies from hypothesis import given, settings, strategies
from scipy.sparse import csr_matrix, rand from scipy.sparse import csr_matrix, rand
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
rng = np.random.RandomState(1) rng = np.random.RandomState(1)

View File

@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
import pytest
import numpy as np import numpy as np
import pytest
import testing as tm
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
try: try:
import datatable as dt import datatable as dt

View File

@ -1,8 +1,9 @@
import xgboost as xgb
import testing as tm
import numpy as np import numpy as np
import pytest import pytest
import xgboost as xgb
from xgboost import testing as tm
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -1,8 +1,9 @@
import xgboost as xgb
import testing as tm
import numpy as np import numpy as np
import pytest import pytest
import xgboost as xgb
from xgboost import testing as tm
rng = np.random.RandomState(1337) rng = np.random.RandomState(1337)
@ -254,8 +255,8 @@ class TestEvalMetrics:
self.run_roc_auc_multi("hist", n_samples, weighted) self.run_roc_auc_multi("hist", n_samples, weighted)
def run_pr_auc_binary(self, tree_method): def run_pr_auc_binary(self, tree_method):
from sklearn.metrics import precision_recall_curve, auc
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
from sklearn.metrics import auc, precision_recall_curve
X, y = make_classification(128, 4, n_classes=2, random_state=1994) X, y = make_classification(128, 4, n_classes=2, random_state=1994)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1) clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])

View File

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost
import testing as tm
import pytest import pytest
import xgboost
from xgboost import testing as tm
dpath = 'demo/data/' dpath = 'demo/data/'
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -1,10 +1,9 @@
import testing as tm
from hypothesis import given, note, settings, strategies from hypothesis import given, note, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
pytestmark = testing.timeout(10) pytestmark = tm.timeout(10)
parameter_strategy = strategies.fixed_dictionaries({ parameter_strategy = strategies.fixed_dictionaries({

View File

@ -1,12 +1,14 @@
import xgboost
import os
import generate_models as gm
import testing as tm
import json
import zipfile
import pytest
import copy import copy
import json
import os
import urllib.request import urllib.request
import zipfile
import generate_models as gm
import pytest
import xgboost
from xgboost import testing as tm
def run_model_param_check(config): def run_model_param_check(config):

View File

@ -1,8 +1,9 @@
import numpy as np import numpy as np
import xgboost as xgb
import testing as tm
import pytest import pytest
import xgboost as xgb
from xgboost import testing as tm
dpath = 'demo/data/' dpath = 'demo/data/'

View File

@ -4,12 +4,11 @@ import tempfile
import numpy as np import numpy as np
import pytest import pytest
import testing as tm
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
pytestmark = testing.timeout(10) pytestmark = tm.timeout(10)
class TestOMP: class TestOMP:
@ -86,7 +85,7 @@ class TestOMP:
def test_with_omp_thread_limit(self): def test_with_omp_thread_limit(self):
args = [ args = [
"python", os.path.join( "python", os.path.join(
tm.PROJECT_ROOT, "tests", "python", "with_omp_limit.py" os.path.dirname(tm.normpath(__file__)), "with_omp_limit.py"
) )
] ]
results = [] results = []

View File

@ -1,8 +1,8 @@
import xgboost as xgb
import numpy as np import numpy as np
import pytest import pytest
import testing as tm
import xgboost as xgb
from xgboost import testing as tm
pytestmark = pytest.mark.skipif(**tm.no_pandas()) pytestmark = pytest.mark.skipif(**tm.no_pandas())

View File

@ -1,9 +1,10 @@
import pickle
import numpy as np
import xgboost as xgb
import os
import json import json
import os
import pickle
import numpy as np
import xgboost as xgb
kRows = 100 kRows = 100
kCols = 10 kCols = 10

View File

@ -1,15 +1,16 @@
import json import json
import numpy as np
import xgboost as xgb
import testing as tm
import numpy as np
import pytest import pytest
import xgboost as xgb
from xgboost import testing as tm
try: try:
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Source from graphviz import Source
from matplotlib.axes import Axes
except ImportError: except ImportError:
pass pass

View File

@ -1,12 +1,13 @@
'''Tests for running inplace prediction.''' '''Tests for running inplace prediction.'''
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np
from scipy import sparse
import pytest
import pandas as pd
import testing as tm import numpy as np
import pandas as pd
import pytest
from scipy import sparse
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm
def run_threaded_predict(X, rows, predict_func): def run_threaded_predict(X, rows, predict_func):

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
from hypothesis import given, settings, strategies from hypothesis import given, settings, strategies
from scipy import sparse from scipy import sparse
from testing import ( from xgboost.testing import (
IteratorForTest, IteratorForTest,
make_batches, make_batches,
make_batches_sparse, make_batches_sparse,

View File

@ -1,13 +1,15 @@
import numpy as np
from scipy.sparse import csr_matrix
import testing as tm
import xgboost
import os
import itertools import itertools
import os
import shutil import shutil
import urllib.request import urllib.request
import zipfile import zipfile
import numpy as np
from scipy.sparse import csr_matrix
import xgboost
from xgboost import testing as tm
def test_ranking_with_unweighted_data(): def test_ranking_with_unweighted_data():
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17]) Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
import itertools import itertools
import re import re
import numpy as np
import scipy import scipy
import scipy.special import scipy.special
import xgboost as xgb
dpath = 'demo/data/' dpath = 'demo/data/'
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -4,7 +4,8 @@ from typing import List
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
import testing as tm
from xgboost import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)

View File

@ -6,10 +6,9 @@ import uuid
import numpy as np import numpy as np
import pytest import pytest
import testing as tm
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
@ -38,7 +37,7 @@ from .utils import SparkTestCase
logging.getLogger("py4j").setLevel(logging.INFO) logging.getLogger("py4j").setLevel(logging.INFO)
pytestmark = testing.timeout(60) pytestmark = tm.timeout(60)
class XgboostLocalTest(SparkTestCase): class XgboostLocalTest(SparkTestCase):

View File

@ -6,7 +6,8 @@ import uuid
import numpy as np import numpy as np
import pytest import pytest
import testing as tm
from xgboost import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)

View File

@ -6,9 +6,10 @@ import tempfile
import unittest import unittest
import pytest import pytest
import testing as tm
from six import StringIO from six import StringIO
from xgboost import testing as tm
if tm.no_spark()["condition"]: if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"): if sys.platform.startswith("win") or sys.platform.startswith("darwin"):

View File

@ -1,11 +1,13 @@
import testing as tm
import pytest
import numpy as np
import xgboost as xgb
import json import json
import os import os
dpath = os.path.join(tm.PROJECT_ROOT, 'demo', 'data') import numpy as np
import pytest
import xgboost as xgb
from xgboost import testing as tm
dpath = tm.data_dir(__file__)
def test_aft_survival_toy_data(): def test_aft_survival_toy_data():

View File

@ -3,10 +3,10 @@ import sys
import numpy as np import numpy as np
import pytest import pytest
import testing as tm
import xgboost as xgb import xgboost as xgb
from xgboost import RabitTracker, testing from xgboost import RabitTracker
from xgboost import testing as tm
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -61,7 +61,7 @@ def test_rabit_ops():
run_rabit_ops(client, n_workers) run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**testing.skip_ipv6()) @pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6(): def test_rabit_ops_ipv6():
import dask import dask

View File

@ -1,10 +1,11 @@
import xgboost as xgb
import testing as tm
import numpy as np
import pytest
import os import os
import tempfile import tempfile
import numpy as np
import pytest
import xgboost as xgb
from xgboost import testing as tm
rng = np.random.RandomState(1337) rng = np.random.RandomState(1337)

View File

@ -1,8 +1,8 @@
import numpy as np import numpy as np
import xgboost as xgb
from numpy.testing import assert_approx_equal from numpy.testing import assert_approx_equal
import xgboost as xgb
train_data = xgb.DMatrix(np.array([[1]]), label=np.array([1])) train_data = xgb.DMatrix(np.array([[1]]), label=np.array([1]))

View File

@ -1,11 +1,13 @@
import json import json
from string import ascii_lowercase from string import ascii_lowercase
from typing import Dict, Any from typing import Any, Dict
import testing as tm
import pytest
import xgboost as xgb
import numpy as np import numpy as np
from hypothesis import given, strategies, settings, note import pytest
from hypothesis import given, note, settings, strategies
import xgboost as xgb
from xgboost import testing as tm
exact_parameter_strategy = strategies.fixed_dictionaries({ exact_parameter_strategy = strategies.fixed_dictionaries({
'nthread': strategies.integers(1, 4), 'nthread': strategies.integers(1, 4),

View File

@ -1,14 +1,16 @@
import unittest
import pytest
import numpy as np
import testing as tm
import xgboost as xgb
import os import os
import unittest
import numpy as np
import pytest
import xgboost as xgb
from xgboost import testing as tm
try: try:
import pandas as pd
import pyarrow as pa import pyarrow as pa
import pyarrow.csv as pc import pyarrow.csv as pc
import pandas as pd
except ImportError: except ImportError:
pass pass
@ -73,7 +75,7 @@ class TestArrowTable(unittest.TestCase):
np.testing.assert_allclose(preds1, preds2) np.testing.assert_allclose(preds1, preds2)
def test_arrow_survival(self): def test_arrow_survival(self):
data = os.path.join(tm.PROJECT_ROOT, "demo", "data", "veterans_lung_cancer.csv") data = os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv")
table = pc.read_csv(data) table = pc.read_csv(data)
y_lower_bound = table["Survival_label_lower_bound"] y_lower_bound = table["Survival_label_lower_bound"]
y_upper_bound = table["Survival_label_upper_bound"] y_upper_bound = table["Survival_label_upper_bound"]

View File

@ -20,7 +20,6 @@ import numpy as np
import pytest import pytest
import scipy import scipy
import sklearn import sklearn
import testing as tm
from hypothesis import HealthCheck, given, note, settings from hypothesis import HealthCheck, given, note, settings
from sklearn.datasets import make_classification, make_regression from sklearn.datasets import make_classification, make_regression
from test_predict import verify_leaf_output from test_predict import verify_leaf_output
@ -29,7 +28,7 @@ from test_with_sklearn import run_data_initialization, run_feature_weights
from xgboost.data import _is_cudf_df from xgboost.data import _is_cudf_df
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -45,7 +44,7 @@ from xgboost.dask import DaskDMatrix
dask.config.set({"distributed.scheduler.allowed-failures": False}) dask.config.set({"distributed.scheduler.allowed-failures": False})
pytestmark = testing.timeout(30) pytestmark = tm.timeout(30)
if hasattr(HealthCheck, 'function_scoped_fixture'): if hasattr(HealthCheck, 'function_scoped_fixture'):
suppress = [HealthCheck.function_scoped_fixture] suppress = [HealthCheck.function_scoped_fixture]
@ -1116,8 +1115,9 @@ def test_predict_with_meta(client: "Client") -> None:
def run_aft_survival(client: "Client", dmatrix_t: Type) -> None: def run_aft_survival(client: "Client", dmatrix_t: Type) -> None:
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data', df = dd.read_csv(
'veterans_lung_cancer.csv')) os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv")
)
y_lower_bound = df['Survival_label_lower_bound'] y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound'] y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', X = df.drop(['Survival_label_lower_bound',

View File

@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb
import testing as tm
import pytest import pytest
from test_dmatrix import set_base_margin_info from test_dmatrix import set_base_margin_info
import xgboost as xgb
from xgboost import testing as tm
try: try:
import modin.pandas as md import modin.pandas as md
except ImportError: except ImportError:

View File

@ -1,11 +1,13 @@
import os import os
import tempfile import tempfile
import numpy as np import numpy as np
import xgboost as xgb
import testing as tm
import pytest import pytest
from test_dmatrix import set_base_margin_info from test_dmatrix import set_base_margin_info
import xgboost as xgb
from xgboost import testing as tm
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:

View File

@ -1,7 +1,8 @@
import numpy as np import numpy as np
import xgboost as xgb
import pytest import pytest
import xgboost as xgb
try: try:
import shap import shap
except ImportError: except ImportError:

View File

@ -8,14 +8,13 @@ from typing import Callable, Optional
import numpy as np import numpy as np
import pytest import pytest
import testing as tm
from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.estimator_checks import parametrize_with_checks
import xgboost as xgb import xgboost as xgb
from xgboost import testing from xgboost import testing as tm
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
pytestmark = [pytest.mark.skipif(**tm.no_sklearn()), testing.timeout(30)] pytestmark = [pytest.mark.skipif(**tm.no_sklearn()), tm.timeout(30)]
def test_binary_classification(): def test_binary_classification():
@ -155,11 +154,10 @@ def test_ranking():
def test_stacking_regression(): def test_stacking_regression():
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor, StackingRegressor
from sklearn.linear_model import RidgeCV from sklearn.linear_model import RidgeCV
from sklearn.ensemble import RandomForestRegressor from sklearn.model_selection import train_test_split
from sklearn.ensemble import StackingRegressor
X, y = load_diabetes(return_X_y=True) X, y = load_diabetes(return_X_y=True)
estimators = [ estimators = [
@ -177,13 +175,13 @@ def test_stacking_regression():
def test_stacking_classification(): def test_stacking_classification():
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import StackingClassifier from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
estimators = [ estimators = [
@ -354,8 +352,8 @@ def test_num_parallel_tree():
def test_regression(): def test_regression():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import fetch_california_housing from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
@ -383,8 +381,8 @@ def test_regression():
def run_housing_rf_regression(tree_method): def run_housing_rf_regression(tree_method):
from sklearn.metrics import mean_squared_error
from sklearn.datasets import fetch_california_housing from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
@ -407,8 +405,8 @@ def test_rf_regression():
def test_parameter_tuning(): def test_parameter_tuning():
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import fetch_california_housing from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import GridSearchCV
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
xgb_model = xgb.XGBRegressor(learning_rate=0.1) xgb_model = xgb.XGBRegressor(learning_rate=0.1)
@ -421,8 +419,8 @@ def test_parameter_tuning():
def test_regression_with_custom_objective(): def test_regression_with_custom_objective():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import fetch_california_housing from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
def objective_ls(y_true, y_pred): def objective_ls(y_true, y_pred):
@ -539,8 +537,8 @@ def test_sklearn_plotting():
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Source from graphviz import Source
from matplotlib.axes import Axes
ax = xgb.plot_importance(classifier) ax = xgb.plot_importance(classifier)
assert isinstance(ax, Axes) assert isinstance(ax, Axes)
@ -666,8 +664,8 @@ def test_kwargs_error():
def test_kwargs_grid_search(): def test_kwargs_grid_search():
from sklearn.model_selection import GridSearchCV
from sklearn import datasets from sklearn import datasets
from sklearn.model_selection import GridSearchCV
params = {'tree_method': 'hist'} params = {'tree_method': 'hist'}
clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params) clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params)
@ -841,9 +839,7 @@ def test_save_load_model():
def test_RFECV(): def test_RFECV():
from sklearn.datasets import load_diabetes from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
from sklearn.feature_selection import RFECV from sklearn.feature_selection import RFECV
# Regression # Regression
@ -1046,8 +1042,9 @@ def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
with open(model_path) as fd: with open(model_path) as fd:
model = json.load(fd) model = json.load(fd)
parser_path = os.path.join(tm.PROJECT_ROOT, 'demo', 'json-model', parser_path = os.path.join(
'json_parser.py') tm.demo_dir(__file__), "json-model", "json_parser.py"
)
spec = importlib.util.spec_from_file_location("JsonParser", spec = importlib.util.spec_from_file_location("JsonParser",
parser_path) parser_path)
foo = importlib.util.module_from_spec(spec) foo = importlib.util.module_from_spec(spec)
@ -1162,8 +1159,8 @@ def run_boost_from_prediction_multi_clasas(
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"]) @pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
def test_boost_from_prediction(tree_method): def test_boost_from_prediction(tree_method):
from sklearn.datasets import load_breast_cancer, load_iris, make_regression
import pandas as pd import pandas as pd
from sklearn.datasets import load_breast_cancer, load_iris, make_regression
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)

View File

@ -1,7 +1,9 @@
import xgboost as xgb import sys
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
import sys
import xgboost as xgb
def run_omp(output_path: str): def run_omp(output_path: str):