parent
0b2a26fa74
commit
f93f1c03fc
@ -381,7 +381,6 @@ object XGBoost extends Serializable {
|
|||||||
val attempt = TaskContext.get().attemptNumber.toString
|
val attempt = TaskContext.get().attemptNumber.toString
|
||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
|
||||||
val numRounds = xgbExecutionParam.numRounds
|
val numRounds = xgbExecutionParam.numRounds
|
||||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||||
try {
|
try {
|
||||||
@ -997,4 +996,3 @@ private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
|
|||||||
group
|
group
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -308,8 +308,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||||
if (batchCnt == 0) {
|
if (batchCnt == 0) {
|
||||||
val rabitEnv = Array(
|
val rabitEnv = Array(
|
||||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -286,8 +286,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
|
||||||
if (batchCnt == 0) {
|
if (batchCnt == 0) {
|
||||||
val rabitEnv = Array(
|
val rabitEnv = Array(
|
||||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
|
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
|
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit 74bf00a5ab4594f1695a8ea960394ce89f4a44d0
|
Subproject commit 4acdd7c6f68debe1c39ae07ca75466d74d194dd1
|
||||||
Loading…
x
Reference in New Issue
Block a user