[jvm-packages][spark]Preserve num classes (#2068)
* add back train method but mark as deprecated * fix scalastyle error * change class to object in examples * fix compilation error * bump spark version to 2.1 * preserve num_class issues * fix failed test cases * rivising * add multi class test
This commit is contained in:
@@ -16,6 +16,11 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
@@ -60,7 +65,7 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
}
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = round, nWorkers = numWorkers, useExternalMemory = false)
|
||||
round = round, nWorkers = numWorkers)
|
||||
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
||||
"id", "features", "label")
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
@@ -83,7 +88,7 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||
round = 5, nWorkers = numWorkers)
|
||||
val testSetItr = testItr.zipWithIndex.map {
|
||||
case (instance: LabeledPoint, id: Int) =>
|
||||
(id, instance.features, instance.label)
|
||||
@@ -184,4 +189,38 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
||||
}
|
||||
|
||||
private def convertCSVPointToLabelPoint(valueArray: Array[String]): LabeledPoint = {
|
||||
val intValueArray = new Array[Double](valueArray.length)
|
||||
intValueArray(valueArray.length - 2) = {
|
||||
if (valueArray(valueArray.length - 2) == "?") {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
intValueArray(valueArray.length - 1) = valueArray(valueArray.length - 1).toDouble - 1
|
||||
for (i <- 0 until intValueArray.length - 2) {
|
||||
intValueArray(i) = valueArray(i).toDouble
|
||||
}
|
||||
LabeledPoint(intValueArray.last, new DenseVector(intValueArray.take(intValueArray.length - 1)))
|
||||
}
|
||||
|
||||
private def loadCSVPoints(filePath: String, zeroBased: Boolean = false): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += convertCSVPointToLabelPoint(sample.split(","))
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
test("multi_class classification test") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val testItr = loadCSVPoints(getClass.getResource("/dermatology.data").getFile).iterator
|
||||
val trainingDF = buildTrainingDataframe()
|
||||
XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user