[jvm-packages] Fix "obj_type" error to enable custom objectives and evaluations (#3646)

credits to @mmui
This commit is contained in:
Michael Mui 2018-09-14 12:06:33 -07:00 committed by Nan Zhu
parent 7bbb44182a
commit 20a9e716bd
5 changed files with 33 additions and 4 deletions

View File

@ -227,9 +227,9 @@ object XGBoost extends Serializable {
}
require(nWorkers > 0, "you must specify more than 0 workers")
if (obj != null) {
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
" defined, you have to specify the objective type as classification or regression" +
" with a customized objective function")
}
val trackerConf = params.get("tracker_conf") match {
case None => TrackerConf()

View File

@ -130,6 +130,8 @@ class XGBoostClassifier (
// setters for learning params
def setObjective(value: String): this.type = set(objective, value)
def setObjectiveType(value: String): this.type = set(objectiveType, value)
def setBaseScore(value: Double): this.type = set(baseScore, value)
def setEvalMetric(value: String): this.type = set(evalMetric, value)
@ -160,6 +162,10 @@ class XGBoostClassifier (
set(evalMetric, setupDefaultEvalMetric())
}
if (isDefined(customObj) && $(customObj) != null) {
set(objectiveType, "classification")
}
val _numClasses = getNumClasses(dataset)
if (isDefined(numClass) && $(numClass) != _numClasses) {
throw new Exception("The number of classes in dataset doesn't match " +

View File

@ -130,6 +130,8 @@ class XGBoostRegressor (
// setters for learning params
def setObjective(value: String): this.type = set(objective, value)
def setObjectiveType(value: String): this.type = set(objectiveType, value)
def setBaseScore(value: Double): this.type = set(baseScore, value)
def setEvalMetric(value: String): this.type = set(evalMetric, value)
@ -158,6 +160,10 @@ class XGBoostRegressor (
set(evalMetric, setupDefaultEvalMetric())
}
if (isDefined(customObj) && $(customObj) != null) {
set(objectiveType, "regression")
}
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
lit(Float.NaN)

View File

@ -33,6 +33,18 @@ private[spark] trait LearningTaskParams extends Params {
final def getObjective: String = $(objective)
/**
* The learning objective type of the specified custom objective and eval.
* Corresponding type will be assigned if custom objective is defined
* options: regression, classification. default: null
*/
final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " +
s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}",
(value: String) => LearningTaskParams.supportedObjectiveType.contains(value))
final def getObjectiveType: String = $(objectiveType)
/**
* the initial prediction score of all instances, global bias. default=0.5
*/
@ -84,6 +96,8 @@ private[spark] object LearningTaskParams {
"binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise",
"rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
val supportedObjectiveType = HashSet("regression", "classification")
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
}

View File

@ -140,15 +140,18 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
}
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
assert(xgb.getObjectiveType === "classification")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}