[jvm-packages] Release dmatrix when no longer needed (#2436)
When using xgboost4j-spark I had executors getting killed much more often than i would expect by yarn for overrunning their memory limits, based on the memoryOverhead provided. It looks like a significant amount of this is because dmatrix's were being created but not released, because they were only released when the GC decided it was time to cleanup the references. Rather than waiting for the GC, relesae the DMatrix's when we know they are no longer necessary.
This commit is contained in:
parent
1899f9e744
commit
169c983b5f
@ -125,15 +125,19 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||||
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||||
if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) {
|
try {
|
||||||
trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]](
|
if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) {
|
||||||
TaskContext.getPartitionId()).toArray)
|
trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||||
|
TaskContext.getPartitionId()).toArray)
|
||||||
|
}
|
||||||
|
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||||
|
watches = new mutable.HashMap[String, DMatrix] {
|
||||||
|
put("train", trainingSet)
|
||||||
|
}.toMap, obj, eval)
|
||||||
|
Rabit.shutdown()
|
||||||
|
} finally {
|
||||||
|
trainingSet.delete()
|
||||||
}
|
}
|
||||||
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
|
||||||
watches = new mutable.HashMap[String, DMatrix] {
|
|
||||||
put("train", trainingSet)
|
|
||||||
}.toMap, obj, eval)
|
|
||||||
Rabit.shutdown()
|
|
||||||
} else {
|
} else {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" +
|
throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" +
|
||||||
|
|||||||
@ -67,9 +67,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
if (testSamples.hasNext) {
|
if (testSamples.hasNext) {
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||||
val res = broadcastBooster.value.predictLeaf(dMatrix)
|
try {
|
||||||
Rabit.shutdown()
|
val res = broadcastBooster.value.predictLeaf(dMatrix)
|
||||||
Iterator(res)
|
Rabit.shutdown()
|
||||||
|
Iterator(res)
|
||||||
|
} finally {
|
||||||
|
dMatrix.delete()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
}
|
}
|
||||||
@ -113,21 +117,25 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
}
|
}
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||||
if (groupData != null) {
|
try {
|
||||||
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
|
if (groupData != null) {
|
||||||
}
|
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
|
||||||
(evalFunc, iter) match {
|
|
||||||
case (null, _) => {
|
|
||||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
|
||||||
val Array(evName, predNumeric) = predStr.split(":")
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(Some(evName, predNumeric.toFloat))
|
|
||||||
}
|
}
|
||||||
case _ => {
|
(evalFunc, iter) match {
|
||||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
case (null, _) => {
|
||||||
Rabit.shutdown()
|
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
val Array(evName, predNumeric) = predStr.split(":")
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(Some(evName, predNumeric.toFloat))
|
||||||
|
}
|
||||||
|
case _ => {
|
||||||
|
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
|
dMatrix.delete()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator(None)
|
Iterator(None)
|
||||||
@ -161,9 +169,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||||
val res = broadcastBooster.value.predict(dMatrix)
|
try {
|
||||||
Rabit.shutdown()
|
val res = broadcastBooster.value.predict(dMatrix)
|
||||||
Iterator(res)
|
Rabit.shutdown()
|
||||||
|
Iterator(res)
|
||||||
|
} finally {
|
||||||
|
dMatrix.delete()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -191,9 +203,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
||||||
val res = broadcastBooster.value.predict(dMatrix)
|
try {
|
||||||
Rabit.shutdown()
|
val res = broadcastBooster.value.predict(dMatrix)
|
||||||
Iterator(res)
|
Rabit.shutdown()
|
||||||
|
Iterator(res)
|
||||||
|
} finally {
|
||||||
|
dMatrix.delete()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
}
|
}
|
||||||
@ -236,18 +252,22 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val testDataset = new DMatrix(vectorIterator, cachePrefix)
|
val testDataset = new DMatrix(vectorIterator, cachePrefix)
|
||||||
val rawPredictResults = {
|
try {
|
||||||
if (!predLeaf) {
|
val rawPredictResults = {
|
||||||
broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator
|
if (!predLeaf) {
|
||||||
} else {
|
broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator
|
||||||
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
|
} else {
|
||||||
|
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
Rabit.shutdown()
|
||||||
Rabit.shutdown()
|
// concatenate original data partition and predictions
|
||||||
// concatenate original data partition and predictions
|
rowItr1.zip(rawPredictResults).map {
|
||||||
rowItr1.zip(rawPredictResults).map {
|
case (originalColumns: Row, predictColumn: Row) =>
|
||||||
case (originalColumns: Row, predictColumn: Row) =>
|
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
||||||
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
}
|
||||||
|
} finally {
|
||||||
|
testDataset.delete()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Iterator[Row]()
|
Iterator[Row]()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user