Stop using Rabit in predition (#9054)
This commit is contained in:
parent
39b0fde0e7
commit
3b742dc4f1
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -22,7 +22,6 @@ import java.util.ServiceLoader
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
private var batchCnt = 0
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Communicator.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||
|
||||
override def next(): Row = {
|
||||
val ret = batchIterImpl.next()
|
||||
if (!batchIterImpl.hasNext) {
|
||||
Communicator.shutdown()
|
||||
}
|
||||
ret
|
||||
}
|
||||
override def next(): Row = batchIterImpl.next()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user