[jvm-packages][spark]Preserve num classes (#2068)

* add back train method but mark as deprecated

* fix scalastyle error

* change class to object in examples

* fix compilation error

* bump spark version to 2.1

* preserve num_class issues

* fix failed test cases

* rivising

* add multi class test
This commit is contained in:
Nan Zhu
2017-03-04 14:14:31 -08:00
committed by GitHub
parent a92093388d
commit ac30a0aff5
4 changed files with 424 additions and 15 deletions

View File

@@ -97,7 +97,12 @@ class XGBoostEstimator private[spark](
for (param <- params) {
xgbParamMap += param.name -> $(param)
}
xgbParamMap.toMap
val r = xgbParamMap.toMap
if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) {
r - "num_class"
} else {
r
}
}
/**
@@ -110,19 +115,13 @@ class XGBoostEstimator private[spark](
LabeledPoint(label, feature)
}
transformSchema(trainingSet.schema, logging = true)
val trainedModel = XGBoost.trainWithRDD(instances, fromParamsToXGBParamMap,
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
val trainedModel = XGBoost.trainWithRDD(instances, derivedXGBoosterParamMap,
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)).setParent(this)
val returnedModel = copyValues(trainedModel)
if (XGBoost.isClassificationTask(xgboostParams)) {
val numClass = {
if (xgboostParams.contains("num_class")) {
xgboostParams("num_class").asInstanceOf[Int]
} else {
2
}
}
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClass
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
}
returnedModel
}

View File

@@ -18,10 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark.params
import scala.collection.immutable.HashSet
import org.apache.spark.ml.param.{DoubleParam, Param, Params}
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
trait LearningTaskParams extends Params {
/**
* number of tasks to learn
*/
val numClasses = new IntParam(this, "num_class", "number of classes")
/**
* Specify the learning task and the corresponding learning objective.
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
@@ -48,7 +53,7 @@ trait LearningTaskParams extends Params {
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
setDefault(objective -> "reg:linear", baseScore -> 0.5)
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2)
}
private[spark] object LearningTaskParams {