From a837fa96202f570848454f01dd4cc2f29803f7dc Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Tue, 11 Apr 2017 06:12:49 -0700 Subject: [PATCH] [jvm-packages] rdds containing boosters should be cleaned once we got boosters to driver (#2183) --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7a2d8df04..d9dbcd543 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -123,10 +123,8 @@ object XGBoost extends Serializable { } val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing) val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName)) - if (xgBoostConfMap.isDefinedAt("groupData") - && xgBoostConfMap.get("groupData").get != null) { - trainingSet.setGroup( - xgBoostConfMap.get("groupData").get.asInstanceOf[Seq[Seq[Int]]]( + if (xgBoostConfMap.contains("groupData") && xgBoostConfMap("groupData") != null) { + trainingSet.setGroup(xgBoostConfMap("groupData").asInstanceOf[Seq[Seq[Int]]]( TaskContext.getPartitionId()).toArray) } booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, @@ -309,7 +307,10 @@ object XGBoost extends Serializable { configMap: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean): XGBoostModel = { if (trackerReturnVal == 0) { - convertBoosterToXGBoostModel(distributedBoosters.first(), isClassificationTask) + val xgboostModel = convertBoosterToXGBoostModel(distributedBoosters.first(), + isClassificationTask) + distributedBoosters.unpersist(false) + xgboostModel } else { try { if (sparkJobThread.isAlive) {