Stop using Rabit in predition (#9054)

This commit is contained in:
austinzh 2023-04-21 07:38:07 -04:00 committed by GitHub
parent 39b0fde0e7
commit 3b742dc4f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2021-2022 by Contributors Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -22,7 +22,6 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable} import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
private var batchCnt = 0 private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow => private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Communicator.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol)) val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._ import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
override def hasNext: Boolean = batchIterImpl.hasNext override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = { override def next(): Row = batchIterImpl.next()
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Communicator.shutdown()
}
ret
}
} }
} }