diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d5ed85d0d..3e8736370 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -194,8 +194,8 @@ object XGBoost extends Serializable { val objective = params.getOrElse("objective", params.getOrElse("obj_type", null)) objective != null && { val objStr = objective.toString - objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" && - objStr != "rank:pairwise") + objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" && + !objStr.startsWith("rank:") } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index dc2ef9672..86958552e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -389,4 +389,28 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix) < 0.1) } + + test("isClassificationTask correctly classifies supported objectives") { + import org.scalatest.prop.TableDrivenPropertyChecks._ + + val objectives = Table( + ("isClassificationTask", "params"), + (true, Map("obj_type" -> "classification")), + (false, Map("obj_type" -> "regression")), + (false, Map("objective" -> "rank:ndcg")), + (false, Map("objective" -> "rank:pairwise")), + (false, Map("objective" -> "rank:map")), + (false, Map("objective" -> "count:poisson")), + (true, Map("objective" -> "binary:logistic")), + (true, Map("objective" -> "binary:logitraw")), + (true, Map("objective" -> "multi:softmax")), + (true, Map("objective" -> "multi:softprob")), + (false, Map("objective" -> "reg:linear")), + (false, Map("objective" -> "reg:logistic")), + (false, Map("objective" -> "reg:gamma")), + (false, Map("objective" -> "reg:tweedie"))) + forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) => + assert(XGBoost.isClassificationTask(params) == isClassificationTask) + } + } }