force the user to set number of workers
This commit is contained in:
@@ -72,10 +72,7 @@ object XGBoost {
|
||||
*/
|
||||
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
XGBoostScala.loadModel(
|
||||
FileSystem
|
||||
.get(new Configuration)
|
||||
.open(new Path(modelPath))))
|
||||
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -85,11 +82,9 @@ object XGBoost {
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(
|
||||
dtrain: DataSet[LabeledVector],
|
||||
params: Map[String, Any],
|
||||
round: Int): XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int, nWorkers: Int):
|
||||
XGBoostModel = {
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
if (tracker.start()) {
|
||||
dtrain
|
||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||
|
||||
Reference in New Issue
Block a user