[jvm-packages][spark]Preserve num classes (#2068)

* add back train method but mark as deprecated

* fix scalastyle error

* change class to object in examples

* fix compilation error

* bump spark version to 2.1

* preserve num_class issues

* fix failed test cases

* rivising

* add multi class test
This commit is contained in:
Nan Zhu
2017-03-04 14:14:31 -08:00
committed by GitHub
parent a92093388d
commit ac30a0aff5
4 changed files with 424 additions and 15 deletions

View File

@@ -16,6 +16,11 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
@@ -60,7 +65,7 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
}
val trainingDF = buildTrainingDataframe()
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = round, nWorkers = numWorkers, useExternalMemory = false)
round = round, nWorkers = numWorkers)
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
"id", "features", "label")
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
@@ -83,7 +88,7 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
val trainingDF = buildTrainingDataframe()
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false)
round = 5, nWorkers = numWorkers)
val testSetItr = testItr.zipWithIndex.map {
case (instance: LabeledPoint, id: Int) =>
(id, instance.features, instance.label)
@@ -184,4 +189,38 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
}
private def convertCSVPointToLabelPoint(valueArray: Array[String]): LabeledPoint = {
val intValueArray = new Array[Double](valueArray.length)
intValueArray(valueArray.length - 2) = {
if (valueArray(valueArray.length - 2) == "?") {
1
} else {
0
}
}
intValueArray(valueArray.length - 1) = valueArray(valueArray.length - 1).toDouble - 1
for (i <- 0 until intValueArray.length - 2) {
intValueArray(i) = valueArray(i).toDouble
}
LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1)))
}
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += convertCSVPointToLabelPoint(sample.split(","))
}
sampleList.toList
}
test("multi_class classification test") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val testItr = loadCSVPoints(getClass.getResource("/dermatology.data").getFile).iterator
val trainingDF = buildTrainingDataframe()
XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers)
}
}