[pyspark] Add param validation for "objective" and "eval_metric" param, and remove invalid booster params (#8173)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
@@ -433,6 +433,7 @@ class XgboostLocalTest(SparkTestCase):
|
||||
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
|
||||
self.assertFalse(hasattr(py_reg, "gpu_id"))
|
||||
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
|
||||
self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror")
|
||||
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
||||
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
|
||||
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
|
||||
@@ -445,6 +446,7 @@ class XgboostLocalTest(SparkTestCase):
|
||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.objective), None)
|
||||
py_cls2 = SparkXGBClassifier(n_estimators=200)
|
||||
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
|
||||
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
|
||||
|
||||
Reference in New Issue
Block a user