[jvm-packages] Integration with Spark Dataframe/Dataset (#1559)
* bump up to scala 2.11 * framework of data frame integration * test consistency between RDD and DataFrame * order preservation * test order preservation * example code and fix makefile * improve type checking * improve APIs * user docs * work around travis CI's limitation on log length * adjust test structure * integrate with Spark -1 .x * spark 2.x integration * remove spark 1.x implementation but provide instructions on how to downgrade
This commit is contained in:
parent
7ff742ebf7
commit
fb02797e2a
2
.gitignore
vendored
2
.gitignore
vendored
@ -79,3 +79,5 @@ tags
|
|||||||
*.class
|
*.class
|
||||||
target
|
target
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
|||||||
@ -13,7 +13,7 @@ Before you install XGBoost4J, you need to define environment variable `JAVA_HOME
|
|||||||
|
|
||||||
After your `JAVA_HOME` is defined correctly, it is as simple as run `mvn package` under jvm-packages directory to install XGBoost4J. You can also skip the tests by running `mvn -DskipTests=true package`, if you are sure about the correctness of your local setup.
|
After your `JAVA_HOME` is defined correctly, it is as simple as run `mvn package` under jvm-packages directory to install XGBoost4J. You can also skip the tests by running `mvn -DskipTests=true package`, if you are sure about the correctness of your local setup.
|
||||||
|
|
||||||
XGBoost4J-Spark which integrates XGBoost with Spark requires to run with Spark 1.6 or newer (the default version is 1.6.1). You can build XGBoost4J-Spark as a component of XGBoost4J by running `mvn package` or specifying the spark version by `mvn -Dspark.version=1.6.0 package`.
|
After integrating with Dataframe/Dataset APIs of Spark 2.0, XGBoost4J-Spark only supports compile with Spark 2.x. You can build XGBoost4J-Spark as a component of XGBoost4J by running `mvn package`, and you can specify the version of spark with `mvn -Dspark.version=2.0.0 package`. (To continue working with Spark 1.x, the users are supposed to update pom.xml by modifying the properties like `spark.version`, `scala.version`, and `scala.binary.version`. Users also need to change the implemention by replacing SparkSession with SQLContext and the type of API parameters from Dataset[_] to Dataframe)
|
||||||
|
|
||||||
Contents
|
Contents
|
||||||
--------
|
--------
|
||||||
|
|||||||
@ -49,12 +49,17 @@ object XGBoostScalaExample {
|
|||||||
```
|
```
|
||||||
|
|
||||||
### XGBoost Spark
|
### XGBoost Spark
|
||||||
|
|
||||||
|
XGBoost4J-Spark supports training XGBoost model through RDD and Dataframe
|
||||||
|
|
||||||
|
RDD Version:
|
||||||
|
|
||||||
```scala
|
```scala
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||||
|
|
||||||
object DistTrainWithSpark {
|
object SparkWithRDD {
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
if (args.length != 3) {
|
if (args.length != 3) {
|
||||||
println(
|
println(
|
||||||
@ -85,6 +90,52 @@ object DistTrainWithSpark {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Dataframe Version:
|
||||||
|
|
||||||
|
```scala
|
||||||
|
object SparkWithDataFrame {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
if (args.length != 5) {
|
||||||
|
println(
|
||||||
|
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
||||||
|
sys.exit(1)
|
||||||
|
}
|
||||||
|
// create SparkSession
|
||||||
|
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
|
||||||
|
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
|
val sparkSession = SparkSession.builder().appName("XGBoost-spark-example").config(sparkConf).
|
||||||
|
getOrCreate()
|
||||||
|
// create training and testing dataframes
|
||||||
|
val inputTrainPath = args(2)
|
||||||
|
val inputTestPath = args(3)
|
||||||
|
val outputModelPath = args(4)
|
||||||
|
// number of iterations
|
||||||
|
val numRound = args(0).toInt
|
||||||
|
import DataUtils._
|
||||||
|
val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath).
|
||||||
|
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
|
||||||
|
val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType(
|
||||||
|
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||||
|
val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath).
|
||||||
|
zipWithIndex().map{ case (labeledPoint, id) =>
|
||||||
|
Row(id, labeledPoint.features, labeledPoint.label)}
|
||||||
|
val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType(
|
||||||
|
Array(StructField("id", LongType),
|
||||||
|
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||||
|
// training parameters
|
||||||
|
val paramMap = List(
|
||||||
|
"eta" -> 0.1f,
|
||||||
|
"max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val xgboostModel = XGBoost.trainWithDataset(
|
||||||
|
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
|
||||||
|
// xgboost-spark appends the column containing prediction results
|
||||||
|
xgboostModel.transform(testDF).show()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### XGBoost Flink
|
### XGBoost Flink
|
||||||
```scala
|
```scala
|
||||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||||
|
|||||||
@ -14,8 +14,6 @@
|
|||||||
<maven.compiler.source>1.7</maven.compiler.source>
|
<maven.compiler.source>1.7</maven.compiler.source>
|
||||||
<maven.compiler.target>1.7</maven.compiler.target>
|
<maven.compiler.target>1.7</maven.compiler.target>
|
||||||
<maven.version>3.3.9</maven.version>
|
<maven.version>3.3.9</maven.version>
|
||||||
<scala.version>2.10.5</scala.version>
|
|
||||||
<scala.binary.version>2.10</scala.binary.version>
|
|
||||||
</properties>
|
</properties>
|
||||||
<modules>
|
<modules>
|
||||||
<module>xgboost4j</module>
|
<module>xgboost4j</module>
|
||||||
@ -25,13 +23,15 @@
|
|||||||
</modules>
|
</modules>
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>spark-1.x</id>
|
<id>spark-2.x</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>true</activeByDefault>
|
<activeByDefault>true</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
<properties>
|
<properties>
|
||||||
<spark.version>1.6.1</spark.version>
|
<spark.version>2.0.0</spark.version>
|
||||||
<scala.binary.version>2.10</scala.binary.version>
|
<flink.suffix>_2.11</flink.suffix>
|
||||||
|
<scala.version>2.11.8</scala.version>
|
||||||
|
<scala.binary.version>2.11</scala.binary.version>
|
||||||
</properties>
|
</properties>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|||||||
@ -0,0 +1,65 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.example.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils}
|
||||||
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.{SQLContext, Row}
|
||||||
|
import org.apache.spark.{SparkContext, SparkConf}
|
||||||
|
|
||||||
|
object SparkWithDataFrame {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
if (args.length != 5) {
|
||||||
|
println(
|
||||||
|
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
||||||
|
sys.exit(1)
|
||||||
|
}
|
||||||
|
// create SparkSession
|
||||||
|
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
|
||||||
|
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
|
val sqlContext = new SQLContext(new SparkContext(sparkConf))
|
||||||
|
// create training and testing dataframes
|
||||||
|
val inputTrainPath = args(2)
|
||||||
|
val inputTestPath = args(3)
|
||||||
|
val outputModelPath = args(4)
|
||||||
|
// number of iterations
|
||||||
|
val numRound = args(0).toInt
|
||||||
|
import DataUtils._
|
||||||
|
val trainRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTrainPath).
|
||||||
|
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
|
||||||
|
val trainDF = sqlContext.createDataFrame(trainRDDOfRows, StructType(
|
||||||
|
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||||
|
val testRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTestPath).
|
||||||
|
zipWithIndex().map{ case (labeledPoint, id) =>
|
||||||
|
Row(id, labeledPoint.features, labeledPoint.label)}
|
||||||
|
val testDF = sqlContext.createDataFrame(testRDDOfRows, StructType(
|
||||||
|
Array(StructField("id", LongType),
|
||||||
|
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||||
|
// training parameters
|
||||||
|
val paramMap = List(
|
||||||
|
"eta" -> 0.1f,
|
||||||
|
"max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val xgboostModel = XGBoost.trainWithDataFrame(
|
||||||
|
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
|
||||||
|
// xgboost-spark appends the column containing prediction results
|
||||||
|
xgboostModel.transform(testDF).show()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
|
|||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
|
||||||
object DistTrainWithSpark {
|
object SparkWithRDD {
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
if (args.length != 5) {
|
if (args.length != 5) {
|
||||||
println(
|
println(
|
||||||
@ -45,7 +45,7 @@ object DistTrainWithSpark {
|
|||||||
"eta" -> 0.1f,
|
"eta" -> 0.1f,
|
||||||
"max_depth" -> 2,
|
"max_depth" -> 2,
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||||
useExternalMemory = true)
|
useExternalMemory = true)
|
||||||
xgboostModel.booster.predict(new DMatrix(testSet))
|
xgboostModel.booster.predict(new DMatrix(testSet))
|
||||||
// save model to HDFS path
|
// save model to HDFS path
|
||||||
@ -35,22 +35,17 @@
|
|||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.flink</groupId>
|
<groupId>org.apache.flink</groupId>
|
||||||
<artifactId>flink-java</artifactId>
|
<artifactId>flink-scala${flink.suffix}</artifactId>
|
||||||
<version>0.10.2</version>
|
<version>0.10.2</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.flink</groupId>
|
<groupId>org.apache.flink</groupId>
|
||||||
<artifactId>flink-scala</artifactId>
|
<artifactId>flink-clients${flink.suffix}</artifactId>
|
||||||
<version>0.10.2</version>
|
<version>0.10.2</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.flink</groupId>
|
<groupId>org.apache.flink</groupId>
|
||||||
<artifactId>flink-clients</artifactId>
|
<artifactId>flink-ml${flink.suffix}</artifactId>
|
||||||
<version>0.10.2</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.flink</groupId>
|
|
||||||
<artifactId>flink-ml</artifactId>
|
|
||||||
<version>0.10.2</version>
|
<version>0.10.2</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|||||||
@ -18,10 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
|
|
||||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
|
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
|
||||||
|
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||||
|
|
||||||
object DataUtils extends Serializable {
|
object DataUtils extends Serializable {
|
||||||
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
|
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
|
|||||||
import org.apache.spark.mllib.linalg.SparseVector
|
import org.apache.spark.mllib.linalg.SparseVector
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
@ -111,6 +112,33 @@ object XGBoost extends Serializable {
|
|||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param trainingData the trainingset represented as DataFrame
|
||||||
|
* @param params Map containing the parameters to configure XGBoost
|
||||||
|
* @param round the number of iterations
|
||||||
|
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||||
|
* workers equals to the partition number of trainingData RDD
|
||||||
|
* @param obj the user-defined objective function, null by default
|
||||||
|
* @param eval the user-defined evaluation function, null by default
|
||||||
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
|
* @param missing the value represented the missing value in the dataset
|
||||||
|
* @param inputCol the name of input column, "features" as default value
|
||||||
|
* @param labelCol the name of output column, "label" as default value
|
||||||
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
|
* @return XGBoostModel when successful training
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def trainWithDataFrame(trainingData: Dataset[_],
|
||||||
|
params: Map[String, Any], round: Int,
|
||||||
|
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||||
|
useExternalMemory: Boolean = false, missing: Float = Float.NaN,
|
||||||
|
inputCol: String = "features", labelCol: String = "label"): XGBoostModel = {
|
||||||
|
new XGBoostEstimator(inputCol, labelCol, params, round, nWorkers, obj, eval,
|
||||||
|
useExternalMemory, missing).fit(trainingData)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param trainingData the trainingset represented as RDD
|
* @param trainingData the trainingset represented as RDD
|
||||||
@ -127,9 +155,9 @@ object XGBoost extends Serializable {
|
|||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def trainWithRDD(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
val tracker = new RabitTracker(nWorkers)
|
val tracker = new RabitTracker(nWorkers)
|
||||||
implicit val sc = trainingData.sparkContext
|
implicit val sc = trainingData.sparkContext
|
||||||
|
|||||||
@ -0,0 +1,81 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
|
import org.apache.spark.ml.{Predictor, Estimator}
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.util.Identifiable
|
||||||
|
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types.{NumericType, DoubleType, StructType}
|
||||||
|
import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the estimator wrapping XGBoost to produce a training model
|
||||||
|
*
|
||||||
|
* @param inputCol the name of input column
|
||||||
|
* @param labelCol the name of label column
|
||||||
|
* @param xgboostParams the parameters configuring XGBoost
|
||||||
|
* @param round the number of iterations to train
|
||||||
|
* @param nWorkers the total number of workers of xgboost
|
||||||
|
* @param obj the customized objective function, default to be null and using the default in model
|
||||||
|
* @param eval the customized eval function, default to be null and using the default in model
|
||||||
|
* @param useExternalMemory whether to use external memory when training
|
||||||
|
* @param missing the value taken as missing
|
||||||
|
*/
|
||||||
|
class XGBoostEstimator(
|
||||||
|
inputCol: String, labelCol: String,
|
||||||
|
xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
|
||||||
|
obj: ObjectiveTrait = null,
|
||||||
|
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN)
|
||||||
|
extends Estimator[XGBoostModel] {
|
||||||
|
|
||||||
|
override val uid: String = Identifiable.randomUID("XGBoostEstimator")
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* produce a XGBoostModel by fitting the given dataset
|
||||||
|
*/
|
||||||
|
def fit(trainingSet: Dataset[_]): XGBoostModel = {
|
||||||
|
val instances = trainingSet.select(
|
||||||
|
col(inputCol), col(labelCol).cast(DoubleType)).rdd.map {
|
||||||
|
case Row(feature: Vector, label: Double) =>
|
||||||
|
LabeledPoint(label, feature)
|
||||||
|
}
|
||||||
|
transformSchema(trainingSet.schema, logging = true)
|
||||||
|
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj,
|
||||||
|
eval, useExternalMemory, missing).setParent(this)
|
||||||
|
copyValues(trainedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): Estimator[XGBoostModel] = {
|
||||||
|
defaultCopy(extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
|
// check input type, for now we only support vectorUDT as the input feature type
|
||||||
|
val inputType = schema(inputCol).dataType
|
||||||
|
require(inputType.equals(new VectorUDT), s"the type of input column $inputCol has to VectorUDT")
|
||||||
|
// check label Type,
|
||||||
|
val labelType = schema(labelCol).dataType
|
||||||
|
require(labelType.isInstanceOf[NumericType], s"the type of label column $labelCol has to" +
|
||||||
|
s" be NumericType")
|
||||||
|
schema
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,16 +16,28 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
|
||||||
import org.apache.spark.{TaskContext, SparkContext}
|
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
|
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, DMatrix => JDMatrix}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, Booster, DMatrix}
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
class XGBoostModel(_booster: Booster) extends Serializable {
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.apache.spark.annotation.DeveloperApi
|
||||||
|
import org.apache.spark.ml.{Model, PredictionModel}
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.util.Identifiable
|
||||||
|
import org.apache.spark.mllib.linalg.{VectorUDT, DenseVector, Vector}
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
|
class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializable {
|
||||||
|
|
||||||
|
var inputCol = "features"
|
||||||
|
var outputCol = "prediction"
|
||||||
|
var outputType: DataType = ArrayType(elementType = FloatType, containsNull = false)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||||
@ -40,6 +52,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
evalName: String,
|
evalName: String,
|
||||||
useExternalCache: Boolean = false): String = {
|
useExternalCache: Boolean = false): String = {
|
||||||
|
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||||
val appName = evalDataset.context.appName
|
val appName = evalDataset.context.appName
|
||||||
val allEvalMetrics = evalDataset.mapPartitions {
|
val allEvalMetrics = evalDataset.mapPartitions {
|
||||||
labeledPointsPartition =>
|
labeledPointsPartition =>
|
||||||
@ -55,7 +68,7 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||||
val predictions = _booster.predict(dMatrix)
|
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
Iterator(Some(eval.eval(predictions, dMatrix)))
|
Iterator(Some(eval.eval(predictions, dMatrix)))
|
||||||
} else {
|
} else {
|
||||||
@ -152,8 +165,71 @@ class XGBoostModel(_booster: Booster) extends Serializable {
|
|||||||
outputStream.close()
|
outputStream.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the booster instance of this model
|
|
||||||
*/
|
|
||||||
def booster: Booster = _booster
|
def booster: Booster = _booster
|
||||||
|
|
||||||
|
override val uid: String = Identifiable.randomUID("XGBoostModel")
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): XGBoostModel = {
|
||||||
|
defaultCopy(extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* produces the prediction results and append as an additional column in the original dataset
|
||||||
|
* NOTE: the prediction results is kept as the original format of xgboost
|
||||||
|
* @return the original dataframe with an additional column containing prediction results
|
||||||
|
*/
|
||||||
|
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||||
|
transform(testSet, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* produces the prediction results and append as an additional column in the original dataset
|
||||||
|
* NOTE: the prediction results is transformed by applying the transformation function
|
||||||
|
* predictResultTrans to the original xgboost output
|
||||||
|
* @param predictResultTrans the function to transform xgboost output to the expected format
|
||||||
|
* @return the original dataframe with an additional column containing prediction results
|
||||||
|
*/
|
||||||
|
def transform(testSet: Dataset[_], predictResultTrans: Option[Array[Float] => DataType]):
|
||||||
|
DataFrame = {
|
||||||
|
transformSchema(testSet.schema, logging = true)
|
||||||
|
val broadcastBooster = testSet.sqlContext.sparkContext.broadcast(_booster)
|
||||||
|
val instances = testSet.rdd.mapPartitions {
|
||||||
|
rowIterator =>
|
||||||
|
if (rowIterator.hasNext) {
|
||||||
|
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||||
|
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)).
|
||||||
|
toList.iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testDataset = new DMatrix(vectorIterator, null)
|
||||||
|
val rowPredictResults = broadcastBooster.value.predict(testDataset)
|
||||||
|
val predictResults = {
|
||||||
|
if (predictResultTrans.isDefined) {
|
||||||
|
rowPredictResults.map(prediction => Row(predictResultTrans.get(prediction))).iterator
|
||||||
|
} else {
|
||||||
|
rowPredictResults.map(prediction => Row(prediction)).iterator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rowItr1.zip(predictResults).map {
|
||||||
|
case (originalColumns: Row, predictColumn: Row) =>
|
||||||
|
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Iterator[Row]()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testSet.sqlContext.createDataFrame(instances, testSet.schema.add("prediction", outputType)).
|
||||||
|
cache()
|
||||||
|
}
|
||||||
|
|
||||||
|
@DeveloperApi
|
||||||
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
|
if (schema.fieldNames.contains(outputCol)) {
|
||||||
|
throw new IllegalArgumentException(s"Output column $outputCol already exists.")
|
||||||
|
}
|
||||||
|
val inputType = schema(inputCol).dataType
|
||||||
|
require(inputType.equals(new VectorUDT),
|
||||||
|
s"the type of input column $inputCol has to be VectorUDT")
|
||||||
|
val outputFields = schema.fields :+ StructField(outputCol, outputType, nullable = false)
|
||||||
|
StructType(outputFields)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,38 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
|
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||||
|
|
||||||
|
trait SharedSparkContext extends FunSuite with BeforeAndAfter {
|
||||||
|
|
||||||
|
protected implicit var sc: SparkContext = null
|
||||||
|
|
||||||
|
before {
|
||||||
|
// build SparkContext
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||||
|
sc = new SparkContext(sparkConf)
|
||||||
|
sc.setLogLevel("ERROR")
|
||||||
|
}
|
||||||
|
|
||||||
|
after {
|
||||||
|
if (sc != null) {
|
||||||
|
sc.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,107 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
import scala.io.Source
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
|
trait Utils extends SharedSparkContext {
|
||||||
|
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
||||||
|
|
||||||
|
protected class EvalError extends EvalTrait {
|
||||||
|
|
||||||
|
val logger = LogFactory.getLog(classOf[EvalError])
|
||||||
|
|
||||||
|
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get evaluate metric
|
||||||
|
*
|
||||||
|
* @return evalMetric
|
||||||
|
*/
|
||||||
|
override def getMetric: String = evalMetric
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with predicts and data
|
||||||
|
*
|
||||||
|
* @param predicts predictions as array
|
||||||
|
* @param dmat data matrix to evaluate
|
||||||
|
* @return result of the metric
|
||||||
|
*/
|
||||||
|
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||||
|
var error: Float = 0f
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel
|
||||||
|
} catch {
|
||||||
|
case ex: XGBoostError =>
|
||||||
|
logger.error(ex)
|
||||||
|
return -1f
|
||||||
|
}
|
||||||
|
val nrow: Int = predicts.length
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||||
|
error += 1
|
||||||
|
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||||
|
error += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
error / labels.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
||||||
|
val file = Source.fromFile(new File(filePath))
|
||||||
|
val sampleList = new ListBuffer[LabeledPoint]
|
||||||
|
for (sample <- file.getLines()) {
|
||||||
|
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||||
|
}
|
||||||
|
sampleList.toList
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
|
||||||
|
val labelAndFeatures = line.split(" ")
|
||||||
|
val label = labelAndFeatures(0).toDouble
|
||||||
|
val features = labelAndFeatures.tail
|
||||||
|
val denseFeature = new Array[Double](129)
|
||||||
|
for (feature <- features) {
|
||||||
|
val idAndValue = feature.split(":")
|
||||||
|
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||||
|
}
|
||||||
|
(label, new DenseVector(denseFeature))
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||||
|
val (label, sv) = fromSVMStringToLabelAndVector(line)
|
||||||
|
LabeledPoint(label, sv)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
||||||
|
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
||||||
|
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,129 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
import scala.io.Source
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.mllib.linalg.VectorUDT
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
import org.apache.spark.sql._
|
||||||
|
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
|
||||||
|
|
||||||
|
class XGBoostDFSuite extends Utils {
|
||||||
|
|
||||||
|
private def loadRow(filePath: String): List[Row] = {
|
||||||
|
val file = Source.fromFile(new File(filePath))
|
||||||
|
val rowList = new ListBuffer[Row]
|
||||||
|
for (rowLine <- file.getLines()) {
|
||||||
|
rowList += fromSVMStringToRow(rowLine)
|
||||||
|
}
|
||||||
|
rowList.toList
|
||||||
|
}
|
||||||
|
|
||||||
|
private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None):
|
||||||
|
DataFrame = {
|
||||||
|
val rowList = loadRow(getClass.getResource("/agaricus.txt.train").getFile)
|
||||||
|
val rowRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers)
|
||||||
|
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
||||||
|
sparkSession.createDataFrame(rowRDD,
|
||||||
|
StructType(Array(StructField("label", DoubleType, nullable = false),
|
||||||
|
StructField("features", new VectorUDT, nullable = false))))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def fromSVMStringToRow(line: String): Row = {
|
||||||
|
val (label, sv) = fromSVMStringToLabelAndVector(line)
|
||||||
|
Row(label, sv)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test consistency between training with dataframe and RDD") {
|
||||||
|
val trainingDF = buildTrainingDataframe()
|
||||||
|
val trainingRDD = buildTrainingRDD()
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
val xgBoostModelWithRDD = XGBoost.trainWithRDD(trainingRDD, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
val eval = new EvalError()
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
assert(
|
||||||
|
eval.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix) ===
|
||||||
|
eval.eval(xgBoostModelWithRDD.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test transform of dataframe-based model") {
|
||||||
|
val trainingDF = buildTrainingDataframe()
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
||||||
|
val testRowsRDD = sc.parallelize(testSet.zipWithIndex, numWorkers).map{
|
||||||
|
case (instance: LabeledPoint, id: Int) =>
|
||||||
|
Row(id, instance.features, instance.label)
|
||||||
|
}
|
||||||
|
val testDF = trainingDF.sparkSession.createDataFrame(testRowsRDD, StructType(
|
||||||
|
Array(StructField("id", IntegerType),
|
||||||
|
StructField("features", new VectorUDT), StructField("label", DoubleType))))
|
||||||
|
xgBoostModelWithDF.transform(testDF).show()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test order preservation of dataframe-based model") {
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val trainingItr = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile).
|
||||||
|
iterator
|
||||||
|
val (testItr, auxTestItr) =
|
||||||
|
loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.duplicate
|
||||||
|
import DataUtils._
|
||||||
|
val trainDMatrix = new DMatrix(new JDMatrix(trainingItr, null))
|
||||||
|
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
||||||
|
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
|
||||||
|
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||||
|
val testRowsRDD = sc.parallelize(
|
||||||
|
auxTestItr.toList.zipWithIndex, numWorkers).map {
|
||||||
|
case (instance: LabeledPoint, id: Int) =>
|
||||||
|
Row(id, instance.features, instance.label)
|
||||||
|
}
|
||||||
|
val trainingDF = buildTrainingDataframe()
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
val testDF = trainingDF.sqlContext.createDataFrame(testRowsRDD, StructType(
|
||||||
|
Array(StructField("id", IntegerType), StructField("features", new VectorUDT),
|
||||||
|
StructField("label", DoubleType))))
|
||||||
|
val predResultsFromDF =
|
||||||
|
xgBoostModelWithDF.transform(testDF).collect().map(row => (row.getAs[Int]("id"),
|
||||||
|
row.getAs[mutable.WrappedArray[Float]]("prediction"))).toMap
|
||||||
|
for (i <- predResultFromSeq.indices) {
|
||||||
|
assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
|
||||||
|
for (j <- predResultFromSeq(i).indices) {
|
||||||
|
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -20,107 +20,20 @@ import java.io.File
|
|||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.io.Source
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix}
|
||||||
import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors, DenseVector}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => ScalaXGBoost}
|
||||||
|
import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix, XGBoostError}
|
class XGBoostGeneralSuite extends Utils {
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
|
||||||
|
|
||||||
class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|
||||||
|
|
||||||
private implicit var sc: SparkContext = null
|
|
||||||
private val numWorkers = Runtime.getRuntime().availableProcessors()
|
|
||||||
|
|
||||||
private class EvalError extends EvalTrait {
|
|
||||||
|
|
||||||
val logger = LogFactory.getLog(classOf[EvalError])
|
|
||||||
|
|
||||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get evaluate metric
|
|
||||||
*
|
|
||||||
* @return evalMetric
|
|
||||||
*/
|
|
||||||
override def getMetric: String = evalMetric
|
|
||||||
|
|
||||||
/**
|
|
||||||
* evaluate with predicts and data
|
|
||||||
*
|
|
||||||
* @param predicts predictions as array
|
|
||||||
* @param dmat data matrix to evaluate
|
|
||||||
* @return result of the metric
|
|
||||||
*/
|
|
||||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
|
||||||
var error: Float = 0f
|
|
||||||
var labels: Array[Float] = null
|
|
||||||
try {
|
|
||||||
labels = dmat.getLabel
|
|
||||||
} catch {
|
|
||||||
case ex: XGBoostError =>
|
|
||||||
logger.error(ex)
|
|
||||||
return -1f
|
|
||||||
}
|
|
||||||
val nrow: Int = predicts.length
|
|
||||||
for (i <- 0 until nrow) {
|
|
||||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
|
||||||
error += 1
|
|
||||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
|
||||||
error += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error / labels.length
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
before {
|
|
||||||
// build SparkContext
|
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
|
||||||
sc = new SparkContext(sparkConf)
|
|
||||||
}
|
|
||||||
|
|
||||||
after {
|
|
||||||
if (sc != null) {
|
|
||||||
sc.stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
|
||||||
val labelAndFeatures = line.split(" ")
|
|
||||||
val label = labelAndFeatures(0).toInt
|
|
||||||
val features = labelAndFeatures.tail
|
|
||||||
val denseFeature = new Array[Double](129)
|
|
||||||
for (feature <- features) {
|
|
||||||
val idAndValue = feature.split(":")
|
|
||||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
|
||||||
}
|
|
||||||
LabeledPoint(label, new DenseVector(denseFeature))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
|
||||||
val file = Source.fromFile(new File(filePath))
|
|
||||||
val sampleList = new ListBuffer[LabeledPoint]
|
|
||||||
for (sample <- file.getLines()) {
|
|
||||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
|
||||||
}
|
|
||||||
sampleList.toList
|
|
||||||
}
|
|
||||||
|
|
||||||
private def buildTrainingRDD(sparkContext: Option[SparkContext] = None): RDD[LabeledPoint] = {
|
|
||||||
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
|
|
||||||
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("build RDD containing boosters with the specified worker number") {
|
test("build RDD containing boosters with the specified worker number") {
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD()
|
||||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||||
@ -145,14 +58,15 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
sc = null
|
sc = null
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
customSparkContext.setLogLevel("ERROR")
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, round = 5,
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
nWorkers = numWorkers, useExternalMemory = true)
|
nWorkers = numWorkers, useExternalMemory = true)
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
@ -194,13 +108,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
val testRDD = buildDenseRDD().repartition(4)
|
val testRDD = buildDenseRDD().repartition(4)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test consistency of prediction functions with RDD") {
|
test("test consistency of prediction functions with RDD") {
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD()
|
||||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile)
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
||||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||||
val testCollection = testRDD.collect()
|
val testCollection = testRDD.collect()
|
||||||
for (i <- testSet.indices) {
|
for (i <- testSet.indices) {
|
||||||
@ -208,7 +122,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
}
|
}
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
val predRDD = xgBoostModel.predict(testRDD)
|
val predRDD = xgBoostModel.predict(testRDD)
|
||||||
val predResult1 = predRDD.collect()(0)
|
val predResult1 = predRDD.collect()(0)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
@ -225,26 +139,25 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
}
|
}
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD()
|
||||||
val testRDD = buildEmptyRDD()
|
val testRDD = buildEmptyRDD()
|
||||||
import DataUtils._
|
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test model consistency after save and load") {
|
test("test model consistency after save and load") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = buildTrainingRDD()
|
val trainingRDD = buildTrainingRDD()
|
||||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix)
|
testSetDMatrix)
|
||||||
assert(evalResults < 0.1)
|
assert(evalResults < 0.1)
|
||||||
@ -261,12 +174,13 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
||||||
set("spark.task.cpus", "4")
|
set("spark.task.cpus", "4")
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
customSparkContext.setLogLevel("ERROR")
|
||||||
// start another app
|
// start another app
|
||||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
}
|
}
|
||||||
customSparkContext.stop()
|
customSparkContext.stop()
|
||||||
}
|
}
|
||||||
@ -279,13 +193,14 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
customSparkContext.setLogLevel("ERROR")
|
||||||
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||||
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
customSparkContext.stop()
|
customSparkContext.stop()
|
||||||
Loading…
x
Reference in New Issue
Block a user