fix the merge
This commit is contained in:
parent
16008ebfb8
commit
6499422e90
@ -64,7 +64,6 @@ object XGBoost extends Serializable {
|
||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||
// force the job
|
||||
boosters.foreachPartition(_ => ())
|
||||
println("=====finished training=====")
|
||||
val booster = boosters.first()
|
||||
val returnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
|
||||
@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
numWorker, 2, null, null)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === numWorker)
|
||||
|
||||
@ -350,7 +350,10 @@ public class Booster implements Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
@ -367,9 +370,8 @@ public class Booster implements Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a byte array
|
||||
*
|
||||
* @return dumped model information
|
||||
* @return the saved byte array.
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
|
||||
@ -57,7 +57,7 @@ class DataBatch {
|
||||
return b;
|
||||
}
|
||||
|
||||
static class BatchIterator implements Iterator<DataBatch>, Serializable {
|
||||
static class BatchIterator implements Iterator<DataBatch> {
|
||||
private Iterator<LabeledPoint> base;
|
||||
private int batchSize;
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
/**
|
||||
* Rabit global class for synchronization.
|
||||
*/
|
||||
public class Rabit implements Serializable {
|
||||
public class Rabit {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
//load native library
|
||||
static {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user