[jvm-packages] Leverage the Spark ml API to read DataFrame from files in LibSVM format. (#1785)

This commit is contained in:
XianXing Zhang 2016-11-20 18:28:03 -08:00 committed by Nan Zhu
parent ca0069b708
commit ce708c8e7f

View File

@ -17,11 +17,9 @@
package ml.dmlc.xgboost4j.scala.example.spark package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.Booster import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils} import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._ import org.apache.spark.SparkConf
import org.apache.spark.sql.{SparkSession, SQLContext, Row}
import org.apache.spark.{SparkContext, SparkConf}
object SparkWithDataFrame { object SparkWithDataFrame {
def main(args: Array[String]): Unit = { def main(args: Array[String]): Unit = {
@ -41,16 +39,8 @@ object SparkWithDataFrame {
val inputTrainPath = args(2) val inputTrainPath = args(2)
val inputTestPath = args(3) val inputTestPath = args(3)
// build dataset // build dataset
val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath). val trainDF = sparkSession.sqlContext.read.format("libsvm").load(inputTrainPath)
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)} val testDF = sparkSession.sqlContext.read.format("libsvm").load(inputTestPath)
val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType(
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath).
zipWithIndex().map{ case (labeledPoint, id) =>
Row(id, labeledPoint.features, labeledPoint.label)}
val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType(
Array(StructField("id", LongType),
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
// start training // start training
val paramMap = List( val paramMap = List(
"eta" -> 0.1f, "eta" -> 0.1f,