Fixes for numpy 2.0. (#10252)

This commit is contained in:
Jiaming Yuan
2024-05-07 03:54:32 +08:00
committed by GitHub
parent dcc9639b91
commit 73afef1a6e
12 changed files with 35 additions and 34 deletions

View File

@@ -1653,9 +1653,9 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[9.0, 4.0, 8.0],
[np.NaN, 1.0, 5.5],
[np.NaN, 6.0, 7.5],
[np.NaN, 8.0, 9.5],
[np.nan, 1.0, 5.5],
[np.nan, 6.0, 7.5],
[np.nan, 8.0, 9.5],
]
)
qid_train = np.array([0, 0, 0, 1, 1, 1])
@@ -1666,9 +1666,9 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
[1.5, 2.0, 3.0],
[4.5, 5.0, 6.0],
[9.0, 4.5, 8.0],
[np.NaN, 1.0, 6.0],
[np.NaN, 6.0, 7.0],
[np.NaN, 8.0, 10.5],
[np.nan, 1.0, 6.0],
[np.nan, 6.0, 7.0],
[np.nan, 8.0, 10.5],
]
)