[jvm-packages] do not use multiple jobs to make checkpoints (#5082)

* temp

* temp

* tep

* address the comments

* fix stylistic issues

* fix

* external checkpoint
This commit is contained in:
Nan Zhu
2020-02-01 19:36:39 -08:00
committed by GitHub
parent fa26313feb
commit d7b45fbcaf
14 changed files with 464 additions and 320 deletions

View File

@@ -13,6 +13,18 @@
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>

View File

@@ -0,0 +1,117 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private FileSystem fs;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
if (checkpointPath == null || checkpointPath.isEmpty()) {
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
" empty checkpoint path");
}
this.checkpointPath = new Path(checkpointPath);
this.fs = fs;
}
private String getPath(int version) {
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
}
private List<Integer> getExistingVersions() throws IOException {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
.map(fileName -> Integer.valueOf(
fileName.substring(0, fileName.length() - modelSuffix.length())))
.collect(Collectors.toList());
}
}
public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
}
}
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
} catch (IOException e) {
logger.error("failed to delete outdated checkpoint at " + path, e);
}
});
}
}
public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
throws IOException {
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}
}
}

View File

@@ -15,12 +15,16 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
/**
* trainer for xgboost
@@ -108,35 +112,34 @@ public class XGBoost {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRounds if non-zero, training would be stopped
* after a specified number of consecutive
* goes to the unexpected direction in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRounds,
Booster booster) throws XGBoostError {
private static void saveCheckpoint(
Booster booster,
int iter,
Set<Integer> checkpointIterations,
ExternalCheckpointManager ecm) throws XGBoostError {
try {
if (checkpointIterations.contains(iter)) {
ecm.updateCheckpoint(booster);
}
} catch (Exception e) {
logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
}
}
public static Booster trainAndSaveCheckpoint(
DMatrix dtrain,
Map<String, Object> params,
int numRounds,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRounds,
Booster booster,
int checkpointInterval,
String checkpointPath,
FileSystem fs) throws XGBoostError, IOException {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
@@ -144,6 +147,11 @@ public class XGBoost {
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null;
if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs);
}
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey());
@@ -158,7 +166,7 @@ public class XGBoost {
bestScore = Float.MAX_VALUE;
}
bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
//collect all data matrixs
DMatrix[] allMats;
@@ -181,14 +189,19 @@ public class XGBoost {
booster.setParams(params);
}
//begin to train
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
if (ecm != null) {
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
}
// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
}
@@ -224,7 +237,7 @@ public class XGBoost {
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
Rabit.trackerPrint(String.format(
"early stopping after %d rounds away from the best iteration",
earlyStoppingRounds));
earlyStoppingRounds));
break;
}
}
@@ -239,6 +252,44 @@ public class XGBoost {
return booster;
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRounds if non-zero, training would be stopped
* after a specified number of consecutive
* goes to the unexpected direction in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRounds,
Booster booster) throws XGBoostError {
try {
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
earlyStoppingRounds, booster,
-1, null, null);
} catch (IOException e) {
logger.error("training failed in xgboost4j", e);
throw new XGBoostError("training failed in xgboost4j ", e);
}
}
private static Integer tryGetIntFromObject(Object o) {
if (o instanceof Integer) {
return (int)o;

View File

@@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
public XGBoostError(String message) {
super(message);
}
public XGBoostError(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
import org.apache.hadoop.fs.FileSystem
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
extends JavaECM(checkpointPath, fs) {
def updateCheckpoint(booster: Booster): Unit = {
super.updateCheckpoint(booster.booster)
}
def loadCheckpointAsScalaBooster(): Booster = {
val loadedBooster = super.loadCheckpointAsBooster()
if (loadedBooster == null) {
null
} else {
new Booster(loadedBooster)
}
}
}

View File

@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
/**
* XGBoost Scala Training function.
*/
object XGBoost {
private[scala] def trainAndSaveCheckpoint(
dtrain: DMatrix,
params: Map[String, Any],
numRounds: Int,
watches: Map[String, DMatrix] = Map(),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
prevBooster: Booster,
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (prevBooster == null) {
null
} else {
prevBooster.booster
}
val xgboostInJava = checkpointParams.
map(cp => {
JXGBoost.trainAndSaveCheckpoint(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
cp.checkpointInterval,
cp.checkpointPath,
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
}).
getOrElse(
JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
)
if (prevBooster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
prevBooster
}
}
/**
* Train a booster given parameters.
*
@@ -55,23 +101,8 @@ object XGBoost {
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
if (booster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
booster
}
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
booster, None)
}
/**
@@ -126,3 +157,41 @@ object XGBoost {
new Booster(xgboostInJava)
}
}
private[scala] case class ExternalCheckpointParams(
checkpointInterval: Int,
checkpointPath: String,
skipCleanCheckpoint: Boolean)
private[scala] object ExternalCheckpointParams {
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None | Some(null) | Some("") => null
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
s" an instance of String, but current value is ${params("checkpoint_path")}")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
if (checkpointPath == null || checkpointInterval == 0) {
None
} else {
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
}
}
}