Rabit update. (#5978)

* Remove parameter on JVM Packages.
This commit is contained in:
Jiaming Yuan 2020-08-11 09:17:32 +08:00 committed by GitHub
parent 0b2a26fa74
commit f93f1c03fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 3 additions and 7 deletions

View File

@ -381,7 +381,6 @@ object XGBoost extends Serializable {
val attempt = TaskContext.get().attemptNumber.toString val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId) rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
val numRounds = xgbExecutionParam.numRounds val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0 val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try { try {
@ -997,4 +996,3 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
group group
} }
} }

View File

@ -308,8 +308,7 @@ class XGBoostClassificationModel private[ml](
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) { if (batchCnt == 0) {
val rabitEnv = Array( val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString, "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
} }

View File

@ -286,8 +286,7 @@ class XGBoostRegressionModel private[ml] (
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow => private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) { if (batchCnt == 0) {
val rabitEnv = Array( val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString, "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
} }

2
rabit

@ -1 +1 @@
Subproject commit 74bf00a5ab4594f1695a8ea960394ce89f4a44d0 Subproject commit 4acdd7c6f68debe1c39ae07ca75466d74d194dd1