fix the merge

This commit is contained in:
CodingCat 2016-03-06 15:22:05 -05:00
parent 16008ebfb8
commit 6499422e90
5 changed files with 8 additions and 6 deletions

View File

@ -64,7 +64,6 @@ object XGBoost extends Serializable {
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
// force the job // force the job
boosters.foreachPartition(_ => ()) boosters.foreachPartition(_ => ())
println("=====finished training=====")
val booster = boosters.first() val booster = boosters.first()
val returnVal = tracker.waitFor() val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal") logger.info(s"Rabit returns with exit code $returnVal")

View File

@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
trainingRDD, trainingRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap, "objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null) numWorker, 2, null, null)
val boosterCount = boosterRDD.count() val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker) assert(boosterCount === numWorker)

View File

@ -350,7 +350,10 @@ public class Booster implements Serializable {
} }
/** /**
* get the dump of the model as a string array * Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
* *
* @param withStats Controls whether the split statistics are output. * @param withStats Controls whether the split statistics are output.
* @return dumped model information * @return dumped model information
@ -367,9 +370,8 @@ public class Booster implements Serializable {
} }
/** /**
* get the dump of the model as a byte array
* *
* @return dumped model information * @return the saved byte array.
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public byte[] toByteArray() throws XGBoostError { public byte[] toByteArray() throws XGBoostError {

View File

@ -57,7 +57,7 @@ class DataBatch {
return b; return b;
} }
static class BatchIterator implements Iterator<DataBatch>, Serializable { static class BatchIterator implements Iterator<DataBatch> {
private Iterator<LabeledPoint> base; private Iterator<LabeledPoint> base;
private int batchSize; private int batchSize;

View File

@ -10,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
/** /**
* Rabit global class for synchronization. * Rabit global class for synchronization.
*/ */
public class Rabit implements Serializable { public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class); private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library //load native library
static { static {