Fixes for numpy 2.0. (#10252)

This commit is contained in:
Jiaming Yuan 2024-05-07 03:54:32 +08:00 committed by GitHub
parent dcc9639b91
commit 73afef1a6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 35 additions and 34 deletions

View File

@ -84,7 +84,7 @@ def main(tmpdir: str) -> xgboost.Booster:
it = Iterator(files)
# For non-data arguments, specify it here once instead of passing them by the `next`
# method.
missing = np.NaN
missing = np.nan
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
# ``approx`` is also supported, but less efficient due to sketching. GPU behaves

View File

@ -233,9 +233,9 @@ def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray:
if not data.flags.c_contiguous:
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
data = np.asarray(data, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
data = np.asarray(data, dtype=dtype)
data, dtype = _ensure_np_dtype(data, dtype)
return data
@ -483,7 +483,7 @@ def pandas_transform_data(data: DataFrame) -> List[np.ndarray]:
if is_pd_cat_dtype(ser.dtype):
return _ensure_np_dtype(
ser.cat.codes.astype(np.float32)
.replace(-1.0, np.NaN)
.replace(-1.0, np.nan)
.to_numpy(na_value=np.nan),
np.float32,
)[0]
@ -495,7 +495,7 @@ def pandas_transform_data(data: DataFrame) -> List[np.ndarray]:
.combine_chunks()
.dictionary_encode()
.indices.astype(np.float32)
.replace(-1.0, np.NaN)
.replace(-1.0, np.nan)
)
def nu_type(ser: pd.Series) -> np.ndarray:

View File

@ -437,7 +437,7 @@ def make_categorical(
index = rng.randint(
low=0, high=n_samples - 1, size=int(n_samples * sparsity)
)
df.iloc[index, i] = np.NaN
df.iloc[index, i] = np.nan
if is_categorical_dtype(df.dtypes[i]):
assert n_categories == np.unique(df.dtypes[i].categories).size

View File

@ -66,7 +66,7 @@ def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None:
X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)})
y = pd.Series([*[0] * 5000, *[1] * 5000])
X["a"][:3000:1000] = np.NaN
X["a"][:3000:1000] = np.nan
client.wait_for_workers(n_workers=n_workers)

View File

@ -10,7 +10,7 @@ from xgboost.testing.data import run_base_margin_info
cudf = pytest.importorskip("cudf")
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.nan):
"""Test constructing DMatrix from cudf"""
import pandas as pd
@ -38,8 +38,8 @@ def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
def _test_from_cudf(DMatrixT):
"""Test constructing DMatrix from cudf"""
dmatrix_from_cudf(np.float32, DMatrixT, np.NAN)
dmatrix_from_cudf(np.float64, DMatrixT, np.NAN)
dmatrix_from_cudf(np.float32, DMatrixT, np.nan)
dmatrix_from_cudf(np.float64, DMatrixT, np.nan)
dmatrix_from_cudf(np.int8, DMatrixT, 2)
dmatrix_from_cudf(np.int32, DMatrixT, -2)
@ -66,7 +66,7 @@ def _test_from_cudf(DMatrixT):
)
# Test when number of elements is less than 8
X = cudf.DataFrame({"x": cudf.Series([0, 1, 2, np.NAN, 4], dtype=np.int32)})
X = cudf.DataFrame({"x": cudf.Series([0, 1, 2, np.nan, 4], dtype=np.int32)})
dtrain = DMatrixT(X)
assert dtrain.num_col() == 1
assert dtrain.num_row() == 5
@ -225,7 +225,7 @@ class TestFromColumnar:
assert len(interfaces) == X.shape[1]
# test missing value
X = cudf.DataFrame({"f0": ["a", "b", np.NaN]})
X = cudf.DataFrame({"f0": ["a", "b", np.nan]})
X["f0"] = X["f0"].astype("category")
df, cat_codes, _, _ = xgb.data._transform_cudf_df(
X, None, None, enable_categorical=True

View File

@ -18,7 +18,7 @@ def test_array_interface() -> None:
np.testing.assert_equal(cp.asnumpy(arr), cp.asnumpy(ret))
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.nan):
"""Test constructing DMatrix from cupy"""
kRows = 80
kCols = 3
@ -46,9 +46,9 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
def _test_from_cupy(DMatrixT):
"""Test constructing DMatrix from cupy"""
dmatrix_from_cupy(np.float16, DMatrixT, np.NAN)
dmatrix_from_cupy(np.float32, DMatrixT, np.NAN)
dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)
dmatrix_from_cupy(np.float16, DMatrixT, np.nan)
dmatrix_from_cupy(np.float32, DMatrixT, np.nan)
dmatrix_from_cupy(np.float64, DMatrixT, np.nan)
dmatrix_from_cupy(np.uint8, DMatrixT, 2)
dmatrix_from_cupy(np.uint32, DMatrixT, 3)

View File

@ -147,7 +147,7 @@ class TestDMatrix:
assert dm.slice([0, 1]).num_col() == dm.num_col()
assert dm.slice([0, 1]).feature_names == dm.feature_names
with pytest.raises(ValueError, match=r"Duplicates found: \['bar'\]"):
with pytest.raises(ValueError, match=r"Duplicates found: \[.*'bar'.*\]"):
dm.feature_names = ["bar"] * (data.shape[1] - 2) + ["a", "b"]
dm.feature_types = list("qiqiq")
@ -264,7 +264,7 @@ class TestDMatrix:
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
watchlist = [(dtrain, "train")]
param = {"max_depth": 3, "objective": "binary:logistic"}
bst = xgb.train(param, dtrain, 5, watchlist)
bst = xgb.train(param, dtrain, 5, evals=watchlist)
bst.predict(dtrain)
i32 = csr_matrix((x.data.astype(np.int32), x.indices, x.indptr), shape=x.shape)
@ -302,7 +302,7 @@ class TestDMatrix:
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
watchlist = [(dtrain, "train")]
param = {"max_depth": 3, "objective": "binary:logistic"}
bst = xgb.train(param, dtrain, 5, watchlist)
bst = xgb.train(param, dtrain, 5, evals=watchlist)
bst.predict(dtrain)
def test_unknown_data(self):
@ -320,6 +320,7 @@ class TestDMatrix:
X = rng.rand(10, 10)
y = rng.rand(10)
X = sparse.dok_matrix(X)
with pytest.warns(UserWarning, match="dok_matrix"):
Xy = xgb.DMatrix(X, y)
assert Xy.num_row() == 10
assert Xy.num_col() == 10
@ -343,8 +344,8 @@ class TestDMatrix:
X = X.values.astype(np.float32)
feature_types = ["c"] * n_features
X[1, 3] = np.NAN
X[2, 4] = np.NAN
X[1, 3] = np.nan
X[2, 4] = np.nan
X = sparse.csr_matrix(X)
Xy = xgb.DMatrix(X, y, feature_types=feature_types)

View File

@ -241,7 +241,7 @@ class TestInplacePredict:
# unsupported types
for dtype in [
np.string_,
np.bytes_,
np.complex64,
np.complex128,
]:

View File

@ -333,7 +333,7 @@ class TestQuantileDMatrix:
# unsupported types
for dtype in [
np.string_,
np.bytes_,
np.complex64,
np.complex128,
]:

View File

@ -248,7 +248,7 @@ class TestPandas:
assert transformed.columns[0].min() == 0
# test missing value
X = pd.DataFrame({"f0": ["a", "b", np.NaN]})
X = pd.DataFrame({"f0": ["a", "b", np.nan]})
X["f0"] = X["f0"].astype("category")
arr, _, _ = xgb.data._transform_pandas_df(X, enable_categorical=True)
for c in arr.columns:

View File

@ -1098,7 +1098,7 @@ def test_pandas_input():
np.testing.assert_equal(model.feature_names_in_, np.array(feature_names))
columns = list(train.columns)
random.shuffle(columns, lambda: 0.1)
random.shuffle(columns)
df_incorrect = df[columns]
with pytest.raises(ValueError):
model.predict(df_incorrect)

View File

@ -1653,9 +1653,9 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[9.0, 4.0, 8.0],
[np.NaN, 1.0, 5.5],
[np.NaN, 6.0, 7.5],
[np.NaN, 8.0, 9.5],
[np.nan, 1.0, 5.5],
[np.nan, 6.0, 7.5],
[np.nan, 8.0, 9.5],
]
)
qid_train = np.array([0, 0, 0, 1, 1, 1])
@ -1666,9 +1666,9 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
[1.5, 2.0, 3.0],
[4.5, 5.0, 6.0],
[9.0, 4.5, 8.0],
[np.NaN, 1.0, 6.0],
[np.NaN, 6.0, 7.0],
[np.NaN, 8.0, 10.5],
[np.nan, 1.0, 6.0],
[np.nan, 6.0, 7.0],
[np.nan, 8.0, 10.5],
]
)