diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8e730667d..05cbee80d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -125,15 +125,19 @@ object XGBoost extends Serializable { } val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing) val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName)) - if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) { - trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]]( - TaskContext.getPartitionId()).toArray) + try { + if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) { + trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]]( + TaskContext.getPartitionId()).toArray) + } + booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, + watches = new mutable.HashMap[String, DMatrix] { + put("train", trainingSet) + }.toMap, obj, eval) + Rabit.shutdown() + } finally { + trainingSet.delete() } - booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, - watches = new mutable.HashMap[String, DMatrix] { - put("train", trainingSet) - }.toMap, obj, eval) - Rabit.shutdown() } else { Rabit.shutdown() throw new XGBoostError(s"detect the empty partition in training dataset, partition ID:" + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index fe281bff6..b4d405364 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -67,9 +67,13 @@ abstract class XGBoostModel(protected var _booster: Booster) Rabit.init(rabitEnv.asJava) if (testSamples.hasNext) { val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) - val res = broadcastBooster.value.predictLeaf(dMatrix) - Rabit.shutdown() - Iterator(res) + try { + val res = broadcastBooster.value.predictLeaf(dMatrix) + Rabit.shutdown() + Iterator(res) + } finally { + dMatrix.delete() + } } else { Iterator() } @@ -113,21 +117,25 @@ abstract class XGBoostModel(protected var _booster: Booster) } import DataUtils._ val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) - if (groupData != null) { - dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray) - } - (evalFunc, iter) match { - case (null, _) => { - val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) - val Array(evName, predNumeric) = predStr.split(":") - Rabit.shutdown() - Iterator(Some(evName, predNumeric.toFloat)) + try { + if (groupData != null) { + dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray) } - case _ => { - val predictions = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) + (evalFunc, iter) match { + case (null, _) => { + val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) + val Array(evName, predNumeric) = predStr.split(":") + Rabit.shutdown() + Iterator(Some(evName, predNumeric.toFloat)) + } + case _ => { + val predictions = broadcastBooster.value.predict(dMatrix) + Rabit.shutdown() + Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) + } } + } finally { + dMatrix.delete() } } else { Iterator(None) @@ -161,9 +169,13 @@ abstract class XGBoostModel(protected var _booster: Booster) flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) - val res = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(res) + try { + val res = broadcastBooster.value.predict(dMatrix) + Rabit.shutdown() + Iterator(res) + } finally { + dMatrix.delete() + } } } } @@ -191,9 +203,13 @@ abstract class XGBoostModel(protected var _booster: Booster) } } val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) - val res = broadcastBooster.value.predict(dMatrix) - Rabit.shutdown() - Iterator(res) + try { + val res = broadcastBooster.value.predict(dMatrix) + Rabit.shutdown() + Iterator(res) + } finally { + dMatrix.delete() + } } else { Iterator() } @@ -236,18 +252,22 @@ abstract class XGBoostModel(protected var _booster: Booster) } } val testDataset = new DMatrix(vectorIterator, cachePrefix) - val rawPredictResults = { - if (!predLeaf) { - broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator - } else { - broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator + try { + val rawPredictResults = { + if (!predLeaf) { + broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator + } else { + broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator + } } - } - Rabit.shutdown() - // concatenate original data partition and predictions - rowItr1.zip(rawPredictResults).map { - case (originalColumns: Row, predictColumn: Row) => - Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq) + Rabit.shutdown() + // concatenate original data partition and predictions + rowItr1.zip(rawPredictResults).map { + case (originalColumns: Row, predictColumn: Row) => + Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq) + } + } finally { + testDataset.delete() } } else { Iterator[Row]()