From 3b742dc4f179d5bf349c992c26655e69e7c0c0ab Mon Sep 17 00:00:00 2001 From: austinzh Date: Fri, 21 Apr 2023 07:38:07 -0400 Subject: [PATCH] Stop using Rabit in predition (#9054) --- .../xgboost4j/scala/spark/PreXGBoost.scala | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 176a54832..31d58224b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2022 by Contributors + Copyright (c) 2021-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import scala.collection.{AbstractIterator, Iterator, mutable} -import ml.dmlc.xgboost4j.java.Communicator import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon @@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} @@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider { private var batchCnt = 0 private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow => - if (batchCnt == 0) { - val rabitEnv = Array( - "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Communicator.init(rabitEnv.asJava) - } - val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol)) import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ @@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider { override def hasNext: Boolean = batchIterImpl.hasNext - override def next(): Row = { - val ret = batchIterImpl.next() - if (!batchIterImpl.hasNext) { - Communicator.shutdown() - } - ret - } + override def next(): Row = batchIterImpl.next() + } }