Stop using Rabit in predition (#9054)
This commit is contained in:
parent
39b0fde0e7
commit
3b742dc4f1
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021-2022 by Contributors
|
Copyright (c) 2021-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -22,7 +22,6 @@ import java.util.ServiceLoader
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Communicator
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||||
@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.broadcast.Broadcast
|
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||||
@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
private var batchCnt = 0
|
private var batchCnt = 0
|
||||||
|
|
||||||
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
||||||
if (batchCnt == 0) {
|
|
||||||
val rabitEnv = Array(
|
|
||||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
|
||||||
Communicator.init(rabitEnv.asJava)
|
|
||||||
}
|
|
||||||
|
|
||||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||||
@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
|
|||||||
|
|
||||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||||
|
|
||||||
override def next(): Row = {
|
override def next(): Row = batchIterImpl.next()
|
||||||
val ret = batchIterImpl.next()
|
|
||||||
if (!batchIterImpl.hasNext) {
|
|
||||||
Communicator.shutdown()
|
|
||||||
}
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user