diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala index 8d5b6d0bf..c2efcc6fe 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala @@ -17,11 +17,9 @@ package ml.dmlc.xgboost4j.scala.example.spark import ml.dmlc.xgboost4j.scala.Booster -import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SparkSession, SQLContext, Row} -import org.apache.spark.{SparkContext, SparkConf} +import ml.dmlc.xgboost4j.scala.spark.XGBoost +import org.apache.spark.sql.SparkSession +import org.apache.spark.SparkConf object SparkWithDataFrame { def main(args: Array[String]): Unit = { @@ -41,16 +39,8 @@ object SparkWithDataFrame { val inputTrainPath = args(2) val inputTestPath = args(3) // build dataset - val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath). - map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)} - val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType( - Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType)))) - val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath). - zipWithIndex().map{ case (labeledPoint, id) => - Row(id, labeledPoint.features, labeledPoint.label)} - val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType( - Array(StructField("id", LongType), - StructField("features", ArrayType(FloatType)), StructField("label", IntegerType)))) + val trainDF = sparkSession.sqlContext.read.format("libsvm").load(inputTrainPath) + val testDF = sparkSession.sqlContext.read.format("libsvm").load(inputTestPath) // start training val paramMap = List( "eta" -> 0.1f,