[jvm-packages][spark]Preserve num classes (#2068)
* add back train method but mark as deprecated * fix scalastyle error * change class to object in examples * fix compilation error * bump spark version to 2.1 * preserve num_class issues * fix failed test cases * rivising * add multi class test
This commit is contained in:
@@ -97,7 +97,12 @@ class XGBoostEstimator private[spark](
|
||||
for (param <- params) {
|
||||
xgbParamMap += param.name -> $(param)
|
||||
}
|
||||
xgbParamMap.toMap
|
||||
val r = xgbParamMap.toMap
|
||||
if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) {
|
||||
r - "num_class"
|
||||
} else {
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -110,19 +115,13 @@ class XGBoostEstimator private[spark](
|
||||
LabeledPoint(label, feature)
|
||||
}
|
||||
transformSchema(trainingSet.schema, logging = true)
|
||||
val trainedModel = XGBoost.trainWithRDD(instances, fromParamsToXGBParamMap,
|
||||
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
||||
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
|
||||
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing)).setParent(this)
|
||||
val returnedModel = copyValues(trainedModel)
|
||||
if (XGBoost.isClassificationTask(xgboostParams)) {
|
||||
val numClass = {
|
||||
if (xgboostParams.contains("num_class")) {
|
||||
xgboostParams("num_class").asInstanceOf[Int]
|
||||
} else {
|
||||
2
|
||||
}
|
||||
}
|
||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClass
|
||||
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
||||
}
|
||||
returnedModel
|
||||
}
|
||||
|
||||
@@ -18,10 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param.{DoubleParam, Param, Params}
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
trait LearningTaskParams extends Params {
|
||||
|
||||
/**
|
||||
* number of tasks to learn
|
||||
*/
|
||||
val numClasses = new IntParam(this, "num_class", "number of classes")
|
||||
|
||||
/**
|
||||
* Specify the learning task and the corresponding learning objective.
|
||||
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
|
||||
@@ -48,7 +53,7 @@ trait LearningTaskParams extends Params {
|
||||
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5)
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
Reference in New Issue
Block a user