[jvm-packages] update methods in test cases to be consistent (#1780)
* add back train method but mark as deprecated * fix scalastyle error * change class to object in examples * fix compilation error * update methods in test cases to be consistent * add blank lines * fix
This commit is contained in:
parent
ce708c8e7f
commit
965091c4bb
@ -38,38 +38,45 @@ trait Utils extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
protected def loadLabelPoints(filePath: String, zeroBased: Boolean = false):
|
||||||
|
List[LabeledPoint] = {
|
||||||
val file = Source.fromFile(new File(filePath))
|
val file = Source.fromFile(new File(filePath))
|
||||||
val sampleList = new ListBuffer[LabeledPoint]
|
val sampleList = new ListBuffer[LabeledPoint]
|
||||||
for (sample <- file.getLines()) {
|
for (sample <- file.getLines()) {
|
||||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
sampleList += fromColValueStringToLabeledPoint(sample, zeroBased)
|
||||||
}
|
}
|
||||||
sampleList.toList
|
sampleList.toList
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = {
|
protected def loadLabelAndVector(filePath: String, zeroBased: Boolean = false):
|
||||||
|
List[(Double, SparkVector)] = {
|
||||||
val file = Source.fromFile(new File(filePath))
|
val file = Source.fromFile(new File(filePath))
|
||||||
val sampleList = new ListBuffer[(Double, SparkVector)]
|
val sampleList = new ListBuffer[(Double, SparkVector)]
|
||||||
for (sample <- file.getLines()) {
|
for (sample <- file.getLines()) {
|
||||||
sampleList += fromSVMStringToLabelAndVector(sample)
|
sampleList += fromColValueStringToLabelAndVector(sample, zeroBased)
|
||||||
}
|
}
|
||||||
sampleList.toList
|
sampleList.toList
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
|
protected def fromColValueStringToLabelAndVector(line: String, zeroBased: Boolean):
|
||||||
|
(Double, SparkVector) = {
|
||||||
val labelAndFeatures = line.split(" ")
|
val labelAndFeatures = line.split(" ")
|
||||||
val label = labelAndFeatures(0).toDouble
|
val label = labelAndFeatures(0).toDouble
|
||||||
val features = labelAndFeatures.tail
|
val features = labelAndFeatures.tail
|
||||||
val denseFeature = new Array[Double](129)
|
val denseFeature = new Array[Double](126)
|
||||||
for (feature <- features) {
|
for (feature <- features) {
|
||||||
val idAndValue = feature.split(":")
|
val idAndValue = feature.split(":")
|
||||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
if (!zeroBased) {
|
||||||
|
denseFeature(idAndValue(0).toInt - 1) = idAndValue(1).toDouble
|
||||||
|
} else {
|
||||||
|
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||||
|
}
|
||||||
}
|
}
|
||||||
(label, new DenseVector(denseFeature))
|
(label, new DenseVector(denseFeature))
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
protected def fromColValueStringToLabeledPoint(line: String, zeroBased: Boolean): LabeledPoint = {
|
||||||
val (label, sv) = fromSVMStringToLabelAndVector(line)
|
val (label, sv) = fromColValueStringToLabelAndVector(line, zeroBased)
|
||||||
LabeledPoint(label, sv)
|
LabeledPoint(label, sv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -95,12 +95,14 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
test("test schema of XGBoostRegressionModel") {
|
test("test schema of XGBoostRegressionModel") {
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "reg:linear")
|
"objective" -> "reg:linear")
|
||||||
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator.
|
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile,
|
||||||
|
zeroBased = true).iterator.
|
||||||
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
||||||
(id, instance.features, instance.label)
|
(id, instance.features, instance.label)
|
||||||
}
|
}
|
||||||
val trainingDF = {
|
val trainingDF = {
|
||||||
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile)
|
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile,
|
||||||
|
zeroBased = true)
|
||||||
val labeledPointsRDD = sc.parallelize(rowList, numWorkers)
|
val labeledPointsRDD = sc.parallelize(rowList, numWorkers)
|
||||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
||||||
import sparkSession.implicits._
|
import sparkSession.implicits._
|
||||||
@ -183,4 +185,5 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user