[jvm-packages] Leverage the Spark ml API to read DataFrame from files in LibSVM format. (#1785)
This commit is contained in:
parent
ca0069b708
commit
ce708c8e7f
@ -17,11 +17,9 @@
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{SparkSession, SQLContext, Row}
|
||||
import org.apache.spark.{SparkContext, SparkConf}
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.SparkConf
|
||||
|
||||
object SparkWithDataFrame {
|
||||
def main(args: Array[String]): Unit = {
|
||||
@ -41,16 +39,8 @@ object SparkWithDataFrame {
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
// build dataset
|
||||
val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath).
|
||||
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
|
||||
val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType(
|
||||
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||
val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath).
|
||||
zipWithIndex().map{ case (labeledPoint, id) =>
|
||||
Row(id, labeledPoint.features, labeledPoint.label)}
|
||||
val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType(
|
||||
Array(StructField("id", LongType),
|
||||
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
||||
val trainDF = sparkSession.sqlContext.read.format("libsvm").load(inputTrainPath)
|
||||
val testDF = sparkSession.sqlContext.read.format("libsvm").load(inputTestPath)
|
||||
// start training
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user