[jvm-packages] fix bug doing rabit call after finalize (#2079)
[jvm-packages]fix bug doing rabit call after finalize
This commit is contained in:
parent
fd19b7a188
commit
a92093388d
@ -57,9 +57,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
import DataUtils._
|
import DataUtils._
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
testSet.mapPartitions { testSamples =>
|
testSet.mapPartitions { testSamples =>
|
||||||
|
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||||
|
Rabit.init(rabitEnv.asJava)
|
||||||
if (testSamples.hasNext) {
|
if (testSamples.hasNext) {
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||||
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
|
val res = broadcastBooster.value.predictLeaf(dMatrix)
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(res)
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
}
|
}
|
||||||
@ -145,8 +149,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||||
|
val res = broadcastBooster.value.predict(dMatrix)
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
Iterator(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user