[jvm-packages]support multiple validation datasets in Spark (#3910)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* wrap iterators

* enable copartition training and validationset

* add parameters

* converge code path and have init unit test

* enable multi evals for ranking

* unit test and doc

* update example

* fix early stopping

* address the offline comments

* udpate doc

* test eval metrics

* fix compilation issue

* fix example
This commit is contained in:
Nan Zhu
2018-12-17 21:03:57 -08:00
committed by GitHub
parent c8c7b9649c
commit c055a32609
14 changed files with 477 additions and 136 deletions

View File

@@ -40,7 +40,7 @@ object SparkTraining {
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv(args(0))
val rawInput = spark.read.schema(schema).csv(inputPath)
// transform class to index to make xgboost happy
val stringIndexer = new StringIndexer()
@@ -55,6 +55,8 @@ object SparkTraining {
val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
"classIndex")
val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))
/**
* setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
* to get 2 workers within 60000 ms
@@ -67,12 +69,13 @@ object SparkTraining {
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2)
"num_workers" -> 2,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
val results = xgbClassificationModel.transform(xgbInput)
val xgbClassificationModel = xgbClassifier.fit(train)
val results = xgbClassificationModel.transform(test)
results.show()
}
}