Fix pyspark parameter. (#9460)
- Don't pass the `use_gpu` parameter to the learner. - Fix GPU approx with PySpark.
This commit is contained in:
@@ -456,7 +456,9 @@ def check_sub_dict_match(
|
||||
assert sub_dist[k] == whole_dict[k], f"check on {k} failed"
|
||||
|
||||
|
||||
def get_params_map(params_kv: dict, estimator: Type) -> dict:
|
||||
def get_params_map(
|
||||
params_kv: dict, estimator: xgb.spark.core._SparkXGBEstimator
|
||||
) -> dict:
|
||||
return {getattr(estimator, k): v for k, v in params_kv.items()}
|
||||
|
||||
|
||||
@@ -870,10 +872,10 @@ class TestPySparkLocal:
|
||||
|
||||
def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None:
|
||||
clf = SparkXGBClassifier(device="cuda", tree_method="exact")
|
||||
with pytest.raises(ValueError, match="not supported on GPU"):
|
||||
with pytest.raises(ValueError, match="not supported for distributed"):
|
||||
clf.fit(clf_data.cls_df_train)
|
||||
regressor = SparkXGBRegressor(device="cuda", tree_method="exact")
|
||||
with pytest.raises(ValueError, match="not supported on GPU"):
|
||||
with pytest.raises(ValueError, match="not supported for distributed"):
|
||||
regressor.fit(reg_data.reg_df_train)
|
||||
|
||||
reg = SparkXGBRegressor(device="cuda", tree_method="gpu_hist")
|
||||
|
||||
Reference in New Issue
Block a user