[jvm-packages] do not use multiple jobs to make checkpoints (#5082)
* temp * temp * tep * address the comments * fix stylistic issues * fix * external checkpoint
This commit is contained in:
@@ -13,6 +13,18 @@
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private Path checkpointPath;
|
||||
private FileSystem fs;
|
||||
|
||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||
if (checkpointPath == null || checkpointPath.isEmpty()) {
|
||||
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
|
||||
" empty checkpoint path");
|
||||
}
|
||||
this.checkpointPath = new Path(checkpointPath);
|
||||
this.fs = fs;
|
||||
}
|
||||
|
||||
private String getPath(int version) {
|
||||
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
|
||||
}
|
||||
|
||||
private List<Integer> getExistingVersions() throws IOException {
|
||||
if (!fs.exists(checkpointPath)) {
|
||||
return new ArrayList<>();
|
||||
} else {
|
||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||
.map(path -> path.getPath().getName())
|
||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||
.map(fileName -> Integer.valueOf(
|
||||
fileName.substring(0, fileName.length() - modelSuffix.length())))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
public void cleanPath() throws IOException {
|
||||
fs.delete(checkpointPath, true);
|
||||
}
|
||||
|
||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||
List<Integer> versions = getExistingVersions();
|
||||
if (versions.size() > 0) {
|
||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||
String checkpointPath = getPath(latestVersion);
|
||||
InputStream in = fs.open(new Path(checkpointPath));
|
||||
logger.info("loaded checkpoint from " + checkpointPath);
|
||||
Booster booster = XGBoost.loadModel(in);
|
||||
booster.setVersion(latestVersion);
|
||||
return booster;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||
List<String> prevModelPaths = getExistingVersions().stream()
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||
boosterToCheckpoint.saveModel(out);
|
||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||
prevModelPaths.stream().forEach(path -> {
|
||||
try {
|
||||
fs.delete(new Path(path), true);
|
||||
} catch (IOException e) {
|
||||
logger.error("failed to delete outdated checkpoint at " + path, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||
try {
|
||||
fs.delete(new Path(getPath(v)), true);
|
||||
} catch (IOException e) {
|
||||
logger.error("failed to clean checkpoint from other training instance", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||
throws IOException {
|
||||
if (checkpointInterval > 0) {
|
||||
List<Integer> prevRounds =
|
||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||
prevRounds.add(0);
|
||||
int firstCheckpointRound = prevRounds.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||
arr.add(i);
|
||||
}
|
||||
arr.add(numOfRounds);
|
||||
return arr;
|
||||
} else if (checkpointInterval <= 0) {
|
||||
List<Integer> l = new ArrayList<Integer>();
|
||||
l.add(numOfRounds);
|
||||
return l;
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,16 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.*;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
/**
|
||||
* trainer for xgboost
|
||||
@@ -108,35 +112,34 @@ public class XGBoost {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
private static void saveCheckpoint(
|
||||
Booster booster,
|
||||
int iter,
|
||||
Set<Integer> checkpointIterations,
|
||||
ExternalCheckpointManager ecm) throws XGBoostError {
|
||||
try {
|
||||
if (checkpointIterations.contains(iter)) {
|
||||
ecm.updateCheckpoint(booster);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
|
||||
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
|
||||
}
|
||||
}
|
||||
|
||||
public static Booster trainAndSaveCheckpoint(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int numRounds,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster,
|
||||
int checkpointInterval,
|
||||
String checkpointPath,
|
||||
FileSystem fs) throws XGBoostError, IOException {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
@@ -144,6 +147,11 @@ public class XGBoost {
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
ExternalCheckpointManager ecm = null;
|
||||
if (checkpointPath != null) {
|
||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||
}
|
||||
|
||||
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
||||
names.add(evalEntry.getKey());
|
||||
@@ -158,7 +166,7 @@ public class XGBoost {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
@@ -181,14 +189,19 @@ public class XGBoost {
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
||||
if (ecm != null) {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
booster.saveRabitCheckpoint();
|
||||
}
|
||||
|
||||
@@ -224,7 +237,7 @@ public class XGBoost {
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds));
|
||||
earlyStoppingRounds));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -239,6 +252,44 @@ public class XGBoost {
|
||||
return booster;
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
try {
|
||||
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
|
||||
earlyStoppingRounds, booster,
|
||||
-1, null, null);
|
||||
} catch (IOException e) {
|
||||
logger.error("training failed in xgboost4j", e);
|
||||
throw new XGBoostError("training failed in xgboost4j ", e);
|
||||
}
|
||||
}
|
||||
|
||||
private static Integer tryGetIntFromObject(Object o) {
|
||||
if (o instanceof Integer) {
|
||||
return (int)o;
|
||||
|
||||
@@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
|
||||
public XGBoostError(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public XGBoostError(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
||||
extends JavaECM(checkpointPath, fs) {
|
||||
|
||||
def updateCheckpoint(booster: Booster): Unit = {
|
||||
super.updateCheckpoint(booster.booster)
|
||||
}
|
||||
|
||||
def loadCheckpointAsScalaBooster(): Booster = {
|
||||
val loadedBooster = super.loadCheckpointAsBooster()
|
||||
if (loadedBooster == null) {
|
||||
null
|
||||
} else {
|
||||
new Booster(loadedBooster)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
/**
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
private[scala] def trainAndSaveCheckpoint(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
numRounds: Int,
|
||||
watches: Map[String, DMatrix] = Map(),
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
prevBooster: Booster,
|
||||
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (prevBooster == null) {
|
||||
null
|
||||
} else {
|
||||
prevBooster.booster
|
||||
}
|
||||
val xgboostInJava = checkpointParams.
|
||||
map(cp => {
|
||||
JXGBoost.trainAndSaveCheckpoint(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||
cp.checkpointInterval,
|
||||
cp.checkpointPath,
|
||||
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
||||
}).
|
||||
getOrElse(
|
||||
JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
)
|
||||
if (prevBooster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
prevBooster
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
@@ -55,23 +101,8 @@ object XGBoost {
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
booster: Booster = null): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (booster == null) {
|
||||
null
|
||||
} else {
|
||||
booster.booster
|
||||
}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
if (booster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
booster
|
||||
}
|
||||
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
||||
booster, None)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -126,3 +157,41 @@ object XGBoost {
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] case class ExternalCheckpointParams(
|
||||
checkpointInterval: Int,
|
||||
checkpointPath: String,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[scala] object ExternalCheckpointParams {
|
||||
|
||||
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None | Some(null) | Some("") => null
|
||||
case Some(path: String) => path
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||
s" an instance of String, but current value is ${params("checkpoint_path")}")
|
||||
}
|
||||
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
|
||||
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
if (checkpointPath == null || checkpointInterval == 0) {
|
||||
None
|
||||
} else {
|
||||
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user