fix the merge

This commit is contained in:
CodingCat 2016-03-06 15:22:05 -05:00
parent 16008ebfb8
commit 6499422e90
5 changed files with 8 additions and 6 deletions

View File

@ -64,7 +64,6 @@ object XGBoost extends Serializable {
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
// force the job
boosters.foreachPartition(_ => ())
println("=====finished training=====")
val booster = boosters.first()
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")

View File

@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
trainingRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorker, 2, null, null)
val boosterCount = boosterRDD.count()
assert(boosterCount === numWorker)

View File

@ -350,7 +350,10 @@ public class Booster implements Serializable {
}
/**
* get the dump of the model as a string array
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
@ -367,9 +370,8 @@ public class Booster implements Serializable {
}
/**
* get the dump of the model as a byte array
*
* @return dumped model information
* @return the saved byte array.
* @throws XGBoostError native error
*/
public byte[] toByteArray() throws XGBoostError {

View File

@ -57,7 +57,7 @@ class DataBatch {
return b;
}
static class BatchIterator implements Iterator<DataBatch>, Serializable {
static class BatchIterator implements Iterator<DataBatch> {
private Iterator<LabeledPoint> base;
private int batchSize;

View File

@ -10,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
/**
* Rabit global class for synchronization.
*/
public class Rabit implements Serializable {
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {