[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4876)

* Expose sets of rabit configurations to spark layer
This commit is contained in:
Chen Qin
2019-10-18 12:07:31 -07:00
committed by Jiaming Yuan
parent 31030a8d3a
commit 86ed01c4bb
73 changed files with 343 additions and 115 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,7 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
@@ -28,7 +31,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
test("nthread configuration must be no larger than spark.task.cpus") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
intercept[IllegalArgumentException] {
@@ -40,7 +43,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
// TODO write an isolated test for Booster.
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator, null)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
@@ -52,7 +55,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
ss.conf.set("spark.ssl.enabled", true)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -0,0 +1,110 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
test("test parity classification prediction") {
val training = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val model1 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
assert(Rabit.rabitEnvs.asScala.size > 7)
Rabit.rabitEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "1")
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
})
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 == p2)
}
}
test("test parity regression prediction") {
val training = buildDataFrame(Regression.train)
val testDM = new DMatrix(Regression.test.iterator, null)
val testDF = buildDataFrame(Classification.test)
val model1 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
).fit(training)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val model2 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
assert(Rabit.rabitEnvs.asScala.size > 7)
Rabit.rabitEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "true")
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
if (item._1.toString == "DMLC_WORKER_STOP_PROCESS_ON_ERROR") assert(item._2 == "false")
})
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) < 0.00001f)
}
}
test("test graceful failure handle") {
val training = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
// mock rank 0 failure during 4th allreduce synchronization
Rabit.mockList = Array("0,4,0,0").toList.asJava
intercept[XGBoostError] {
new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"rabit_timeout" -> true, "rabit_timeout_sec" -> 1,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> false)).fit(training)
}
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.