diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index ee304fe6b..d756b3c5a 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -697,13 +697,14 @@ class XgboostLocalTest(SparkTestCase): self.assert_model_compatible(model.stages[0], tmp_dir) def test_classifier_with_cross_validator(self): - xgb_classifer = SparkXGBClassifier() + xgb_classifer = SparkXGBClassifier(n_estimators=1) paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build() cvBin = CrossValidator( estimator=xgb_classifer, estimatorParamMaps=paramMaps, evaluator=BinaryClassificationEvaluator(), seed=1, + parallelism=4, numFolds=2, ) cvBinModel = cvBin.fit(self.cls_df_train_large)