[jvm-packages] tiny fix for empty partition in predict (#3014)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* tiny fix for empty partition in predict

* further fix
This commit is contained in:
Nan Zhu 2018-01-07 08:34:18 -08:00 committed by GitHub
parent 740eba42f7
commit a187ed6c8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -169,12 +169,12 @@ abstract class XGBoostModel(protected var _booster: Booster)
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val sampleArray = testSamples.toList
val numRows = sampleArray.size
val numColumns = sampleArray.head.size
val sampleArray = testSamples.toArray
val numRows = sampleArray.length
if (numRows == 0) {
Iterator()
} else {
val numColumns = sampleArray.head.size
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
// translate to required format