diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 176a54832..31d58224b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2022 by Contributors + Copyright (c) 2021-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import scala.collection.{AbstractIterator, Iterator, mutable} -import ml.dmlc.xgboost4j.java.Communicator import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon @@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} @@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider { private var batchCnt = 0 private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow => - if (batchCnt == 0) { - val rabitEnv = Array( - "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Communicator.init(rabitEnv.asJava) - } - val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol)) import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ @@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider { override def hasNext: Boolean = batchIterImpl.hasNext - override def next(): Row = { - val ret = batchIterImpl.next() - if (!batchIterImpl.hasNext) { - Communicator.shutdown() - } - ret - } + override def next(): Row = batchIterImpl.next() + } }