[jvm-packages] tiny fix for empty partition in predict (#3014)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * tiny fix for empty partition in predict * further fix
This commit is contained in:
parent
740eba42f7
commit
a187ed6c8f
@ -169,12 +169,12 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val sampleArray = testSamples.toList
|
||||
val numRows = sampleArray.size
|
||||
val numColumns = sampleArray.head.size
|
||||
val sampleArray = testSamples.toArray
|
||||
val numRows = sampleArray.length
|
||||
if (numRows == 0) {
|
||||
Iterator()
|
||||
} else {
|
||||
val numColumns = sampleArray.head.size
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
// translate to required format
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user