[jvm-packages] Leverage the Spark ml API to read DataFrame from files in LibSVM format. (#1785)
This commit is contained in:
parent
ca0069b708
commit
ce708c8e7f
@ -17,11 +17,9 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.example.spark
|
package ml.dmlc.xgboost4j.scala.example.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, DataUtils}
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.SparkConf
|
||||||
import org.apache.spark.sql.{SparkSession, SQLContext, Row}
|
|
||||||
import org.apache.spark.{SparkContext, SparkConf}
|
|
||||||
|
|
||||||
object SparkWithDataFrame {
|
object SparkWithDataFrame {
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
@ -41,16 +39,8 @@ object SparkWithDataFrame {
|
|||||||
val inputTrainPath = args(2)
|
val inputTrainPath = args(2)
|
||||||
val inputTestPath = args(3)
|
val inputTestPath = args(3)
|
||||||
// build dataset
|
// build dataset
|
||||||
val trainRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTrainPath).
|
val trainDF = sparkSession.sqlContext.read.format("libsvm").load(inputTrainPath)
|
||||||
map{ labeledPoint => Row(labeledPoint.features, labeledPoint.label)}
|
val testDF = sparkSession.sqlContext.read.format("libsvm").load(inputTestPath)
|
||||||
val trainDF = sparkSession.createDataFrame(trainRDDOfRows, StructType(
|
|
||||||
Array(StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
|
||||||
val testRDDOfRows = MLUtils.loadLibSVMFile(sparkSession.sparkContext, inputTestPath).
|
|
||||||
zipWithIndex().map{ case (labeledPoint, id) =>
|
|
||||||
Row(id, labeledPoint.features, labeledPoint.label)}
|
|
||||||
val testDF = sparkSession.createDataFrame(testRDDOfRows, StructType(
|
|
||||||
Array(StructField("id", LongType),
|
|
||||||
StructField("features", ArrayType(FloatType)), StructField("label", IntegerType))))
|
|
||||||
// start training
|
// start training
|
||||||
val paramMap = List(
|
val paramMap = List(
|
||||||
"eta" -> 0.1f,
|
"eta" -> 0.1f,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user