[jvm-packages] Fix "obj_type" error to enable custom objectives and evaluations (#3646)
credits to @mmui
This commit is contained in:
parent
7bbb44182a
commit
20a9e716bd
@ -227,9 +227,9 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
if (obj != null) {
|
if (obj != null) {
|
||||||
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
|
||||||
" you have to specify the objective type as classification or regression with a" +
|
" defined, you have to specify the objective type as classification or regression" +
|
||||||
" customized objective function")
|
" with a customized objective function")
|
||||||
}
|
}
|
||||||
val trackerConf = params.get("tracker_conf") match {
|
val trackerConf = params.get("tracker_conf") match {
|
||||||
case None => TrackerConf()
|
case None => TrackerConf()
|
||||||
|
|||||||
@ -130,6 +130,8 @@ class XGBoostClassifier (
|
|||||||
// setters for learning params
|
// setters for learning params
|
||||||
def setObjective(value: String): this.type = set(objective, value)
|
def setObjective(value: String): this.type = set(objective, value)
|
||||||
|
|
||||||
|
def setObjectiveType(value: String): this.type = set(objectiveType, value)
|
||||||
|
|
||||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||||
|
|
||||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||||
@ -160,6 +162,10 @@ class XGBoostClassifier (
|
|||||||
set(evalMetric, setupDefaultEvalMetric())
|
set(evalMetric, setupDefaultEvalMetric())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isDefined(customObj) && $(customObj) != null) {
|
||||||
|
set(objectiveType, "classification")
|
||||||
|
}
|
||||||
|
|
||||||
val _numClasses = getNumClasses(dataset)
|
val _numClasses = getNumClasses(dataset)
|
||||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
||||||
throw new Exception("The number of classes in dataset doesn't match " +
|
throw new Exception("The number of classes in dataset doesn't match " +
|
||||||
|
|||||||
@ -130,6 +130,8 @@ class XGBoostRegressor (
|
|||||||
// setters for learning params
|
// setters for learning params
|
||||||
def setObjective(value: String): this.type = set(objective, value)
|
def setObjective(value: String): this.type = set(objective, value)
|
||||||
|
|
||||||
|
def setObjectiveType(value: String): this.type = set(objectiveType, value)
|
||||||
|
|
||||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||||
|
|
||||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||||
@ -158,6 +160,10 @@ class XGBoostRegressor (
|
|||||||
set(evalMetric, setupDefaultEvalMetric())
|
set(evalMetric, setupDefaultEvalMetric())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isDefined(customObj) && $(customObj) != null) {
|
||||||
|
set(objectiveType, "regression")
|
||||||
|
}
|
||||||
|
|
||||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||||
lit(Float.NaN)
|
lit(Float.NaN)
|
||||||
|
|||||||
@ -33,6 +33,18 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
|
|
||||||
final def getObjective: String = $(objective)
|
final def getObjective: String = $(objective)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The learning objective type of the specified custom objective and eval.
|
||||||
|
* Corresponding type will be assigned if custom objective is defined
|
||||||
|
* options: regression, classification. default: null
|
||||||
|
*/
|
||||||
|
final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " +
|
||||||
|
s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}",
|
||||||
|
(value: String) => LearningTaskParams.supportedObjectiveType.contains(value))
|
||||||
|
|
||||||
|
final def getObjectiveType: String = $(objectiveType)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the initial prediction score of all instances, global bias. default=0.5
|
* the initial prediction score of all instances, global bias. default=0.5
|
||||||
*/
|
*/
|
||||||
@ -84,6 +96,8 @@ private[spark] object LearningTaskParams {
|
|||||||
"binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise",
|
"binary:logitraw", "count:poisson", "multi:softmax", "multi:softprob", "rank:pairwise",
|
||||||
"rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
"rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")
|
||||||
|
|
||||||
|
val supportedObjectiveType = HashSet("regression", "classification")
|
||||||
|
|
||||||
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
||||||
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
|
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -140,15 +140,18 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test("XGBoost and Spark parameters synchronize correctly") {
|
test("XGBoost and Spark parameters synchronize correctly") {
|
||||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||||
|
"objective_type" -> "classification")
|
||||||
// from xgboost params to spark params
|
// from xgboost params to spark params
|
||||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||||
assert(xgb.getEta === 1.0)
|
assert(xgb.getEta === 1.0)
|
||||||
assert(xgb.getObjective === "binary:logistic")
|
assert(xgb.getObjective === "binary:logistic")
|
||||||
|
assert(xgb.getObjectiveType === "classification")
|
||||||
// from spark to xgboost params
|
// from spark to xgboost params
|
||||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||||
|
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
|
||||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user