diff --git a/.gitignore b/.gitignore
index 5533356f5..f84ce1d89 100644
--- a/.gitignore
+++ b/.gitignore
@@ -79,3 +79,5 @@ tags
*.class
target
*.swp
+
+.DS_Store
diff --git a/doc/jvm/index.md b/doc/jvm/index.md
index e3ff666c0..128017509 100644
--- a/doc/jvm/index.md
+++ b/doc/jvm/index.md
@@ -13,7 +13,7 @@ Before you install XGBoost4J, you need to define environment variable `JAVA_HOME
After your `JAVA_HOME` is defined correctly, it is as simple as run `mvn package` under jvm-packages directory to install XGBoost4J. You can also skip the tests by running `mvn -DskipTests=true package`, if you are sure about the correctness of your local setup.
-XGBoost4J-Spark which integrates XGBoost with Spark requires to run with Spark 1.6 or newer (the default version is 1.6.1). You can build XGBoost4J-Spark as a component of XGBoost4J by running `mvn package` or specifying the spark version by `mvn -Dspark.version=1.6.0 package`.
+After integrating with Dataframe/Dataset APIs of Spark 2.0, XGBoost4J-Spark only supports compile with Spark 2.x. You can build XGBoost4J-Spark as a component of XGBoost4J by running `mvn package`, and you can specify the version of spark with `mvn -Dspark.version=2.0.0 package`. (To continue working with Spark 1.x, the users are supposed to update pom.xml by modifying the properties like `spark.version`, `scala.version`, and `scala.binary.version`. Users also need to change the implemention by replacing SparkSession with SQLContext and the type of API parameters from Dataset[_] to Dataframe)
Contents
--------
diff --git a/jvm-packages/README.md b/jvm-packages/README.md
index 80e194fd8..596f4fdf6 100644
--- a/jvm-packages/README.md
+++ b/jvm-packages/README.md
@@ -49,12 +49,17 @@ object XGBoostScalaExample {
```
### XGBoost Spark
+
+XGBoost4J-Spark supports training XGBoost model through RDD and Dataframe
+
+RDD Version:
+
```scala
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.MLUtils
import ml.dmlc.xgboost4j.scala.spark.XGBoost
-object DistTrainWithSpark {
+object SparkWithRDD {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
println(
@@ -85,6 +90,52 @@ object DistTrainWithSpark {
}
```
+Dataframe Version:
+
+```scala
+object SparkWithDataFrame {
+ def main(args: Array[String]): Unit = {
+ if (args.length != 5) {
+ println(
+ "usage: program num_of_rounds num_workers training_path test_path model_path")
+ sys.exit(1)
+ }
+ // create SparkSession
+ val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sparkConf.registerKryoClasses(Array(classOf[Booster]))
+ val sparkSession = SparkSession.builder().appName("XGBoost-spark-example").config(sparkConf).
+ getOrCreate()
+ // create training and testing dataframes
+ val inputTrainPath = args(2)
+ val inputTestPath = args(3)
+ val outputModelPath = args(4)
+ // number of iterations
+ val numRound = args(0).toInt
+ import DataUtils._
+ val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath).
+ map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
+ val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType(
+ Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
+ val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath).
+ zipWithIndex().map{ case (labeledPoint, id) =>
+ Row(id, labeledPoint.features, labeledPoint.label)}
+ val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType(
+ Array(StructField("id", LongType),
+ StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
+ // training parameters
+ val paramMap = List(
+ "eta" -> 0.1f,
+ "max_depth" -> 2,
+ "objective" -> "binary:logistic").toMap
+ val xgboostModel = XGBoost.trainWithDataset(
+ trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
+ // xgboost-spark appends the column containing prediction results
+ xgboostModel.transform(testDF).show()
+ }
+}
+```
+
### XGBoost Flink
```scala
import ml.dmlc.xgboost4j.scala.flink.XGBoost
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 3c4401c91..d4d1bde7d 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -14,8 +14,6 @@
1.7
1.7
3.3.9
- 2.10.5
- 2.10
xgboost4j
@@ -25,13 +23,15 @@
- spark-1.x
+ spark-2.x
true
- 1.6.1
- 2.10
+ 2.0.0
+ _2.11
+ 2.11.8
+ 2.11
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala
new file mode 100644
index 000000000..0130a1daa
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala
@@ -0,0 +1,65 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.example.spark
+
+import ml.dmlc.xgboost4j.scala.Booster
+import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.{SparkContext, SparkConf}
+
+object SparkWithDataFrame {
+ def main(args: Array[String]): Unit = {
+ if (args.length != 5) {
+ println(
+ "usage: program num_of_rounds num_workers training_path test_path model_path")
+ sys.exit(1)
+ }
+ // create SparkSession
+ val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sparkConf.registerKryoClasses(Array(classOf[Booster]))
+ val sqlContext = new SQLContext(new SparkContext(sparkConf))
+ // create training and testing dataframes
+ val inputTrainPath = args(2)
+ val inputTestPath = args(3)
+ val outputModelPath = args(4)
+ // number of iterations
+ val numRound = args(0).toInt
+ import DataUtils._
+ val trainRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTrainPath).
+ map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
+ val trainDF = sqlContext.createDataFrame(trainRDDOfRows, StructType(
+ Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
+ val testRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTestPath).
+ zipWithIndex().map{ case (labeledPoint, id) =>
+ Row(id, labeledPoint.features, labeledPoint.label)}
+ val testDF = sqlContext.createDataFrame(testRDDOfRows, StructType(
+ Array(StructField("id", LongType),
+ StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
+ // training parameters
+ val paramMap = List(
+ "eta" -> 0.1f,
+ "max_depth" -> 2,
+ "objective" -> "binary:logistic").toMap
+ val xgboostModel = XGBoost.trainWithDataFrame(
+ trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
+ // xgboost-spark appends the column containing prediction results
+ xgboostModel.transform(testDF).show()
+ }
+}
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala
similarity index 94%
rename from jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
rename to jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala
index b96089e42..b731a0b2d 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.util.MLUtils
-object DistTrainWithSpark {
+object SparkWithRDD {
def main(args: Array[String]): Unit = {
if (args.length != 5) {
println(
@@ -45,7 +45,7 @@ object DistTrainWithSpark {
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
- val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
+ val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
useExternalMemory = true)
xgboostModel.booster.predict(new DMatrix(testSet))
// save model to HDFS path
diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml
index fd9c0be78..0414195ac 100644
--- a/jvm-packages/xgboost4j-flink/pom.xml
+++ b/jvm-packages/xgboost4j-flink/pom.xml
@@ -35,22 +35,17 @@
org.apache.flink
- flink-java
+ flink-scala${flink.suffix}
0.10.2
org.apache.flink
- flink-scala
+ flink-clients${flink.suffix}
0.10.2
org.apache.flink
- flink-clients
- 0.10.2
-
-
- org.apache.flink
- flink-ml
+ flink-ml${flink.suffix}
0.10.2
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
index 371c59c14..4fae9ccd1 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala
@@ -18,10 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
-import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
-import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
-
import ml.dmlc.xgboost4j.LabeledPoint
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
+import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
object DataUtils extends Serializable {
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 8ebf080fd..3491a63cc 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.{SparkContext, TaskContext}
object XGBoost extends Serializable {
@@ -111,6 +112,33 @@ object XGBoost extends Serializable {
}.cache()
}
+ /**
+ *
+ * @param trainingData the trainingset represented as DataFrame
+ * @param params Map containing the parameters to configure XGBoost
+ * @param round the number of iterations
+ * @param nWorkers the number of xgboost workers, 0 by default which means that the number of
+ * workers equals to the partition number of trainingData RDD
+ * @param obj the user-defined objective function, null by default
+ * @param eval the user-defined evaluation function, null by default
+ * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
+ * true, the user may save the RAM cost for running XGBoost within Spark
+ * @param missing the value represented the missing value in the dataset
+ * @param inputCol the name of input column, "features" as default value
+ * @param labelCol the name of output column, "label" as default value
+ * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
+ * @return XGBoostModel when successful training
+ */
+ @throws(classOf[XGBoostError])
+ def trainWithDataFrame(trainingData: Dataset[_],
+ params: Map[String, Any], round: Int,
+ nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
+ useExternalMemory: Boolean = false, missing: Float = Float.NaN,
+ inputCol: String = "features", labelCol: String = "label"): XGBoostModel = {
+ new XGBoostEstimator(inputCol, labelCol, params, round, nWorkers, obj, eval,
+ useExternalMemory, missing).fit(trainingData)
+ }
+
/**
*
* @param trainingData the trainingset represented as RDD
@@ -127,9 +155,9 @@ object XGBoost extends Serializable {
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
- def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
- nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
- useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
+ def trainWithRDD(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
+ nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
+ useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
val tracker = new RabitTracker(nWorkers)
implicit val sc = trainingData.sparkContext
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
new file mode 100644
index 000000000..64ee91f8b
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
@@ -0,0 +1,81 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
+import org.apache.spark.ml.{Predictor, Estimator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{NumericType, DoubleType, StructType}
+import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row}
+
+/**
+ * the estimator wrapping XGBoost to produce a training model
+ *
+ * @param inputCol the name of input column
+ * @param labelCol the name of label column
+ * @param xgboostParams the parameters configuring XGBoost
+ * @param round the number of iterations to train
+ * @param nWorkers the total number of workers of xgboost
+ * @param obj the customized objective function, default to be null and using the default in model
+ * @param eval the customized eval function, default to be null and using the default in model
+ * @param useExternalMemory whether to use external memory when training
+ * @param missing the value taken as missing
+ */
+class XGBoostEstimator(
+ inputCol: String, labelCol: String,
+ xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
+ obj: ObjectiveTrait = null,
+ eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN)
+ extends Estimator[XGBoostModel] {
+
+ override val uid: String = Identifiable.randomUID("XGBoostEstimator")
+
+
+ /**
+ * produce a XGBoostModel by fitting the given dataset
+ */
+ def fit(trainingSet: Dataset[_]): XGBoostModel = {
+ val instances = trainingSet.select(
+ col(inputCol), col(labelCol).cast(DoubleType)).rdd.map {
+ case Row(feature: Vector, label: Double) =>
+ LabeledPoint(label, feature)
+ }
+ transformSchema(trainingSet.schema, logging = true)
+ val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj,
+ eval, useExternalMemory, missing).setParent(this)
+ copyValues(trainedModel)
+ }
+
+ override def copy(extra: ParamMap): Estimator[XGBoostModel] = {
+ defaultCopy(extra)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ // check input type, for now we only support vectorUDT as the input feature type
+ val inputType = schema(inputCol).dataType
+ require(inputType.equals(new VectorUDT), s"the type of input column $inputCol has to VectorUDT")
+ // check label Type,
+ val labelType = schema(labelCol).dataType
+ require(labelType.isInstanceOf[NumericType], s"the type of label column $labelCol has to" +
+ s" be NumericType")
+ schema
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
index 597f08031..b33bfd33e 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
@@ -16,16 +16,28 @@
package ml.dmlc.xgboost4j.scala.spark
-import org.apache.hadoop.fs.{Path, FileSystem}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.{TaskContext, SparkContext}
-import org.apache.spark.mllib.linalg.{DenseVector, Vector}
-import org.apache.spark.rdd.RDD
-import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
-import ml.dmlc.xgboost4j.scala.{EvalTrait, Booster, DMatrix}
import scala.collection.JavaConverters._
-class XGBoostModel(_booster: Booster) extends Serializable {
+import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
+import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
+import org.apache.hadoop.fs.Path
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.{Model, PredictionModel}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{VectorUDT, DenseVector, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.{SparkContext, TaskContext}
+
+class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializable {
+
+ var inputCol = "features"
+ var outputCol = "prediction"
+ var outputType: DataType = ArrayType(elementType = FloatType, containsNull = false)
/**
* evaluate XGBoostModel with a RDD-wrapped dataset
@@ -40,6 +52,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
eval: EvalTrait,
evalName: String,
useExternalCache: Boolean = false): String = {
+ val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
labeledPointsPartition =>
@@ -55,7 +68,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
}
}
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
- val predictions = _booster.predict(dMatrix)
+ val predictions = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(Some(eval.eval(predictions, dMatrix)))
} else {
@@ -152,8 +165,71 @@ class XGBoostModel(_booster: Booster) extends Serializable {
outputStream.close()
}
- /**
- * Get the booster instance of this model
- */
def booster: Booster = _booster
+
+ override val uid: String = Identifiable.randomUID("XGBoostModel")
+
+ override def copy(extra: ParamMap): XGBoostModel = {
+ defaultCopy(extra)
+ }
+
+ /**
+ * produces the prediction results and append as an additional column in the original dataset
+ * NOTE: the prediction results is kept as the original format of xgboost
+ * @return the original dataframe with an additional column containing prediction results
+ */
+ override def transform(testSet: Dataset[_]): DataFrame = {
+ transform(testSet, None)
+ }
+
+ /**
+ * produces the prediction results and append as an additional column in the original dataset
+ * NOTE: the prediction results is transformed by applying the transformation function
+ * predictResultTrans to the original xgboost output
+ * @param predictResultTrans the function to transform xgboost output to the expected format
+ * @return the original dataframe with an additional column containing prediction results
+ */
+ def transform(testSet: Dataset[_], predictResultTrans: Option[Array[Float] => DataType]):
+ DataFrame = {
+ transformSchema(testSet.schema, logging = true)
+ val broadcastBooster = testSet.sqlContext.sparkContext.broadcast(_booster)
+ val instances = testSet.rdd.mapPartitions {
+ rowIterator =>
+ if (rowIterator.hasNext) {
+ val (rowItr1, rowItr2) = rowIterator.duplicate
+ val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)).
+ toList.iterator
+ import DataUtils._
+ val testDataset = new DMatrix(vectorIterator, null)
+ val rowPredictResults = broadcastBooster.value.predict(testDataset)
+ val predictResults = {
+ if (predictResultTrans.isDefined) {
+ rowPredictResults.map(prediction => Row(predictResultTrans.get(prediction))).iterator
+ } else {
+ rowPredictResults.map(prediction => Row(prediction)).iterator
+ }
+ }
+ rowItr1.zip(predictResults).map {
+ case (originalColumns: Row, predictColumn: Row) =>
+ Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
+ }
+ } else {
+ Iterator[Row]()
+ }
+ }
+ testSet.sqlContext.createDataFrame(instances, testSet.schema.add("prediction", outputType)).
+ cache()
+ }
+
+ @DeveloperApi
+ override def transformSchema(schema: StructType): StructType = {
+ if (schema.fieldNames.contains(outputCol)) {
+ throw new IllegalArgumentException(s"Output column $outputCol already exists.")
+ }
+ val inputType = schema(inputCol).dataType
+ require(inputType.equals(new VectorUDT),
+ s"the type of input column $inputCol has to be VectorUDT")
+ val outputFields = schema.fields :+ StructField(outputCol, outputType, nullable = false)
+ StructType(outputFields)
+ }
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala
new file mode 100644
index 000000000..a73cb9fac
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala
@@ -0,0 +1,38 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+trait SharedSparkContext extends FunSuite with BeforeAndAfter {
+
+ protected implicit var sc: SparkContext = null
+
+ before {
+ // build SparkContext
+ val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
+ sc = new SparkContext(sparkConf)
+ sc.setLogLevel("ERROR")
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala
new file mode 100644
index 000000000..7c8ac1744
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala
@@ -0,0 +1,107 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.io.File
+
+import scala.collection.mutable.ListBuffer
+import scala.io.Source
+
+import ml.dmlc.xgboost4j.java.XGBoostError
+import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
+import org.apache.commons.logging.LogFactory
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+trait Utils extends SharedSparkContext {
+ protected val numWorkers = Runtime.getRuntime().availableProcessors()
+
+ protected class EvalError extends EvalTrait {
+
+ val logger = LogFactory.getLog(classOf[EvalError])
+
+ private[xgboost4j] var evalMetric: String = "custom_error"
+
+ /**
+ * get evaluate metric
+ *
+ * @return evalMetric
+ */
+ override def getMetric: String = evalMetric
+
+ /**
+ * evaluate with predicts and data
+ *
+ * @param predicts predictions as array
+ * @param dmat data matrix to evaluate
+ * @return result of the metric
+ */
+ override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
+ var error: Float = 0f
+ var labels: Array[Float] = null
+ try {
+ labels = dmat.getLabel
+ } catch {
+ case ex: XGBoostError =>
+ logger.error(ex)
+ return -1f
+ }
+ val nrow: Int = predicts.length
+ for (i <- 0 until nrow) {
+ if (labels(i) == 0.0 && predicts(i)(0) > 0) {
+ error += 1
+ } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
+ error += 1
+ }
+ }
+ error / labels.length
+ }
+ }
+
+ protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
+ val file = Source.fromFile(new File(filePath))
+ val sampleList = new ListBuffer[LabeledPoint]
+ for (sample <- file.getLines()) {
+ sampleList += fromSVMStringToLabeledPoint(sample)
+ }
+ sampleList.toList
+ }
+
+ protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
+ val labelAndFeatures = line.split(" ")
+ val label = labelAndFeatures(0).toDouble
+ val features = labelAndFeatures.tail
+ val denseFeature = new Array[Double](129)
+ for (feature <- features) {
+ val idAndValue = feature.split(":")
+ denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
+ }
+ (label, new DenseVector(denseFeature))
+ }
+
+ protected def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
+ val (label, sv) = fromSVMStringToLabelAndVector(line)
+ LabeledPoint(label, sv)
+ }
+
+ protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
+ val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
+ sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala
new file mode 100644
index 000000000..527f5bf15
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala
@@ -0,0 +1,129 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+import scala.io.Source
+
+import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
+import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql._
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
+
+class XGBoostDFSuite extends Utils {
+
+ private def loadRow(filePath: String): List[Row] = {
+ val file = Source.fromFile(new File(filePath))
+ val rowList = new ListBuffer[Row]
+ for (rowLine <- file.getLines()) {
+ rowList += fromSVMStringToRow(rowLine)
+ }
+ rowList.toList
+ }
+
+ private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None):
+ DataFrame = {
+ val rowList = loadRow(getClass.getResource("/agaricus.txt.train").getFile)
+ val rowRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers)
+ val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
+ sparkSession.createDataFrame(rowRDD,
+ StructType(Array(StructField("label", DoubleType, nullable = false),
+ StructField("features", new VectorUDT, nullable = false))))
+ }
+
+ private def fromSVMStringToRow(line: String): Row = {
+ val (label, sv) = fromSVMStringToLabelAndVector(line)
+ Row(label, sv)
+ }
+
+ test("test consistency between training with dataframe and RDD") {
+ val trainingDF = buildTrainingDataframe()
+ val trainingRDD = buildTrainingRDD()
+ val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
+ "objective" -> "binary:logistic").toMap
+ val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
+ round = 5, nWorkers = numWorkers, useExternalMemory = false)
+ val xgBoostModelWithRDD = XGBoost.trainWithRDD(trainingRDD, paramMap,
+ round = 5, nWorkers = numWorkers, useExternalMemory = false)
+ val eval = new EvalError()
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ assert(
+ eval.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix) ===
+ eval.eval(xgBoostModelWithRDD.booster.predict(testSetDMatrix, outPutMargin = true),
+ testSetDMatrix))
+ }
+
+ test("test transform of dataframe-based model") {
+ val trainingDF = buildTrainingDataframe()
+ val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
+ "objective" -> "binary:logistic").toMap
+ val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
+ round = 5, nWorkers = numWorkers, useExternalMemory = false)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
+ val testRowsRDD = sc.parallelize(testSet.zipWithIndex, numWorkers).map{
+ case (instance: LabeledPoint, id: Int) =>
+ Row(id, instance.features, instance.label)
+ }
+ val testDF = trainingDF.sparkSession.createDataFrame(testRowsRDD, StructType(
+ Array(StructField("id", IntegerType),
+ StructField("features", new VectorUDT), StructField("label", DoubleType))))
+ xgBoostModelWithDF.transform(testDF).show()
+ }
+
+ test("test order preservation of dataframe-based model") {
+ val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
+ "objective" -> "binary:logistic").toMap
+ val trainingItr = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile).
+ iterator
+ val (testItr, auxTestItr) =
+ loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate
+ import DataUtils._
+ val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
+ val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
+ val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
+ val predResultFromSeq = xgboostModel.predict(testDMatrix)
+ val testRowsRDD = sc.parallelize(
+ auxTestItr.toList.zipWithIndex, numWorkers).map {
+ case (instance: LabeledPoint, id: Int) =>
+ Row(id, instance.features, instance.label)
+ }
+ val trainingDF = buildTrainingDataframe()
+ val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
+ round = 5, nWorkers = numWorkers, useExternalMemory = false)
+ val testDF = trainingDF.sqlContext.createDataFrame(testRowsRDD, StructType(
+ Array(StructField("id", IntegerType), StructField("features", new VectorUDT),
+ StructField("label", DoubleType))))
+ val predResultsFromDF =
+ xgBoostModelWithDF.transform(testDF).collect().map(row => (row.getAs[Int]("id"),
+ row.getAs[mutable.WrappedArray[Float]]("prediction"))).toMap
+ for (i <- predResultFromSeq.indices) {
+ assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
+ for (j <- predResultFromSeq(i).indices) {
+ assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
+ }
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
similarity index 68%
rename from jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
rename to jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index 639a19c91..a6877b096 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -20,107 +20,20 @@ import java.io.File
import java.nio.file.Files
import scala.collection.mutable.ListBuffer
-import scala.io.Source
import scala.util.Random
-import org.apache.commons.logging.LogFactory
-import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors, DenseVector}
+import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix}
+import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => ScalaXGBoost}
+import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
-import org.scalatest.{BeforeAndAfter, FunSuite}
-import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError}
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
-
-class XGBoostSuite extends FunSuite with BeforeAndAfter {
-
- private implicit var sc: SparkContext = null
- private val numWorkers = Runtime.getRuntime().availableProcessors()
-
- private class EvalError extends EvalTrait {
-
- val logger = LogFactory.getLog(classOf[EvalError])
-
- private[xgboost4j] var evalMetric: String = "custom_error"
-
- /**
- * get evaluate metric
- *
- * @return evalMetric
- */
- override def getMetric: String = evalMetric
-
- /**
- * evaluate with predicts and data
- *
- * @param predicts predictions as array
- * @param dmat data matrix to evaluate
- * @return result of the metric
- */
- override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
- var error: Float = 0f
- var labels: Array[Float] = null
- try {
- labels = dmat.getLabel
- } catch {
- case ex: XGBoostError =>
- logger.error(ex)
- return -1f
- }
- val nrow: Int = predicts.length
- for (i <- 0 until nrow) {
- if (labels(i) == 0.0 && predicts(i)(0) > 0) {
- error += 1
- } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
- error += 1
- }
- }
- error / labels.length
- }
- }
-
- before {
- // build SparkContext
- val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
- sc = new SparkContext(sparkConf)
- }
-
- after {
- if (sc != null) {
- sc.stop()
- }
- }
-
- private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
- val labelAndFeatures = line.split(" ")
- val label = labelAndFeatures(0).toInt
- val features = labelAndFeatures.tail
- val denseFeature = new Array[Double](129)
- for (feature <- features) {
- val idAndValue = feature.split(":")
- denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
- }
- LabeledPoint(label, new DenseVector(denseFeature))
- }
-
- private def readFile(filePath: String): List[LabeledPoint] = {
- val file = Source.fromFile(new File(filePath))
- val sampleList = new ListBuffer[LabeledPoint]
- for (sample <- file.getLines()) {
- sampleList += fromSVMStringToLabeledPoint(sample)
- }
- sampleList.toList
- }
-
- private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
- val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
- sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
- }
+class XGBoostGeneralSuite extends Utils {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD()
- val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val boosterRDD = XGBoost.buildDistributedBoosters(
@@ -145,14 +58,15 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
sc = null
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
val customSparkContext = new SparkContext(sparkConf)
+ customSparkContext.setLogLevel("ERROR")
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
- val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap
- val xgBoostModel = XGBoost.train(trainingRDD, paramMap, round = 5,
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
@@ -194,13 +108,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
val testRDD = buildDenseRDD().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
- val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
}
test("test consistency of prediction functions with RDD") {
val trainingRDD = buildTrainingRDD()
- val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile)
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
val testCollection = testRDD.collect()
for (i <- testSet.indices) {
@@ -208,7 +122,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
- val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1 = predRDD.collect()(0)
import DataUtils._
@@ -225,26 +139,25 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
}
val trainingRDD = buildTrainingRDD()
val testRDD = buildEmptyRDD()
- import DataUtils._
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
- val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
println(xgBoostModel.predict(testRDD).collect().length === 0)
}
test("test model consistency after save and load") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD()
- val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
- val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
assert(evalResults < 0.1)
@@ -261,12 +174,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
set("spark.task.cpus", "4")
val customSparkContext = new SparkContext(sparkConf)
+ customSparkContext.setLogLevel("ERROR")
// start another app
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic", "nthread" -> 6).toMap
intercept[IllegalArgumentException] {
- XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
}
customSparkContext.stop()
}
@@ -279,13 +193,14 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
+ customSparkContext.setLogLevel("ERROR")
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
- val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
- val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
+ val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
customSparkContext.stop()