[jvm-packages] tiny fix for empty partition in predict (#3014)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * tiny fix for empty partition in predict * further fix
This commit is contained in:
parent
740eba42f7
commit
a187ed6c8f
@ -169,12 +169,12 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
|
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
testSet.mapPartitions { testSamples =>
|
testSet.mapPartitions { testSamples =>
|
||||||
val sampleArray = testSamples.toList
|
val sampleArray = testSamples.toArray
|
||||||
val numRows = sampleArray.size
|
val numRows = sampleArray.length
|
||||||
val numColumns = sampleArray.head.size
|
|
||||||
if (numRows == 0) {
|
if (numRows == 0) {
|
||||||
Iterator()
|
Iterator()
|
||||||
} else {
|
} else {
|
||||||
|
val numColumns = sampleArray.head.size
|
||||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
// translate to required format
|
// translate to required format
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user