[jvm-packages] Fix "obj_type" error to enable custom objectives and evaluations (#3646)
credits to @mmui
This commit is contained in:
parent
7bbb44182a
commit
20a9e716bd
@ -227,9 +227,9 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
||||
" you have to specify the objective type as classification or regression with a" +
|
||||
" customized objective function")
|
||||
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
|
||||
" defined, you have to specify the objective type as classification or regression" +
|
||||
" with a customized objective function")
|
||||
}
|
||||
val trackerConf = params.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
|
||||
@ -130,6 +130,8 @@ class XGBoostClassifier (
|
||||
// setters for learning params
|
||||
def setObjective(value: String): this.type = set(objective, value)
|
||||
|
||||
def setObjectiveType(value: String): this.type = set(objectiveType, value)
|
||||
|
||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||
|
||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||
@ -160,6 +162,10 @@ class XGBoostClassifier (
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
if (isDefined(customObj) && $(customObj) != null) {
|
||||
set(objectiveType, "classification")
|
||||
}
|
||||
|
||||
val _numClasses = getNumClasses(dataset)
|
||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
||||
throw new Exception("The number of classes in dataset doesn't match " +
|
||||
|
||||
@ -130,6 +130,8 @@ class XGBoostRegressor (
|
||||
// setters for learning params
|
||||
def setObjective(value: String): this.type = set(objective, value)
|
||||
|
||||
def setObjectiveType(value: String): this.type = set(objectiveType, value)
|
||||
|
||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||
|
||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||
@ -158,6 +160,10 @@ class XGBoostRegressor (
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
if (isDefined(customObj) && $(customObj) != null) {
|
||||
set(objectiveType, "regression")
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
|
||||
@ -33,6 +33,18 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
final def getObjective: String = $(objective)
|
||||
|
||||
/**
|
||||
* The learning objective type of the specified custom objective and eval.
|
||||
* Corresponding type will be assigned if custom objective is defined
|
||||
* options: regression, classification. default: null
|
||||
*/
|
||||
final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " +
|
||||
s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}",
|
||||
(value: String) => LearningTaskParams.supportedObjectiveType.contains(value))
|
||||
|
||||
final def getObjectiveType: String = $(objectiveType)
|
||||
|
||||
|
||||
/**
|
||||
* the initial prediction score of all instances, global bias. default=0.5
|
||||
*/
|
||||
@ -84,6 +96,8 @@ private[spark] object LearningTaskParams {
|
||||
"binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise",
|
||||
"rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
||||
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
|
||||
}
|
||||
|
||||
@ -140,15 +140,18 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
}
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
// from xgboost params to spark params
|
||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||
assert(xgb.getEta === 1.0)
|
||||
assert(xgb.getObjective === "binary:logistic")
|
||||
assert(xgb.getObjectiveType === "classification")
|
||||
// from spark to xgboost params
|
||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user