[jvm-packages]support multiple validation datasets in Spark (#3910)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * wrap iterators * enable copartition training and validationset * add parameters * converge code path and have init unit test * enable multi evals for ranking * unit test and doc * update example * fix early stopping * address the offline comments * udpate doc * test eval metrics * fix compilation issue * fix example
This commit is contained in:
@@ -40,7 +40,7 @@ object SparkTraining {
|
||||
StructField("petal length", DoubleType, true),
|
||||
StructField("petal width", DoubleType, true),
|
||||
StructField("class", StringType, true)))
|
||||
val rawInput = spark.read.schema(schema).csv(args(0))
|
||||
val rawInput = spark.read.schema(schema).csv(inputPath)
|
||||
|
||||
// transform class to index to make xgboost happy
|
||||
val stringIndexer = new StringIndexer()
|
||||
@@ -55,6 +55,8 @@ object SparkTraining {
|
||||
val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
|
||||
"classIndex")
|
||||
|
||||
val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))
|
||||
|
||||
/**
|
||||
* setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
|
||||
* to get 2 workers within 60000 ms
|
||||
@@ -67,12 +69,13 @@ object SparkTraining {
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> 2)
|
||||
"num_workers" -> 2,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
setLabelCol("classIndex")
|
||||
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
|
||||
val results = xgbClassificationModel.transform(xgbInput)
|
||||
val xgbClassificationModel = xgbClassifier.fit(train)
|
||||
val results = xgbClassificationModel.transform(test)
|
||||
results.show()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user