From 46f2b820f1cbc7b20c7cc2b7dc4212c89174d3b1 Mon Sep 17 00:00:00 2001 From: ebernhardson Date: Mon, 30 Oct 2017 09:36:03 -0700 Subject: [PATCH] [jvm-packages] Objectives starting with rank: are never classification (#2837) Training a model with the experimental rank:ndcg objective incorrectly returns a Classification model. Adjust the classification check to not recognize rank:* objectives as classification. While writing tests for isClassificationTask also turned up that obj_type -> regression was incorrectly identified as a classification task so the function was slightly adjusted to pass the new tests. --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 4 ++-- .../scala/spark/XGBoostGeneralSuite.scala | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d5ed85d0d..3e8736370 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -194,8 +194,8 @@ object XGBoost extends Serializable { val objective = params.getOrElse("objective", params.getOrElse("obj_type", null)) objective != null && { val objStr = objective.toString - objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" && - objStr != "rank:pairwise") + objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" && + !objStr.startsWith("rank:") } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index dc2ef9672..86958552e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -389,4 +389,28 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix) < 0.1) } + + test("isClassificationTask correctly classifies supported objectives") { + import org.scalatest.prop.TableDrivenPropertyChecks._ + + val objectives = Table( + ("isClassificationTask", "params"), + (true, Map("obj_type" -> "classification")), + (false, Map("obj_type" -> "regression")), + (false, Map("objective" -> "rank:ndcg")), + (false, Map("objective" -> "rank:pairwise")), + (false, Map("objective" -> "rank:map")), + (false, Map("objective" -> "count:poisson")), + (true, Map("objective" -> "binary:logistic")), + (true, Map("objective" -> "binary:logitraw")), + (true, Map("objective" -> "multi:softmax")), + (true, Map("objective" -> "multi:softprob")), + (false, Map("objective" -> "reg:linear")), + (false, Map("objective" -> "reg:logistic")), + (false, Map("objective" -> "reg:gamma")), + (false, Map("objective" -> "reg:tweedie"))) + forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) => + assert(XGBoost.isClassificationTask(params) == isClassificationTask) + } + } }