force the user to set number of workers

This commit is contained in:
CodingCat
2016-03-12 13:33:57 -05:00
parent 980898f3fb
commit 16b9e92328
6 changed files with 20 additions and 26 deletions

View File

@@ -72,10 +72,7 @@ object XGBoost {
*/
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(
FileSystem
.get(new Configuration)
.open(new Path(modelPath))))
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
}
/**
@@ -85,11 +82,9 @@ object XGBoost {
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(
dtrain: DataSet[LabeledVector],
params: Map[String, Any],
round: Int): XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int, nWorkers: Int):
XGBoostModel = {
val tracker = new RabitTracker(nWorkers)
if (tracker.start()) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))