From f768edfede06906d53e30ddebccf3cdef3291ce1 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 6 Mar 2016 08:39:50 -0500 Subject: [PATCH 1/6] adjust the return values of RabitTracker.waitFor(), remove typesafe.Config --- jvm-packages/pom.xml | 5 -- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 64 +++++++++---------- .../ml/dmlc/xgboost4j/java/RabitTracker.java | 7 +- 3 files changed, 34 insertions(+), 42 deletions(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 43f602df6..db6bc8a98 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -161,10 +161,5 @@ 2.2.6 test - - com.typesafe - config - 1.2.1 - diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8b0d0a71e..ea7ba8563 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -16,10 +16,10 @@ package ml.dmlc.xgboost4j.scala.spark -import scala.collection.immutable.HashMap +import scala.collection.mutable -import com.typesafe.config.Config -import org.apache.spark.{TaskContext, SparkContext} +import org.apache.commons.logging.LogFactory +import org.apache.spark.TaskContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -28,6 +28,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} object XGBoost extends Serializable { + var boosters: RDD[Booster] = null + private val logger = LogFactory.getLog("XGBoostSpark") + implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { new XGBoostModel(booster) } @@ -37,42 +40,33 @@ object XGBoost extends Serializable { xgBoostConfMap: Map[String, AnyRef], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ - val sc = trainingData.sparkContext - val tracker = new RabitTracker(numWorkers) - if (tracker.start()) { - trainingData.repartition(numWorkers).mapPartitions { - trainingSamples => - Rabit.init(new java.util.HashMap[String, String]() { - put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - }) - val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) - val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, - watches = new HashMap[String, DMatrix], obj, eval) - Rabit.shutdown() - Iterator(booster) - }.cache() - } else { - null - } + trainingData.repartition(numWorkers).mapPartitions { + trainingSamples => + Rabit.init(new java.util.HashMap[String, String]() { + put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + }) + val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) + val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, + watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) + Rabit.shutdown() + Iterator(booster) + }.cache() } - def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null, - eval: EvalTrait = null): Option[XGBoostModel] = { - import DataUtils._ - val numWorkers = config.getInt("numWorkers") - val round = config.getInt("round") + def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int, + obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = { + val numWorkers = trainingData.partitions.length val sc = trainingData.sparkContext val tracker = new RabitTracker(numWorkers) - if (tracker.start()) { - // TODO: build configuration map from config - val xgBoostConfigMap = new HashMap[String, AnyRef]() - val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round, - obj, eval) - // force the job - sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) - tracker.waitFor() - // TODO: how to choose best model - Some(boosters.first()) + require(tracker.start(), "FAULT: Failed to start tracker") + boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval) + // force the job + sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) + val booster = boosters.first() + val returnVal = tracker.waitFor() + logger.info(s"Rabit returns with exit code $returnVal") + if (returnVal == 0) { + Some(booster) } else { None } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 762cff7bf..a5768d6cd 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -134,15 +134,18 @@ public class RabitTracker { } } - public void waitFor() { + public int waitFor() { try { trackerProcess.get().waitFor(); - logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue()); + int returnVal = trackerProcess.get().exitValue(); + logger.info("Tracker Process ends with exit code " + returnVal); stop(); + return returnVal; } catch (InterruptedException e) { // we should not get here as RabitTracker is accessed in the main thread e.printStackTrace(); logger.error("the RabitTracker thread is terminated unexpectedly"); + return 1; } } } From 808e30f9fcfc713aa094a76c1994b3b696773740 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 6 Mar 2016 10:16:11 -0500 Subject: [PATCH 2/6] example of DistTrainWithSpark and trigger job with foreachPartition --- jvm-packages/xgboost4j-demo/pom.xml | 2 +- .../scala/spark/demo/DistTrainWithSpark.scala | 74 +++ .../xgboost4j/scala/spark/DataUtils.scala | 24 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 3 +- .../java/ml/dmlc/xgboost4j/LabeledPoint.java | 4 +- .../java/ml/dmlc/xgboost4j/java/Booster.java | 461 +++++++++++++-- .../ml/dmlc/xgboost4j/java/DataBatch.java | 3 +- .../dmlc/xgboost4j/java/JavaBoosterImpl.java | 525 ------------------ .../java/ml/dmlc/xgboost4j/java/Rabit.java | 3 +- .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 6 +- .../ml/dmlc/xgboost4j/scala/Booster.scala | 242 +++----- .../xgboost4j/scala/ScalaBoosterImpl.scala | 99 ---- .../ml/dmlc/xgboost4j/scala/XGBoost.scala | 9 +- 13 files changed, 588 insertions(+), 867 deletions(-) create mode 100644 jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala delete mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java delete mode 100644 jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala diff --git a/jvm-packages/xgboost4j-demo/pom.xml b/jvm-packages/xgboost4j-demo/pom.xml index bef184adb..e076af63d 100644 --- a/jvm-packages/xgboost4j-demo/pom.xml +++ b/jvm-packages/xgboost4j-demo/pom.xml @@ -25,7 +25,7 @@ ml.dmlc - xgboost4j + xgboost4j-spark 0.1 diff --git a/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala new file mode 100644 index 000000000..8fd794423 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/scala/ml/dmlc/xgboost4j/scala/spark/demo/DistTrainWithSpark.scala @@ -0,0 +1,74 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.demo + +import java.io.File + +import scala.collection.mutable.ListBuffer +import scala.io.Source + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.regression.LabeledPoint + +import ml.dmlc.xgboost4j.scala.DMatrix +import ml.dmlc.xgboost4j.scala.spark.XGBoost + + +object DistTrainWithSpark { + + private def readFile(filePath: String): List[LabeledPoint] = { + val file = Source.fromFile(new File(filePath)) + val sampleList = new ListBuffer[LabeledPoint] + for (sample <- file.getLines()) { + sampleList += fromSVMStringToLabeledPoint(sample) + } + sampleList.toList + } + + private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = { + val labelAndFeatures = line.split(" ") + val label = labelAndFeatures(0).toInt + val features = labelAndFeatures.tail + val denseFeature = new Array[Double](129) + for (feature <- features) { + val idAndValue = feature.split(":") + denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble + } + LabeledPoint(label, new DenseVector(denseFeature)) + } + + def main(args: Array[String]): Unit = { + import ml.dmlc.xgboost4j.scala.spark.DataUtils._ + if (args.length != 4) { + println( + "usage: program number_of_trainingset_partitions num_of_rounds training_path test_path") + sys.exit(1) + } + val sc = new SparkContext() + val inputTrainPath = args(2) + val inputTestPath = args(3) + val trainingLabeledPoints = readFile(inputTrainPath) + val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt) + val testLabeledPoints = readFile(inputTestPath).iterator + val testMatrix = new DMatrix(testLabeledPoints, null) + val booster = XGBoost.train(trainRDD, + List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + "objective" -> "binary:logistic").toMap, args(1).toInt, null, null) + booster.map(boosterInstance => boosterInstance.predict(testMatrix)) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index 12fb545c9..d61cb9fc1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint} import ml.dmlc.xgboost4j.LabeledPoint -private[spark] object DataUtils extends Serializable { +object DataUtils extends Serializable { + + implicit def fromSparkToXGBoostLabeledPointsAsJava( + sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = { + fromSparkToXGBoostLabeledPoints(sps).asJava + } implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]): - java.util.Iterator[LabeledPoint] = { - (for (p <- sps) yield { + Iterator[LabeledPoint] = { + for (p <- sps) yield { p.features match { case denseFeature: DenseVector => LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat)) @@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable { LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices, sparseFeature.values.map(_.toFloat)) } - }).asJava - } - - private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = { - (sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList) - } - - private def fetchUpdateFromVector(feature: Vector) = feature match { - case denseFeature: DenseVector => - fetchUpdateFromSparseVector(denseFeature.toSparse) - case sparseFeature: SparseVector => - fetchUpdateFromSparseVector(sparseFeature) + } } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index ea7ba8563..7a35b85ec 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -61,7 +61,8 @@ object XGBoost extends Serializable { require(tracker.start(), "FAULT: Failed to start tracker") boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval) // force the job - sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) + boosters.foreachPartition(_ => ()) + println("=====finished training=====") val booster = boosters.first() val returnVal = tracker.waitFor() logger.info(s"Rabit returns with exit code $returnVal") diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java index fc14e361e..5f4351eb1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/LabeledPoint.java @@ -1,10 +1,12 @@ package ml.dmlc.xgboost4j; +import java.io.Serializable; + /** * Labeled data point for training examples. * Represent a sparse training instance. */ -public class LabeledPoint { +public class LabeledPoint implements Serializable { /** Label of the point */ public float label; /** Weight of this data point */ diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index e6e427900..efbbf3c28 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -1,41 +1,140 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ package ml.dmlc.xgboost4j.java; -import java.io.IOException; -import java.io.Serializable; +import java.io.*; +import java.util.HashMap; +import java.util.List; import java.util.Map; -public interface Booster extends Serializable { +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * Booster for xgboost, similar to the python wrapper xgboost.py + * but custom obj function and eval function not supported at present. + * + * @author hzx + */ +public class Booster implements Serializable { + private static final Log logger = LogFactory.getLog(Booster.class); + + long handle = 0; + + //load native library + static { + try { + NativeLibLoader.initXgBoost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + /** + * init Booster from dMatrixs + * + * @param params parameters + * @param dMatrixs DMatrix array + * @throws XGBoostError native error + */ + Booster(Map params, DMatrix[] dMatrixs) throws XGBoostError { + init(dMatrixs); + setParam("seed", "0"); + setParams(params); + } + + /** + * load model from modelPath + * + * @param params parameters + * @param modelPath booster modelPath (model generated by booster.saveModel) + * @throws XGBoostError native error + */ + Booster(Map params, String modelPath) throws XGBoostError { + init(null); + if (modelPath == null) { + throw new NullPointerException("modelPath : null"); + } + loadModel(modelPath); + setParam("seed", "0"); + setParams(params); + } + + + private void init(DMatrix[] dMatrixs) throws XGBoostError { + long[] handles = null; + if (dMatrixs != null) { + handles = dmatrixsToHandles(dMatrixs); + } + long[] out = new long[1]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out)); + + handle = out[0]; + } /** * set parameter * * @param key param name * @param value param value + * @throws XGBoostError native error */ - void setParam(String key, String value) throws XGBoostError; + public final void setParam(String key, String value) throws XGBoostError { + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value)); + } /** * set parameters * * @param params parameters key-value map + * @throws XGBoostError native error */ - void setParams(Map params) throws XGBoostError; + public void setParams(Map params) throws XGBoostError { + if (params != null) { + for (Map.Entry entry : params.entrySet()) { + setParam(entry.getKey(), entry.getValue().toString()); + } + } + } + /** * Update (one iteration) * * @param dtrain training data * @param iter current iteration number + * @throws XGBoostError native error */ - void update(DMatrix dtrain, int iter) throws XGBoostError; + public void update(DMatrix dtrain, int iter) throws XGBoostError { + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); + } /** * update with customize obj func * * @param dtrain training data * @param obj customized objective class + * @throws XGBoostError native error */ - void update(DMatrix dtrain, IObjective obj) throws XGBoostError; + public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { + float[][] predicts = predict(dtrain, true); + List gradients = obj.getGradient(predicts, dtrain); + boost(dtrain, gradients.get(0), gradients.get(1)); + } /** * update with give grad and hess @@ -43,8 +142,16 @@ public interface Booster extends Serializable { * @param dtrain training data * @param grad first order of gradient * @param hess seconde order of gradient + * @throws XGBoostError native error */ - void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError; + public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { + if (grad.length != hess.length) { + throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, + hess.length)); + } + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, + hess)); + } /** * evaluate with given dmatrixs. @@ -53,8 +160,15 @@ public interface Booster extends Serializable { * @param evalNames name for eval dmatrixs, used for check results * @param iter current eval iteration * @return eval information + * @throws XGBoostError native error */ - String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError; + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { + long[] handles = dmatrixsToHandles(evalMatrixs); + String[] evalInfo = new String[1]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, + evalInfo)); + return evalInfo[0]; + } /** * evaluate with given customized Evaluation class @@ -63,17 +177,64 @@ public interface Booster extends Serializable { * @param evalNames evaluation names * @param eval custom evaluator * @return eval information + * @throws XGBoostError native error */ - String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError; + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) + throws XGBoostError { + String evalInfo = ""; + for (int i = 0; i < evalNames.length; i++) { + String evalName = evalNames[i]; + DMatrix evalMat = evalMatrixs[i]; + float evalResult = eval.eval(predict(evalMat), evalMat); + String evalMetric = eval.getMetric(); + evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult); + } + return evalInfo; + } + + /** + * base function for Predict + * + * @param data data + * @param outPutMargin output margin + * @param treeLimit limit number of trees + * @param predLeaf prediction minimum to keep leafs + * @return predict results + */ + private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, + boolean predLeaf) throws XGBoostError { + int optionMask = 0; + if (outPutMargin) { + optionMask = 1; + } + if (predLeaf) { + optionMask = 2; + } + float[][] rawPredicts = new float[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, + treeLimit, rawPredicts)); + int row = (int) data.rowNum(); + int col = rawPredicts[0].length / row; + float[][] predicts = new float[row][col]; + int r, c; + for (int i = 0; i < rawPredicts[0].length; i++) { + r = i / col; + c = i % col; + predicts[r][c] = rawPredicts[0][i]; + } + return predicts; + } /** * Predict with data * * @param data dmatrix storing the input * @return predict result + * @throws XGBoostError native error */ - float[][] predict(DMatrix data) throws XGBoostError; - + public float[][] predict(DMatrix data) throws XGBoostError { + return pred(data, false, 0, false); + } /** * Predict with data @@ -81,9 +242,11 @@ public interface Booster extends Serializable { * @param data dmatrix storing the input * @param outPutMargin Whether to output the raw untransformed margin value. * @return predict result + * @throws XGBoostError native error */ - float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError; - + public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError { + return pred(data, outPutMargin, 0, false); + } /** * Predict with data @@ -92,31 +255,189 @@ public interface Booster extends Serializable { * @param outPutMargin Whether to output the raw untransformed margin value. * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). * @return predict result + * @throws XGBoostError native error */ - float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError; - + public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError { + return pred(data, outPutMargin, treeLimit, false); + } /** * Predict with data - * @param data dmatrix storing the input + * + * @param data dmatrix storing the input * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), - * nsample = data.numRow with each record indicating the predicted leaf index of - * each sample in each tree. Note that the leaf index of a tree is unique per - * tree, so you may find leaf 1 in both tree 1 and tree 0. + * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), + * nsample = data.numRow with each record indicating the predicted leaf index + * of each sample in each tree. + * Note that the leaf index of a tree is unique per tree, so you may find leaf 1 + * in both tree 1 and tree 0. * @return predict result * @throws XGBoostError native error */ - float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError; + public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError { + return pred(data, false, treeLimit, predLeaf); + } /** - * save model to modelPath, the model path support depends on the path support - * in libxgboost. For example, if we want to save to hdfs, libxgboost need to be - * compiled with HDFS support. - * See also toByteArray + * save model to modelPath + * * @param modelPath model path */ - void saveModel(String modelPath) throws XGBoostError; + public void saveModel(String modelPath) throws XGBoostError{ + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath)); + } + + private void loadModel(String modelPath) { + XGBoostJNI.XGBoosterLoadModel(handle, modelPath); + } + + /** + * get the dump of the model as a string array + * + * @param withStats Controls whether the split statistics are output. + * @return dumped model information + * @throws XGBoostError native error + */ + private String[] getDumpInfo(boolean withStats) throws XGBoostError { + int statsFlag = 0; + if (withStats) { + statsFlag = 1; + } + String[][] modelInfos = new String[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); + return modelInfos[0]; + } + + /** + * get the dump of the model as a string array + * + * @param featureMap featureMap file + * @param withStats Controls whether the split statistics are output. + * @return dumped model information + * @throws XGBoostError native error + */ + private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError { + int statsFlag = 0; + if (withStats) { + statsFlag = 1; + } + String[][] modelInfos = new String[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, + modelInfos)); + return modelInfos[0]; + } + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param withStats bool + * Controls whether the split statistics are output. + * @throws FileNotFoundException file not found + * @throws UnsupportedEncodingException unsupported feature + * @throws IOException error with model writing + * @throws XGBoostError native error + */ + public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError { + File tf = new File(modelPath); + FileOutputStream out = new FileOutputStream(tf); + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); + String[] modelInfos = getDumpInfo(withStats); + + for (int i = 0; i < modelInfos.length; i++) { + writer.write("booster [" + i + "]:\n"); + writer.write(modelInfos[i]); + } + + writer.close(); + out.close(); + } + + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param featureMap featureMap file + * @param withStats bool + * Controls whether the split statistics are output. + * @throws FileNotFoundException exception + * @throws UnsupportedEncodingException exception + * @throws IOException exception + * @throws XGBoostError native error + */ + public void dumpModel(String modelPath, String featureMap, boolean withStats) throws + IOException, XGBoostError { + File tf = new File(modelPath); + FileOutputStream out = new FileOutputStream(tf); + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); + String[] modelInfos = getDumpInfo(featureMap, withStats); + + for (int i = 0; i < modelInfos.length; i++) { + writer.write("booster [" + i + "]:\n"); + writer.write(modelInfos[i]); + } + + writer.close(); + out.close(); + } + + + /** + * get importance of each feature + * + * @return featureMap key: feature index, value: feature importance score + * @throws XGBoostError native error + */ + public Map getFeatureScore() throws XGBoostError { + String[] modelInfos = getDumpInfo(false); + Map featureScore = new HashMap(); + for (String tree : modelInfos) { + for (String node : tree.split("\n")) { + String[] array = node.split("\\["); + if (array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if (featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } + + + /** + * get importance of each feature + * + * @param featureMap file to save dumped model info + * @return featureMap key: feature index, value: feature importance score + * @throws XGBoostError native error + */ + public Map getFeatureScore(String featureMap) throws XGBoostError { + String[] modelInfos = getDumpInfo(featureMap, false); + Map featureScore = new HashMap(); + for (String tree : modelInfos) { + for (String node : tree.split("\n")) { + String[] array = node.split("\\["); + if (array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if (featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } /** * Save the model as byte array representation. @@ -127,41 +448,77 @@ public interface Booster extends Serializable { * @return the saved byte array. * @throws XGBoostError */ - byte[] toByteArray() throws XGBoostError; + public byte[] toByteArray() throws XGBoostError { + byte[][] bytes = new byte[1][]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); + return bytes[0]; + } + /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param withStats bool Controls whether the split statistics are output. + * Load the booster model from thread-local rabit checkpoint. + * This is only used in distributed training. + * @return the stored version number of the checkpoint. + * @throws XGBoostError */ - void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError; + int loadRabitCheckpoint() throws XGBoostError { + int[] out = new int[1]; + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); + return out[0]; + } /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param featureMap featureMap file - * @param withStats bool - * Controls whether the split statistics are output. + * Save the booster model into thread-local rabit checkpoint. + * This is only used in distributed training. + * @throws XGBoostError */ - void dumpModel(String modelPath, String featureMap, boolean withStats) - throws IOException, XGBoostError; + void saveRabitCheckpoint() throws XGBoostError { + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); + } /** - * get importance of each feature + * transfer DMatrix array to handle array (used for native functions) * - * @return featureMap key: feature index, value: feature importance score + * @param dmatrixs + * @return handle array for input dmatrixs */ - Map getFeatureScore() throws XGBoostError ; + private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { + long[] handles = new long[dmatrixs.length]; + for (int i = 0; i < dmatrixs.length; i++) { + handles[i] = dmatrixs[i].getHandle(); + } + return handles; + } - /** - * get importance of each feature - * - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score - */ - Map getFeatureScore(String featureMap) throws XGBoostError; + // making Booster serializable + private void writeObject(java.io.ObjectOutputStream out) throws IOException { + try { + out.writeObject(this.toByteArray()); + } catch (XGBoostError ex) { + throw new IOException(ex.toString()); + } + } - void dispose(); + private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { + try { + this.init(null); + byte[] bytes = (byte[])in.readObject(); + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); + } catch (XGBoostError ex) { + throw new IOException(ex.toString()); + } + } + + @Override + protected void finalize() throws Throwable { + super.finalize(); + dispose(); + } + + public synchronized void dispose() { + if (handle != 0L) { + XGBoostJNI.XGBoosterFree(handle); + handle = 0; + } + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index 2a52d0b9b..82ae97ed6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -1,5 +1,6 @@ package ml.dmlc.xgboost4j.java; +import java.io.Serializable; import java.util.Iterator; import ml.dmlc.xgboost4j.LabeledPoint; @@ -56,7 +57,7 @@ class DataBatch { return b; } - static class BatchIterator implements Iterator { + static class BatchIterator implements Iterator, Serializable { private Iterator base; private int batchSize; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java deleted file mode 100644 index 37860dc6f..000000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JavaBoosterImpl.java +++ /dev/null @@ -1,525 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package ml.dmlc.xgboost4j.java; - -import java.io.*; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -/** - * Booster for xgboost, similar to the python wrapper xgboost.py - * but custom obj function and eval function not supported at present. - * - * @author hzx - */ -class JavaBoosterImpl implements Booster { - private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class); - - long handle = 0; - - //load native library - static { - try { - NativeLibLoader.initXgBoost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - - /** - * init Booster from dMatrixs - * - * @param params parameters - * @param dMatrixs DMatrix array - * @throws XGBoostError native error - */ - JavaBoosterImpl(Map params, DMatrix[] dMatrixs) throws XGBoostError { - init(dMatrixs); - setParam("seed", "0"); - setParams(params); - } - - /** - * load model from modelPath - * - * @param params parameters - * @param modelPath booster modelPath (model generated by booster.saveModel) - * @throws XGBoostError native error - */ - JavaBoosterImpl(Map params, String modelPath) throws XGBoostError { - init(null); - if (modelPath == null) { - throw new NullPointerException("modelPath : null"); - } - loadModel(modelPath); - setParam("seed", "0"); - setParams(params); - } - - - private void init(DMatrix[] dMatrixs) throws XGBoostError { - long[] handles = null; - if (dMatrixs != null) { - handles = dmatrixsToHandles(dMatrixs); - } - long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out)); - - handle = out[0]; - } - - /** - * set parameter - * - * @param key param name - * @param value param value - * @throws XGBoostError native error - */ - public final void setParam(String key, String value) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value)); - } - - /** - * set parameters - * - * @param params parameters key-value map - * @throws XGBoostError native error - */ - public void setParams(Map params) throws XGBoostError { - if (params != null) { - for (Map.Entry entry : params.entrySet()) { - setParam(entry.getKey(), entry.getValue().toString()); - } - } - } - - - /** - * Update (one iteration) - * - * @param dtrain training data - * @param iter current iteration number - * @throws XGBoostError native error - */ - public void update(DMatrix dtrain, int iter) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); - } - - /** - * update with customize obj func - * - * @param dtrain training data - * @param obj customized objective class - * @throws XGBoostError native error - */ - public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { - float[][] predicts = predict(dtrain, true); - List gradients = obj.getGradient(predicts, dtrain); - boost(dtrain, gradients.get(0), gradients.get(1)); - } - - /** - * update with give grad and hess - * - * @param dtrain training data - * @param grad first order of gradient - * @param hess seconde order of gradient - * @throws XGBoostError native error - */ - public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { - if (grad.length != hess.length) { - throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, - hess.length)); - } - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, - hess)); - } - - /** - * evaluate with given dmatrixs. - * - * @param evalMatrixs dmatrixs for evaluation - * @param evalNames name for eval dmatrixs, used for check results - * @param iter current eval iteration - * @return eval information - * @throws XGBoostError native error - */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { - long[] handles = dmatrixsToHandles(evalMatrixs); - String[] evalInfo = new String[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, - evalInfo)); - return evalInfo[0]; - } - - /** - * evaluate with given customized Evaluation class - * - * @param evalMatrixs evaluation matrix - * @param evalNames evaluation names - * @param eval custom evaluator - * @return eval information - * @throws XGBoostError native error - */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) - throws XGBoostError { - String evalInfo = ""; - for (int i = 0; i < evalNames.length; i++) { - String evalName = evalNames[i]; - DMatrix evalMat = evalMatrixs[i]; - float evalResult = eval.eval(predict(evalMat), evalMat); - String evalMetric = eval.getMetric(); - evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult); - } - return evalInfo; - } - - /** - * base function for Predict - * - * @param data data - * @param outPutMargin output margin - * @param treeLimit limit number of trees - * @param predLeaf prediction minimum to keep leafs - * @return predict results - */ - private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, - boolean predLeaf) throws XGBoostError { - int optionMask = 0; - if (outPutMargin) { - optionMask = 1; - } - if (predLeaf) { - optionMask = 2; - } - float[][] rawPredicts = new float[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, - treeLimit, rawPredicts)); - int row = (int) data.rowNum(); - int col = rawPredicts[0].length / row; - float[][] predicts = new float[row][col]; - int r, c; - for (int i = 0; i < rawPredicts[0].length; i++) { - r = i / col; - c = i % col; - predicts[r][c] = rawPredicts[0][i]; - } - return predicts; - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data) throws XGBoostError { - return pred(data, false, 0, false); - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError { - return pred(data, outPutMargin, 0, false); - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError { - return pred(data, outPutMargin, treeLimit, false); - } - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), - * nsample = data.numRow with each record indicating the predicted leaf index - * of each sample in each tree. - * Note that the leaf index of a tree is unique per tree, so you may find leaf 1 - * in both tree 1 and tree 0. - * @return predict result - * @throws XGBoostError native error - */ - public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError { - return pred(data, false, treeLimit, predLeaf); - } - - /** - * save model to modelPath - * - * @param modelPath model path - */ - public void saveModel(String modelPath) throws XGBoostError{ - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath)); - } - - private void loadModel(String modelPath) { - XGBoostJNI.XGBoosterLoadModel(handle, modelPath); - } - - /** - * get the dump of the model as a string array - * - * @param withStats Controls whether the split statistics are output. - * @return dumped model information - * @throws XGBoostError native error - */ - private String[] getDumpInfo(boolean withStats) throws XGBoostError { - int statsFlag = 0; - if (withStats) { - statsFlag = 1; - } - String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); - return modelInfos[0]; - } - - /** - * get the dump of the model as a string array - * - * @param featureMap featureMap file - * @param withStats Controls whether the split statistics are output. - * @return dumped model information - * @throws XGBoostError native error - */ - private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError { - int statsFlag = 0; - if (withStats) { - statsFlag = 1; - } - String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, - modelInfos)); - return modelInfos[0]; - } - - /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param withStats bool - * Controls whether the split statistics are output. - * @throws FileNotFoundException file not found - * @throws UnsupportedEncodingException unsupported feature - * @throws IOException error with model writing - * @throws XGBoostError native error - */ - public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError { - File tf = new File(modelPath); - FileOutputStream out = new FileOutputStream(tf); - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); - String[] modelInfos = getDumpInfo(withStats); - - for (int i = 0; i < modelInfos.length; i++) { - writer.write("booster [" + i + "]:\n"); - writer.write(modelInfos[i]); - } - - writer.close(); - out.close(); - } - - - /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param featureMap featureMap file - * @param withStats bool - * Controls whether the split statistics are output. - * @throws FileNotFoundException exception - * @throws UnsupportedEncodingException exception - * @throws IOException exception - * @throws XGBoostError native error - */ - public void dumpModel(String modelPath, String featureMap, boolean withStats) throws - IOException, XGBoostError { - File tf = new File(modelPath); - FileOutputStream out = new FileOutputStream(tf); - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); - String[] modelInfos = getDumpInfo(featureMap, withStats); - - for (int i = 0; i < modelInfos.length; i++) { - writer.write("booster [" + i + "]:\n"); - writer.write(modelInfos[i]); - } - - writer.close(); - out.close(); - } - - - /** - * get importance of each feature - * - * @return featureMap key: feature index, value: feature importance score - * @throws XGBoostError native error - */ - public Map getFeatureScore() throws XGBoostError { - String[] modelInfos = getDumpInfo(false); - Map featureScore = new HashMap(); - for (String tree : modelInfos) { - for (String node : tree.split("\n")) { - String[] array = node.split("\\["); - if (array.length == 1) { - continue; - } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if (featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); - } else { - featureScore.put(fid, 1); - } - } - } - return featureScore; - } - - - /** - * get importance of each feature - * - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score - * @throws XGBoostError native error - */ - public Map getFeatureScore(String featureMap) throws XGBoostError { - String[] modelInfos = getDumpInfo(featureMap, false); - Map featureScore = new HashMap(); - for (String tree : modelInfos) { - for (String node : tree.split("\n")) { - String[] array = node.split("\\["); - if (array.length == 1) { - continue; - } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if (featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); - } else { - featureScore.put(fid, 1); - } - } - } - return featureScore; - } - - /** - * Save the model as byte array representation. - * Write these bytes to a file will give compatible format with other xgboost bindings. - * - * If java natively support HDFS file API, use toByteArray and write the ByteArray, - * - * @return the saved byte array. - * @throws XGBoostError - */ - public byte[] toByteArray() throws XGBoostError { - byte[][] bytes = new byte[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); - return bytes[0]; - } - - - /** - * Load the booster model from thread-local rabit checkpoint. - * This is only used in distributed training. - * @return the stored version number of the checkpoint. - * @throws XGBoostError - */ - int loadRabitCheckpoint() throws XGBoostError { - int[] out = new int[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); - return out[0]; - } - - /** - * Save the booster model into thread-local rabit checkpoint. - * This is only used in distributed training. - * @throws XGBoostError - */ - void saveRabitCheckpoint() throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); - } - - /** - * transfer DMatrix array to handle array (used for native functions) - * - * @param dmatrixs - * @return handle array for input dmatrixs - */ - private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { - long[] handles = new long[dmatrixs.length]; - for (int i = 0; i < dmatrixs.length; i++) { - handles[i] = dmatrixs[i].getHandle(); - } - return handles; - } - - // making Booster serializable - private void writeObject(java.io.ObjectOutputStream out) throws IOException { - try { - out.writeObject(this.toByteArray()); - } catch (XGBoostError ex) { - throw new IOException(ex.toString()); - } - } - - private void readObject(java.io.ObjectInputStream in) - throws IOException, ClassNotFoundException { - try { - this.init(null); - byte[] bytes = (byte[])in.readObject(); - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); - } catch (XGBoostError ex) { - throw new IOException(ex.toString()); - } - } - - @Override - protected void finalize() throws Throwable { - super.finalize(); - dispose(); - } - - public synchronized void dispose() { - if (handle != 0L) { - XGBoostJNI.XGBoosterFree(handle); - handle = 0; - } - } -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index d8408d26c..711f092f7 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -1,6 +1,7 @@ package ml.dmlc.xgboost4j.java; import java.io.IOException; +import java.io.Serializable; import java.util.Map; import org.apache.commons.logging.Log; @@ -9,7 +10,7 @@ import org.apache.commons.logging.LogFactory; /** * Rabit global class for synchronization. */ -public class Rabit { +public class Rabit implements Serializable { private static final Log logger = LogFactory.getLog(DMatrix.class); //load native library static { diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 44b7425a2..5f810b749 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -71,7 +71,7 @@ public class XGBoost { } //initialize booster - JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats); + Booster booster = new Booster(params, allMats); int version = booster.loadRabitCheckpoint(); @@ -115,7 +115,7 @@ public class XGBoost { public static Booster initBoostingModel( Map params, DMatrix[] dMatrixs) throws XGBoostError { - return new JavaBoosterImpl(params, dMatrixs); + return new Booster(params, dMatrixs); } /** @@ -127,7 +127,7 @@ public class XGBoost { */ public static Booster loadBoostModel(Map params, String modelPath) throws XGBoostError { - return new JavaBoosterImpl(params, modelPath); + return new Booster(params, modelPath); } /** diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 88032a61a..5fdeb9e2d 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -16,172 +16,86 @@ package ml.dmlc.xgboost4j.scala -import java.io.IOException - -import ml.dmlc.xgboost4j.java.XGBoostError +import ml.dmlc.xgboost4j.java +import scala.collection.JavaConverters._ import scala.collection.mutable -trait Booster extends Serializable { +class Booster private[xgboost4j](booster: java.Booster) extends Serializable { + + def setParam(key: String, value: String): Unit = { + booster.setParam(key, value) + } + + def update(dtrain: DMatrix, iter: Int): Unit = { + booster.update(dtrain.jDMatrix, iter) + } + + def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { + booster.update(dtrain.jDMatrix, obj) + } + + def dumpModel(modelPath: String, withStats: Boolean): Unit = { + booster.dumpModel(modelPath, withStats) + } + + def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = { + booster.dumpModel(modelPath, featureMap, withStats) + } + + def setParams(params: Map[String, AnyRef]): Unit = { + booster.setParams(params.asJava) + } + + def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = { + booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter) + } + + def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): + String = { + booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval) + } + + def dispose: Unit = { + booster.dispose() + } + + def predict(data: DMatrix): Array[Array[Float]] = { + booster.predict(data.jDMatrix) + } + + def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = { + booster.predict(data.jDMatrix, outPutMargin) + } + + def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): + Array[Array[Float]] = { + booster.predict(data.jDMatrix, outPutMargin, treeLimit) + } + + def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = { + booster.predict(data.jDMatrix, treeLimit, predLeaf) + } + + def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { + booster.boost(dtrain.jDMatrix, grad, hess) + } + + def getFeatureScore: mutable.Map[String, Integer] = { + booster.getFeatureScore.asScala + } + + def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = { + booster.getFeatureScore(featureMap).asScala + } + + def saveModel(modelPath: String): Unit = { + booster.saveModel(modelPath) + } + + override def finalize(): Unit = { + super.finalize() + dispose + } - /** - * set parameter - * - * @param key param name - * @param value param value - */ - @throws(classOf[XGBoostError]) - def setParam(key: String, value: String) - - /** - * set parameters - * - * @param params parameters key-value map - */ - @throws(classOf[XGBoostError]) - def setParams(params: Map[String, AnyRef]) - - /** - * Update (one iteration) - * - * @param dtrain training data - * @param iter current iteration number - */ - @throws(classOf[XGBoostError]) - def update(dtrain: DMatrix, iter: Int) - - /** - * update with customize obj func - * - * @param dtrain training data - * @param obj customized objective class - */ - @throws(classOf[XGBoostError]) - def update(dtrain: DMatrix, obj: ObjectiveTrait) - - /** - * update with give grad and hess - * - * @param dtrain training data - * @param grad first order of gradient - * @param hess seconde order of gradient - */ - @throws(classOf[XGBoostError]) - def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]) - - /** - * evaluate with given dmatrixs. - * - * @param evalMatrixs dmatrixs for evaluation - * @param evalNames name for eval dmatrixs, used for check results - * @param iter current eval iteration - * @return eval information - */ - @throws(classOf[XGBoostError]) - def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String - - /** - * evaluate with given customized Evaluation class - * - * @param evalMatrixs evaluation matrix - * @param evalNames evaluation names - * @param eval custom evaluator - * @return eval information - */ - @throws(classOf[XGBoostError]) - def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @return predict result - */ - @throws(classOf[XGBoostError]) - def predict(data: DMatrix): Array[Array[Float]] - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @return predict result - */ - @throws(classOf[XGBoostError]) - def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param outPutMargin Whether to output the raw untransformed margin value. - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @return predict result - */ - @throws(classOf[XGBoostError]) - def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] - - /** - * Predict with data - * - * @param data dmatrix storing the input - * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), - * nsample = data.numRow with each record indicating the predicted leaf index of - * each sample in each tree. Note that the leaf index of a tree is unique per - * tree, so you may find leaf 1 in both tree 1 and tree 0. - * @return predict result - * @throws XGBoostError native error - */ - @throws(classOf[XGBoostError]) - def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] - - /** - * save model to modelPath - * - * @param modelPath model path - */ - @throws(classOf[XGBoostError]) - def saveModel(modelPath: String) - - /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param withStats bool Controls whether the split statistics are output. - */ - @throws(classOf[IOException]) - @throws(classOf[XGBoostError]) - def dumpModel(modelPath: String, withStats: Boolean) - - /** - * Dump model into a text file. - * - * @param modelPath file to save dumped model info - * @param featureMap featureMap file - * @param withStats bool - * Controls whether the split statistics are output. - */ - @throws(classOf[IOException]) - @throws(classOf[XGBoostError]) - def dumpModel(modelPath: String, featureMap: String, withStats: Boolean) - - /** - * get importance of each feature - * - * @return featureMap key: feature index, value: feature importance score - */ - @throws(classOf[XGBoostError]) - def getFeatureScore: mutable.Map[String, Integer] - - /** - * get importance of each feature - * - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score - */ - @throws(classOf[XGBoostError]) - def getFeatureScore(featureMap: String): mutable.Map[String, Integer] - - def dispose } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala deleted file mode 100644 index bdb2fb34c..000000000 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala - -import ml.dmlc.xgboost4j.java -import scala.collection.JavaConverters._ -import scala.collection.mutable - -private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster { - - override def setParam(key: String, value: String): Unit = { - booster.setParam(key, value) - } - - override def update(dtrain: DMatrix, iter: Int): Unit = { - booster.update(dtrain.jDMatrix, iter) - } - - override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { - booster.update(dtrain.jDMatrix, obj) - } - - override def dumpModel(modelPath: String, withStats: Boolean): Unit = { - booster.dumpModel(modelPath, withStats) - } - - override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = { - booster.dumpModel(modelPath, featureMap, withStats) - } - - override def setParams(params: Map[String, AnyRef]): Unit = { - booster.setParams(params.asJava) - } - - override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = { - booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter) - } - - override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): - String = { - booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval) - } - - override def dispose: Unit = { - booster.dispose() - } - - override def predict(data: DMatrix): Array[Array[Float]] = { - booster.predict(data.jDMatrix) - } - - override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = { - booster.predict(data.jDMatrix, outPutMargin) - } - - override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): - Array[Array[Float]] = { - booster.predict(data.jDMatrix, outPutMargin, treeLimit) - } - - override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = { - booster.predict(data.jDMatrix, treeLimit, predLeaf) - } - - override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { - booster.boost(dtrain.jDMatrix, grad, hess) - } - - override def getFeatureScore: mutable.Map[String, Integer] = { - booster.getFeatureScore.asScala - } - - override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = { - booster.getFeatureScore(featureMap).asScala - } - - override def saveModel(modelPath: String): Unit = { - booster.saveModel(modelPath) - } - - override def finalize(): Unit = { - super.finalize() - dispose - } -} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 58ed51527..5f81ae8b7 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -16,9 +16,10 @@ package ml.dmlc.xgboost4j.scala -import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost} import scala.collection.JavaConverters._ +import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost} + object XGBoost { def train( @@ -31,7 +32,7 @@ object XGBoost { val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava, obj, eval) - new ScalaBoosterImpl(xgboostInJava) + new Booster(xgboostInJava) } def crossValidation( @@ -47,11 +48,11 @@ object XGBoost { def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = { val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix)) - new ScalaBoosterImpl(xgboostInJava) + new Booster(xgboostInJava) } def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = { val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath) - new ScalaBoosterImpl(xgboostInJava) + new Booster(xgboostInJava) } } From 50337d19068c18ab78ee351747d98f3f959a712b Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 6 Mar 2016 14:56:49 -0500 Subject: [PATCH 3/6] fix rabitEnv --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7a35b85ec..4bbef3cde 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext @@ -38,13 +39,13 @@ object XGBoost extends Serializable { private[spark] def buildDistributedBoosters( trainingData: RDD[LabeledPoint], xgBoostConfMap: Map[String, AnyRef], + rabitEnv: mutable.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = { import DataUtils._ trainingData.repartition(numWorkers).mapPartitions { trainingSamples => - Rabit.init(new java.util.HashMap[String, String]() { - put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) - }) + rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv.asJava) val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null)) val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round, watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval) @@ -59,7 +60,8 @@ object XGBoost extends Serializable { val sc = trainingData.sparkContext val tracker = new RabitTracker(numWorkers) require(tracker.start(), "FAULT: Failed to start tracker") - boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval) + boosters = buildDistributedBoosters(trainingData, configMap, + tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) // force the job boosters.foreachPartition(_ => ()) println("=====finished training=====") From 6499422e90f8e9a89c1b623b3ccf2a8dab2eed04 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 6 Mar 2016 15:22:05 -0500 Subject: [PATCH 4/6] fix the merge --- .../scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 1 - .../ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala | 1 + .../src/main/java/ml/dmlc/xgboost4j/java/Booster.java | 8 +++++--- .../src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java | 2 +- .../src/main/java/ml/dmlc/xgboost4j/java/Rabit.java | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 4bbef3cde..68b887b23 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -64,7 +64,6 @@ object XGBoost extends Serializable { tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) // force the job boosters.foreachPartition(_ => ()) - println("=====finished training=====") val booster = boosters.first() val returnVal = tracker.waitFor() logger.info(s"Rabit returns with exit code $returnVal") diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 23c9924d1..ca1fe9ada 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll { trainingRDD, List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", "objective" -> "binary:logistic").toMap, + new scala.collection.mutable.HashMap[String, String], numWorker, 2, null, null) val boosterCount = boosterRDD.count() assert(boosterCount === numWorker) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 15e242b6d..5778149f2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -350,7 +350,10 @@ public class Booster implements Serializable { } /** - * get the dump of the model as a string array + * Save the model as byte array representation. + * Write these bytes to a file will give compatible format with other xgboost bindings. + * + * If java natively support HDFS file API, use toByteArray and write the ByteArray * * @param withStats Controls whether the split statistics are output. * @return dumped model information @@ -367,9 +370,8 @@ public class Booster implements Serializable { } /** - * get the dump of the model as a byte array * - * @return dumped model information + * @return the saved byte array. * @throws XGBoostError native error */ public byte[] toByteArray() throws XGBoostError { diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index 82ae97ed6..d2ff3b612 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -57,7 +57,7 @@ class DataBatch { return b; } - static class BatchIterator implements Iterator, Serializable { + static class BatchIterator implements Iterator { private Iterator base; private int batchSize; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index e336cf807..3429dc3dd 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -10,7 +10,7 @@ import org.apache.commons.logging.LogFactory; /** * Rabit global class for synchronization. */ -public class Rabit implements Serializable { +public class Rabit { private static final Log logger = LogFactory.getLog(DMatrix.class); //load native library static { From 718a9d8c96ee0d24b4a7eca4cd9d90ea653a1e7b Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 6 Mar 2016 15:32:14 -0500 Subject: [PATCH 5/6] use another thread to control spark job --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 68b887b23..a7c802dc1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -62,14 +62,28 @@ object XGBoost extends Serializable { require(tracker.start(), "FAULT: Failed to start tracker") boosters = buildDistributedBoosters(trainingData, configMap, tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) - // force the job - boosters.foreachPartition(_ => ()) - val booster = boosters.first() + @volatile var booster: Booster = null + val sparkJobThread = new Thread() { + override def run() { + // force the job + boosters.foreachPartition(_ => ()) + } + } + sparkJobThread.start() val returnVal = tracker.waitFor() logger.info(s"Rabit returns with exit code $returnVal") if (returnVal == 0) { + booster = boosters.first() Some(booster) } else { + try { + if (sparkJobThread.isAlive) { + sparkJobThread.interrupt() + } + } catch { + case ie: InterruptedException => + logger.info("spark job thread is interrupted") + } None } } From c211a8063377311c49c43911cf517ab57a9e30e3 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 6 Mar 2016 17:02:07 -0500 Subject: [PATCH 6/6] log tracker exit value in logger capture InterruptedException --- .../src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index a5768d6cd..5b04ac432 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -47,8 +47,15 @@ public class RabitTracker { while ((line = reader.readLine()) != null) { trackerProcessLogger.info(line); } + trackerProcess.get().waitFor(); + trackerProcessLogger.info("Tracker Process ends with exit code " + + trackerProcess.get().exitValue()); } catch (IOException ex) { trackerProcessLogger.error(ex.toString()); + } catch (InterruptedException ie) { + // we should not get here as RabitTracker is accessed in the main thread + ie.printStackTrace(); + logger.error("the RabitTracker thread is terminated unexpectedly"); } } }