[jvm-packages] update methods in test cases to be consistent (#1780)

* add back train method but mark as deprecated

* fix scalastyle error

* change class to object in examples

* fix compilation error

* update methods in test cases to be consistent

* add blank lines

* fix
This commit is contained in:
Nan Zhu 2016-11-20 22:49:18 -05:00 committed by GitHub
parent ce708c8e7f
commit 965091c4bb
2 changed files with 21 additions and 11 deletions

View File

@ -38,38 +38,45 @@ trait Utils extends Serializable {
}
}
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
protected def loadLabelPoints(filePath: String, zeroBased: Boolean = false):
List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabeledPoint(sample)
sampleList += fromColValueStringToLabeledPoint(sample, zeroBased)
}
sampleList.toList
}
protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = {
protected def loadLabelAndVector(filePath: String, zeroBased: Boolean = false):
List[(Double, SparkVector)] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[(Double, SparkVector)]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabelAndVector(sample)
sampleList += fromColValueStringToLabelAndVector(sample, zeroBased)
}
sampleList.toList
}
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
protected def fromColValueStringToLabelAndVector(line: String, zeroBased: Boolean):
(Double, SparkVector) = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toDouble
val features = labelAndFeatures.tail
val denseFeature = new Array[Double](129)
val denseFeature = new Array[Double](126)
for (feature <- features) {
val idAndValue = feature.split(":")
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
if (!zeroBased) {
denseFeature(idAndValue(0).toInt - 1) = idAndValue(1).toDouble
} else {
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
}
}
(label, new DenseVector(denseFeature))
}
protected def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
val (label, sv) = fromSVMStringToLabelAndVector(line)
protected def fromColValueStringToLabeledPoint(line: String, zeroBased: Boolean): LabeledPoint = {
val (label, sv) = fromColValueStringToLabelAndVector(line, zeroBased)
LabeledPoint(label, sv)
}

View File

@ -95,12 +95,14 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
test("test schema of XGBoostRegressionModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator.
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile,
zeroBased = true).iterator.
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
(id, instance.features, instance.label)
}
val trainingDF = {
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile)
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile,
zeroBased = true)
val labeledPointsRDD = sc.parallelize(rowList, numWorkers)
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
import sparkSession.implicits._
@ -183,4 +185,5 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
}
}