[jvm-packages] fix bug doing rabit call after finalize (#2079)
[jvm-packages]fix bug doing rabit call after finalize
This commit is contained in:
parent
fd19b7a188
commit
a92093388d
@ -57,9 +57,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
if (testSamples.hasNext) {
|
||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
|
||||
val res = broadcastBooster.value.predictLeaf(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(res)
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
@ -145,8 +149,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||
}
|
||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||
val res = broadcastBooster.value.predict(dMatrix)
|
||||
Rabit.shutdown()
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
Iterator(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user