fix the merge
This commit is contained in:
parent
16008ebfb8
commit
6499422e90
@ -64,7 +64,6 @@ object XGBoost extends Serializable {
|
|||||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||||
// force the job
|
// force the job
|
||||||
boosters.foreachPartition(_ => ())
|
boosters.foreachPartition(_ => ())
|
||||||
println("=====finished training=====")
|
|
||||||
val booster = boosters.first()
|
val booster = boosters.first()
|
||||||
val returnVal = tracker.waitFor()
|
val returnVal = tracker.waitFor()
|
||||||
logger.info(s"Rabit returns with exit code $returnVal")
|
logger.info(s"Rabit returns with exit code $returnVal")
|
||||||
|
|||||||
@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorker, 2, null, null)
|
numWorker, 2, null, null)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === numWorker)
|
assert(boosterCount === numWorker)
|
||||||
|
|||||||
@ -350,7 +350,10 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the dump of the model as a string array
|
* Save the model as byte array representation.
|
||||||
|
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||||
|
*
|
||||||
|
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||||
*
|
*
|
||||||
* @param withStats Controls whether the split statistics are output.
|
* @param withStats Controls whether the split statistics are output.
|
||||||
* @return dumped model information
|
* @return dumped model information
|
||||||
@ -367,9 +370,8 @@ public class Booster implements Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the dump of the model as a byte array
|
|
||||||
*
|
*
|
||||||
* @return dumped model information
|
* @return the saved byte array.
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public byte[] toByteArray() throws XGBoostError {
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class DataBatch {
|
|||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
static class BatchIterator implements Iterator<DataBatch>, Serializable {
|
static class BatchIterator implements Iterator<DataBatch> {
|
||||||
private Iterator<LabeledPoint> base;
|
private Iterator<LabeledPoint> base;
|
||||||
private int batchSize;
|
private int batchSize;
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
|
|||||||
/**
|
/**
|
||||||
* Rabit global class for synchronization.
|
* Rabit global class for synchronization.
|
||||||
*/
|
*/
|
||||||
public class Rabit implements Serializable {
|
public class Rabit {
|
||||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user