diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 6302c35e4..fa1dccc53 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -227,9 +227,9 @@ object XGBoost extends Serializable { } require(nWorkers > 0, "you must specify more than 0 workers") if (obj != null) { - require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," + - " you have to specify the objective type as classification or regression with a" + - " customized objective function") + require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" + + " defined, you have to specify the objective type as classification or regression" + + " with a customized objective function") } val trackerConf = params.get("tracker_conf") match { case None => TrackerConf() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 47b489c22..c8ac28eb4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -130,6 +130,8 @@ class XGBoostClassifier ( // setters for learning params def setObjective(value: String): this.type = set(objective, value) + def setObjectiveType(value: String): this.type = set(objectiveType, value) + def setBaseScore(value: Double): this.type = set(baseScore, value) def setEvalMetric(value: String): this.type = set(evalMetric, value) @@ -160,6 +162,10 @@ class XGBoostClassifier ( set(evalMetric, setupDefaultEvalMetric()) } + if (isDefined(customObj) && $(customObj) != null) { + set(objectiveType, "classification") + } + val _numClasses = getNumClasses(dataset) if (isDefined(numClass) && $(numClass) != _numClasses) { throw new Exception("The number of classes in dataset doesn't match " + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 0fe3452d6..277d55669 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -130,6 +130,8 @@ class XGBoostRegressor ( // setters for learning params def setObjective(value: String): this.type = set(objective, value) + def setObjectiveType(value: String): this.type = set(objectiveType, value) + def setBaseScore(value: Double): this.type = set(baseScore, value) def setEvalMetric(value: String): this.type = set(evalMetric, value) @@ -158,6 +160,10 @@ class XGBoostRegressor ( set(evalMetric, setupDefaultEvalMetric()) } + if (isDefined(customObj) && $(customObj) != null) { + set(objectiveType, "regression") + } + val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) { lit(Float.NaN) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 6f7a32653..5d3106a02 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -33,6 +33,18 @@ private[spark] trait LearningTaskParams extends Params { final def getObjective: String = $(objective) + /** + * The learning objective type of the specified custom objective and eval. + * Corresponding type will be assigned if custom objective is defined + * options: regression, classification. default: null + */ + final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " + + s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}", + (value: String) => LearningTaskParams.supportedObjectiveType.contains(value)) + + final def getObjectiveType: String = $(objectiveType) + + /** * the initial prediction score of all instances, global bias. default=0.5 */ @@ -84,6 +96,8 @@ private[spark] object LearningTaskParams { "binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie") + val supportedObjectiveType = HashSet("regression", "classification") + val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss", "auc", "aucpr", "ndcg", "map", "gamma-deviance") } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 86f9b575a..9c92f8810 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -140,15 +140,18 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { } test("XGBoost and Spark parameters synchronize correctly") { - val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic") + val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic", + "objective_type" -> "classification") // from xgboost params to spark params val xgb = new XGBoostClassifier(xgbParamMap) assert(xgb.getEta === 1.0) assert(xgb.getObjective === "binary:logistic") + assert(xgb.getObjectiveType === "classification") // from spark to xgboost params val xgbCopy = xgb.copy(ParamMap.empty) assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0) assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic") + assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification") val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss")) assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") }