[jvm-packages] update methods in test cases to be consistent (#1780)
* add back train method but mark as deprecated * fix scalastyle error * change class to object in examples * fix compilation error * update methods in test cases to be consistent * add blank lines * fix
This commit is contained in:
parent
ce708c8e7f
commit
965091c4bb
@ -38,38 +38,45 @@ trait Utils extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
||||
protected def loadLabelPoints(filePath: String, zeroBased: Boolean = false):
|
||||
List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||
sampleList += fromColValueStringToLabeledPoint(sample, zeroBased)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = {
|
||||
protected def loadLabelAndVector(filePath: String, zeroBased: Boolean = false):
|
||||
List[(Double, SparkVector)] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[(Double, SparkVector)]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabelAndVector(sample)
|
||||
sampleList += fromColValueStringToLabelAndVector(sample, zeroBased)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
|
||||
protected def fromColValueStringToLabelAndVector(line: String, zeroBased: Boolean):
|
||||
(Double, SparkVector) = {
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures(0).toDouble
|
||||
val features = labelAndFeatures.tail
|
||||
val denseFeature = new Array[Double](129)
|
||||
val denseFeature = new Array[Double](126)
|
||||
for (feature <- features) {
|
||||
val idAndValue = feature.split(":")
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
if (!zeroBased) {
|
||||
denseFeature(idAndValue(0).toInt - 1) = idAndValue(1).toDouble
|
||||
} else {
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
}
|
||||
}
|
||||
(label, new DenseVector(denseFeature))
|
||||
}
|
||||
|
||||
protected def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||
val (label, sv) = fromSVMStringToLabelAndVector(line)
|
||||
protected def fromColValueStringToLabeledPoint(line: String, zeroBased: Boolean): LabeledPoint = {
|
||||
val (label, sv) = fromColValueStringToLabelAndVector(line, zeroBased)
|
||||
LabeledPoint(label, sv)
|
||||
}
|
||||
|
||||
|
||||
@ -95,12 +95,14 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
test("test schema of XGBoostRegressionModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator.
|
||||
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile,
|
||||
zeroBased = true).iterator.
|
||||
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
||||
(id, instance.features, instance.label)
|
||||
}
|
||||
val trainingDF = {
|
||||
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile)
|
||||
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile,
|
||||
zeroBased = true)
|
||||
val labeledPointsRDD = sc.parallelize(rowList, numWorkers)
|
||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
||||
import sparkSession.implicits._
|
||||
@ -183,4 +185,5 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user