[jvm-packages] Objectives starting with rank: are never classification (#2837)
Training a model with the experimental rank:ndcg objective incorrectly returns a Classification model. Adjust the classification check to not recognize rank:* objectives as classification. While writing tests for isClassificationTask also turned up that obj_type -> regression was incorrectly identified as a classification task so the function was slightly adjusted to pass the new tests.
This commit is contained in:
parent
91af8f7106
commit
46f2b820f1
@ -194,8 +194,8 @@ object XGBoost extends Serializable {
|
||||
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
|
||||
objective != null && {
|
||||
val objStr = objective.toString
|
||||
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||
objStr != "rank:pairwise")
|
||||
objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||
!objStr.startsWith("rank:")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -389,4 +389,28 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("isClassificationTask correctly classifies supported objectives") {
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks._
|
||||
|
||||
val objectives = Table(
|
||||
("isClassificationTask", "params"),
|
||||
(true, Map("obj_type" -> "classification")),
|
||||
(false, Map("obj_type" -> "regression")),
|
||||
(false, Map("objective" -> "rank:ndcg")),
|
||||
(false, Map("objective" -> "rank:pairwise")),
|
||||
(false, Map("objective" -> "rank:map")),
|
||||
(false, Map("objective" -> "count:poisson")),
|
||||
(true, Map("objective" -> "binary:logistic")),
|
||||
(true, Map("objective" -> "binary:logitraw")),
|
||||
(true, Map("objective" -> "multi:softmax")),
|
||||
(true, Map("objective" -> "multi:softprob")),
|
||||
(false, Map("objective" -> "reg:linear")),
|
||||
(false, Map("objective" -> "reg:logistic")),
|
||||
(false, Map("objective" -> "reg:gamma")),
|
||||
(false, Map("objective" -> "reg:tweedie")))
|
||||
forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) =>
|
||||
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user