From 965091c4bbc8e33f06f0b39d8e299ff454efc9e9 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Sun, 20 Nov 2016 22:49:18 -0500 Subject: [PATCH] [jvm-packages] update methods in test cases to be consistent (#1780) * add back train method but mark as deprecated * fix scalastyle error * change class to object in examples * fix compilation error * update methods in test cases to be consistent * add blank lines * fix --- .../ml/dmlc/xgboost4j/scala/spark/Utils.scala | 25 ++++++++++++------- .../scala/spark/XGBoostDFSuite.scala | 7 ++++-- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala index 56c373e4e..f50c8011d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala @@ -38,38 +38,45 @@ trait Utils extends Serializable { } } - protected def loadLabelPoints(filePath: String): List[LabeledPoint] = { + protected def loadLabelPoints(filePath: String, zeroBased: Boolean = false): + List[LabeledPoint] = { val file = Source.fromFile(new File(filePath)) val sampleList = new ListBuffer[LabeledPoint] for (sample <- file.getLines()) { - sampleList += fromSVMStringToLabeledPoint(sample) + sampleList += fromColValueStringToLabeledPoint(sample, zeroBased) } sampleList.toList } - protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = { + protected def loadLabelAndVector(filePath: String, zeroBased: Boolean = false): + List[(Double, SparkVector)] = { val file = Source.fromFile(new File(filePath)) val sampleList = new ListBuffer[(Double, SparkVector)] for (sample <- file.getLines()) { - sampleList += fromSVMStringToLabelAndVector(sample) + sampleList += fromColValueStringToLabelAndVector(sample, zeroBased) } sampleList.toList } - protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = { + protected def fromColValueStringToLabelAndVector(line: String, zeroBased: Boolean): + (Double, SparkVector) = { val labelAndFeatures = line.split(" ") val label = labelAndFeatures(0).toDouble val features = labelAndFeatures.tail - val denseFeature = new Array[Double](129) + val denseFeature = new Array[Double](126) for (feature <- features) { val idAndValue = feature.split(":") - denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble + if (!zeroBased) { + denseFeature(idAndValue(0).toInt - 1) = idAndValue(1).toDouble + } else { + denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble + } } (label, new DenseVector(denseFeature)) } - protected def fromSVMStringToLabeledPoint(line: String): LabeledPoint = { - val (label, sv) = fromSVMStringToLabelAndVector(line) + protected def fromColValueStringToLabeledPoint(line: String, zeroBased: Boolean): LabeledPoint = { + val (label, sv) = fromColValueStringToLabelAndVector(line, zeroBased) LabeledPoint(label, sv) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index e23fb79b1..8a0bed92b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -95,12 +95,14 @@ class XGBoostDFSuite extends SharedSparkContext with Utils { test("test schema of XGBoostRegressionModel") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "reg:linear") - val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator. + val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile, + zeroBased = true).iterator. zipWithIndex.map { case (instance: LabeledPoint, id: Int) => (id, instance.features, instance.label) } val trainingDF = { - val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile) + val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile, + zeroBased = true) val labeledPointsRDD = sc.parallelize(rowList, numWorkers) val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate() import sparkSession.implicits._ @@ -183,4 +185,5 @@ class XGBoostDFSuite extends SharedSparkContext with Utils { } + }