[jvm-packages] Integration with Spark Dataframe/Dataset (#1559)
* bump up to scala 2.11 * framework of data frame integration * test consistency between RDD and DataFrame * order preservation * test order preservation * example code and fix makefile * improve type checking * improve APIs * user docs * work around travis CI's limitation on log length * adjust test structure * integrate with Spark -1 .x * spark 2.x integration * remove spark 1.x implementation but provide instructions on how to downgrade
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{SQLContext, Row}
|
||||
import org.apache.spark.{SparkContext, SparkConf}
|
||||
|
||||
object SparkWithDataFrame {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length != 5) {
|
||||
println(
|
||||
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
// create SparkSession
|
||||
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||
val sqlContext = new SQLContext(new SparkContext(sparkConf))
|
||||
// create training and testing dataframes
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val outputModelPath = args(4)
|
||||
// number of iterations
|
||||
val numRound = args(0).toInt
|
||||
import DataUtils._
|
||||
val trainRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTrainPath).
|
||||
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
|
||||
val trainDF = sqlContext.createDataFrame(trainRDDOfRows, StructType(
|
||||
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||
val testRDDOfRows = MLUtils.loadLibSVMFile(sqlContext.sparkContext, inputTestPath).
|
||||
zipWithIndex().map{ case (labeledPoint, id) =>
|
||||
Row(id, labeledPoint.features, labeledPoint.label)}
|
||||
val testDF = sqlContext.createDataFrame(testRDDOfRows, StructType(
|
||||
Array(StructField("id", LongType),
|
||||
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||
// training parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.trainWithDataFrame(
|
||||
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
|
||||
// xgboost-spark appends the column containing prediction results
|
||||
xgboostModel.transform(testDF).show()
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
||||
object DistTrainWithSpark {
|
||||
object SparkWithRDD {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length != 5) {
|
||||
println(
|
||||
@@ -45,7 +45,7 @@ object DistTrainWithSpark {
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
useExternalMemory = true)
|
||||
xgboostModel.booster.predict(new DMatrix(testSet))
|
||||
// save model to HDFS path
|
||||
Reference in New Issue
Block a user