[breaking] [jvm-packages] Remove rabit check point. (#9599)
- Add `numBoostedRound` to jvm packages - Remove rabit checkpoint version. - Change the starting version of training continuation in JVM [breaking]. - Redefine the checkpoint version policy in jvm package. [breaking] - Rename the Python check point callback parameter. [breaking] - Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
parent
7901a299b2
commit
c75a3bc0a9
@ -104,7 +104,7 @@ def check_point_callback():
|
|||||||
# Use callback class from xgboost.callback
|
# Use callback class from xgboost.callback
|
||||||
# Feel free to subclass/customize it to suit your need.
|
# Feel free to subclass/customize it to suit your need.
|
||||||
check_point = xgb.callback.TrainingCheckPoint(
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
directory=tmpdir, iterations=rounds, name="model"
|
directory=tmpdir, interval=rounds, name="model"
|
||||||
)
|
)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{"objective": "binary:logistic"},
|
{"objective": "binary:logistic"},
|
||||||
@ -118,7 +118,7 @@ def check_point_callback():
|
|||||||
# This version of checkpoint saves everything including parameters and
|
# This version of checkpoint saves everything including parameters and
|
||||||
# model. See: doc/tutorials/saving_model.rst
|
# model. See: doc/tutorials/saving_model.rst
|
||||||
check_point = xgb.callback.TrainingCheckPoint(
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
|
directory=tmpdir, interval=rounds, as_pickle=True, name="model"
|
||||||
)
|
)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{"objective": "binary:logistic"},
|
{"objective": "binary:logistic"},
|
||||||
|
|||||||
@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
|
|||||||
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
||||||
const void *buf, bst_ulong len);
|
const void *buf, bst_ulong len);
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Initialize the booster from rabit checkpoint.
|
|
||||||
* This is used in distributed training API.
|
|
||||||
* \param handle handle
|
|
||||||
* \param version The output version of the model.
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
|
||||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
|
||||||
int* version);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Save the current checkpoint to rabit.
|
|
||||||
* \param handle handle
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
|
||||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
|
||||||
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
|
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
|
||||||
* support is experimental, function signature may change in the future without
|
* support is experimental, function signature may change in the future without
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def createNewModels():
|
private def createNewModels():
|
||||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||||
val (model4, model8) = {
|
val (model2, model4) = {
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val paramMap = produceParamMap(tmpPath, 2)
|
val paramMap = produceParamMap(tmpPath, 2)
|
||||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||||
}
|
}
|
||||||
(tmpPath, model4, model8)
|
(tmpPath, model2, model4)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test update/load models") {
|
test("test update/load models") {
|
||||||
val (tmpPath, model4, model8) = createNewModels()
|
val (tmpPath, model2, model4) = createNewModels()
|
||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
|
|
||||||
manager.updateCheckpoint(model4._booster.booster)
|
manager.updateCheckpoint(model2._booster.booster)
|
||||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "4.model")
|
assert(files.head.getPath.getName == "1.model")
|
||||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||||
|
|
||||||
manager.updateCheckpoint(model8._booster)
|
manager.updateCheckpoint(model4._booster)
|
||||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "8.model")
|
assert(files.head.getPath.getName == "3.model")
|
||||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test cleanUpHigherVersions") {
|
test("test cleanUpHigherVersions") {
|
||||||
val (tmpPath, model4, model8) = createNewModels()
|
val (tmpPath, model2, model4) = createNewModels()
|
||||||
|
|
||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
manager.updateCheckpoint(model8._booster)
|
manager.updateCheckpoint(model4._booster)
|
||||||
manager.cleanUpHigherVersions(8)
|
manager.cleanUpHigherVersions(3)
|
||||||
assert(new File(s"$tmpPath/8.model").exists())
|
assert(new File(s"$tmpPath/3.model").exists())
|
||||||
|
|
||||||
manager.cleanUpHigherVersions(4)
|
manager.cleanUpHigherVersions(2)
|
||||||
assert(!new File(s"$tmpPath/8.model").exists())
|
assert(!new File(s"$tmpPath/3.model").exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test checkpoint rounds") {
|
test("test checkpoint rounds") {
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
val (tmpPath, model4, model8) = createNewModels()
|
val (tmpPath, model2, model4) = createNewModels()
|
||||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
assertResult(Seq(7))(
|
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
|
||||||
manager.getCheckpointRounds(0, 7).asScala)
|
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
|
||||||
assertResult(Seq(2, 4, 6, 7))(
|
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
|
||||||
manager.getCheckpointRounds(2, 7).asScala)
|
|
||||||
manager.updateCheckpoint(model4._booster)
|
|
||||||
assertResult(Seq(4, 6, 7))(
|
|
||||||
manager.getCheckpointRounds(2, 7).asScala)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
|||||||
// Check only one model is kept after training
|
// Check only one model is kept after training
|
||||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "8.model")
|
assert(files.head.getPath.getName == "4.model")
|
||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
||||||
// Train next model based on prev model
|
// Train next model based on prev model
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||||
assert(error(tmpModel) >= error(prevModel._booster))
|
assert(error(tmpModel) >= error(prevModel._booster))
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -787,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
return importanceMap;
|
return importanceMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the model as byte array representation.
|
|
||||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
|
||||||
*
|
|
||||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
|
||||||
*
|
|
||||||
* @param withStats Controls whether the split statistics are output.
|
|
||||||
* @return dumped model information
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
|
||||||
int statsFlag = 0;
|
|
||||||
if (withStats) {
|
|
||||||
statsFlag = 1;
|
|
||||||
}
|
|
||||||
String[][] modelInfos = new String[1][];
|
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
|
|
||||||
modelInfos));
|
|
||||||
return modelInfos[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getVersion() {
|
|
||||||
return this.version;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setVersion(int version) {
|
|
||||||
this.version = version;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save model into raw byte array. Currently it's using the deprecated format as
|
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||||
* default, which will be changed into `ubj` in future releases.
|
* default, which will be changed into `ubj` in future releases.
|
||||||
@ -841,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
return bytes[0];
|
return bytes[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Load the booster model from thread-local rabit checkpoint.
|
|
||||||
* This is only used in distributed training.
|
|
||||||
* @return the stored version number of the checkpoint.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
int loadRabitCheckpoint() throws XGBoostError {
|
|
||||||
int[] out = new int[1];
|
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
|
||||||
version = out[0];
|
|
||||||
return version;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the booster model into thread-local rabit checkpoint and increment the version.
|
|
||||||
* This is only used in distributed training.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
void saveRabitCheckpoint() throws XGBoostError {
|
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
|
||||||
version += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get number of model features.
|
* Get number of model features.
|
||||||
* @return the number of features.
|
* @return the number of features.
|
||||||
@ -874,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||||
return numFeature[0];
|
return numFeature[0];
|
||||||
}
|
}
|
||||||
|
public int getNumBoostedRound() throws XGBoostError {
|
||||||
|
int[] numRound = new int[1];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
|
||||||
|
return numRound[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Internal initialization function.
|
* Internal initialization function.
|
||||||
|
|||||||
@ -1,3 +1,18 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
|
|||||||
|
|
||||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||||
private String modelSuffix = ".model";
|
private String modelSuffix = ".model";
|
||||||
private Path checkpointPath;
|
private Path checkpointPath; // directory for checkpoints
|
||||||
private FileSystem fs;
|
private FileSystem fs;
|
||||||
|
|
||||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||||
@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
|
|||||||
if (!fs.exists(checkpointPath)) {
|
if (!fs.exists(checkpointPath)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
} else {
|
} else {
|
||||||
|
// Get integer versions from a list of checkpoint files.
|
||||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||||
.map(path -> path.getPath().getName())
|
.map(path -> path.getPath().getName())
|
||||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||||
@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Integer latest(List<Integer> versions) {
|
||||||
|
return versions.stream()
|
||||||
|
.max(Comparator.comparing(Integer::valueOf)).get();
|
||||||
|
}
|
||||||
|
|
||||||
public void cleanPath() throws IOException {
|
public void cleanPath() throws IOException {
|
||||||
fs.delete(checkpointPath, true);
|
fs.delete(checkpointPath, true);
|
||||||
}
|
}
|
||||||
@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
|
|||||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||||
List<Integer> versions = getExistingVersions();
|
List<Integer> versions = getExistingVersions();
|
||||||
if (versions.size() > 0) {
|
if (versions.size() > 0) {
|
||||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
int latestVersion = this.latest(versions);
|
||||||
String checkpointPath = getPath(latestVersion);
|
String checkpointPath = getPath(latestVersion);
|
||||||
InputStream in = fs.open(new Path(checkpointPath));
|
InputStream in = fs.open(new Path(checkpointPath));
|
||||||
logger.info("loaded checkpoint from " + checkpointPath);
|
logger.info("loaded checkpoint from " + checkpointPath);
|
||||||
Booster booster = XGBoost.loadModel(in);
|
Booster booster = XGBoost.loadModel(in);
|
||||||
booster.setVersion(latestVersion);
|
|
||||||
return booster;
|
return booster;
|
||||||
} else {
|
} else {
|
||||||
return null;
|
return null;
|
||||||
@ -65,13 +85,16 @@ public class ExternalCheckpointManager {
|
|||||||
|
|
||||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||||
List<String> prevModelPaths = getExistingVersions().stream()
|
List<String> prevModelPaths = getExistingVersions().stream()
|
||||||
.map(this::getPath).collect(Collectors.toList());
|
.map(this::getPath).collect(Collectors.toList());
|
||||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
// checkpointing is done after update, so n_rounds - 1 is the current iteration
|
||||||
|
// accounting for training continuation.
|
||||||
|
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
|
||||||
|
String eventualPath = getPath(iter);
|
||||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||||
boosterToCheckpoint.saveModel(out);
|
boosterToCheckpoint.saveModel(out);
|
||||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
logger.info("saving checkpoint with version " + iter);
|
||||||
prevModelPaths.stream().forEach(path -> {
|
prevModelPaths.stream().forEach(path -> {
|
||||||
try {
|
try {
|
||||||
fs.delete(new Path(path), true);
|
fs.delete(new Path(path), true);
|
||||||
@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
|
||||||
try {
|
try {
|
||||||
fs.delete(new Path(getPath(v)), true);
|
fs.delete(new Path(getPath(v)), true);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// Get a list of iterations that need checkpointing.
|
||||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
public List<Integer> getCheckpointRounds(
|
||||||
|
int firstRound, int checkpointInterval, int numOfRounds)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
int end = firstRound + numOfRounds; // exclusive
|
||||||
|
int lastRound = end - 1;
|
||||||
|
if (end - 1 < 0) {
|
||||||
|
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Integer> arr = new ArrayList<>();
|
||||||
if (checkpointInterval > 0) {
|
if (checkpointInterval > 0) {
|
||||||
List<Integer> prevRounds =
|
for (int i = firstRound; i < end; i += checkpointInterval) {
|
||||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
|
||||||
prevRounds.add(0);
|
|
||||||
int firstCheckpointRound = prevRounds.stream()
|
|
||||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
|
||||||
List<Integer> arr = new ArrayList<>();
|
|
||||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
|
||||||
arr.add(i);
|
arr.add(i);
|
||||||
}
|
}
|
||||||
arr.add(numOfRounds);
|
|
||||||
return arr;
|
|
||||||
} else if (checkpointInterval <= 0) {
|
|
||||||
List<Integer> l = new ArrayList<Integer>();
|
|
||||||
l.add(numOfRounds);
|
|
||||||
return l;
|
|
||||||
} else {
|
|
||||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!arr.contains(lastRound)) {
|
||||||
|
arr.add(lastRound);
|
||||||
|
}
|
||||||
|
return arr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -133,7 +133,7 @@ public class XGBoost {
|
|||||||
int earlyStoppingRound) throws XGBoostError {
|
int earlyStoppingRound) throws XGBoostError {
|
||||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||||
}
|
}
|
||||||
|
// save checkpoint if iter is in checkpointIterations
|
||||||
private static void saveCheckpoint(
|
private static void saveCheckpoint(
|
||||||
Booster booster,
|
Booster booster,
|
||||||
int iter,
|
int iter,
|
||||||
@ -169,7 +169,6 @@ public class XGBoost {
|
|||||||
int bestIteration;
|
int bestIteration;
|
||||||
List<String> names = new ArrayList<String>();
|
List<String> names = new ArrayList<String>();
|
||||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||||
Set<Integer> checkpointIterations = new HashSet<>();
|
|
||||||
ExternalCheckpointManager ecm = null;
|
ExternalCheckpointManager ecm = null;
|
||||||
if (checkpointPath != null) {
|
if (checkpointPath != null) {
|
||||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||||
@ -203,32 +202,30 @@ public class XGBoost {
|
|||||||
booster = new Booster(params, allMats);
|
booster = new Booster(params, allMats);
|
||||||
booster.setFeatureNames(dtrain.getFeatureNames());
|
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||||
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||||
booster.loadRabitCheckpoint();
|
|
||||||
} else {
|
} else {
|
||||||
// Start training on an existing booster
|
// Start training on an existing booster
|
||||||
booster.setParams(params);
|
booster.setParams(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Set<Integer> checkpointIterations = new HashSet<>();
|
||||||
if (ecm != null) {
|
if (ecm != null) {
|
||||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
checkpointIterations = new HashSet<>(
|
||||||
|
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean initial_best_score_flag = false;
|
boolean initial_best_score_flag = false;
|
||||||
boolean max_direction = false;
|
boolean max_direction = false;
|
||||||
|
|
||||||
// begin to train
|
// begin to train
|
||||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
for (int iter = 0; iter < numRounds; iter++) {
|
||||||
if (booster.getVersion() % 2 == 0) {
|
if (obj != null) {
|
||||||
if (obj != null) {
|
booster.update(dtrain, iter, obj);
|
||||||
booster.update(dtrain, obj);
|
} else {
|
||||||
} else {
|
booster.update(dtrain, iter);
|
||||||
booster.update(dtrain, iter);
|
|
||||||
}
|
|
||||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
|
||||||
booster.saveRabitCheckpoint();
|
|
||||||
}
|
}
|
||||||
|
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||||
|
|
||||||
//evaluation
|
// evaluation
|
||||||
if (evalMats.length > 0) {
|
if (evalMats.length > 0) {
|
||||||
float[] metricsOut = new float[evalMats.length];
|
float[] metricsOut = new float[evalMats.length];
|
||||||
String evalInfo;
|
String evalInfo;
|
||||||
@ -285,7 +282,6 @@ public class XGBoost {
|
|||||||
Communicator.communicatorPrint(evalInfo + '\n');
|
Communicator.communicatorPrint(evalInfo + '\n');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
booster.saveRabitCheckpoint();
|
|
||||||
}
|
}
|
||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -140,10 +140,11 @@ class XGBoostJNI {
|
|||||||
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
|
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
|
||||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
|
||||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
|
||||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||||
|
|
||||||
|
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||||
|
|
||||||
// communicator functions
|
// communicator functions
|
||||||
public final static native int CommunicatorInit(String[] args);
|
public final static native int CommunicatorInit(String[] args);
|
||||||
public final static native int CommunicatorFinalize();
|
public final static native int CommunicatorFinalize();
|
||||||
|
|||||||
@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def getNumFeature: Long = booster.getNumFeature
|
def getNumFeature: Long = booster.getNumFeature
|
||||||
|
|
||||||
def getVersion: Int = booster.getVersion
|
def getNumBoostedRound: Long = booster.getNumBoostedRound
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
||||||
|
|||||||
@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGBoosterLoadRabitCheckpoint
|
|
||||||
* Signature: (J[I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
|
||||||
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
|
||||||
int version;
|
|
||||||
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
|
||||||
JVM_CHECK_CALL(ret);
|
|
||||||
jint jversion = version;
|
|
||||||
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGBoosterSaveRabitCheckpoint
|
|
||||||
* Signature: (J)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
|
||||||
return XGBoosterSaveRabitCheckpoint(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterGetNumFeature
|
* Method: XGBoosterGetNumFeature
|
||||||
@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
|
||||||
|
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
|
||||||
|
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||||
|
std::int32_t n_rounds{0};
|
||||||
|
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
|
jint jn_rounds = n_rounds;
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: CommunicatorInit
|
* Method: CommunicatorInit
|
||||||
|
|||||||
@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGBoosterLoadRabitCheckpoint
|
|
||||||
* Signature: (J[I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
|
||||||
(JNIEnv *, jclass, jlong, jintArray);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGBoosterSaveRabitCheckpoint
|
|
||||||
* Signature: (J)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
|
||||||
(JNIEnv *, jclass, jlong);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterGetNumFeature
|
* Method: XGBoosterGetNumFeature
|
||||||
@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||||
(JNIEnv *, jclass, jlong, jlongArray);
|
(JNIEnv *, jclass, jlong, jlongArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterGetNumBoostedRound
|
||||||
|
* Signature: (J[I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: CommunicatorInit
|
* Method: CommunicatorInit
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014-2022 by Contributors
|
Copyright (c) 2014-2023 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,7 +16,6 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
@ -845,14 +844,12 @@ public class BoosterImplTest {
|
|||||||
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||||
|
|
||||||
// Save tempBooster to bytestream and load back
|
// Save tempBooster to bytestream and load back
|
||||||
int prevVersion = tempBooster.getVersion();
|
|
||||||
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||||
tempBooster = XGBoost.loadModel(in);
|
tempBooster = XGBoost.loadModel(in);
|
||||||
in.close();
|
in.close();
|
||||||
tempBooster.setVersion(prevVersion);
|
|
||||||
|
|
||||||
// Continue training using tempBooster
|
// Continue training using tempBooster
|
||||||
round = 4;
|
round = 2;
|
||||||
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||||
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||||
TestCase.assertTrue(booster1error == booster2error);
|
TestCase.assertTrue(booster1error == booster2error);
|
||||||
|
|||||||
@ -540,7 +540,10 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
|
|
||||||
|
|
||||||
class TrainingCheckPoint(TrainingCallback):
|
class TrainingCheckPoint(TrainingCallback):
|
||||||
"""Checkpointing operation.
|
"""Checkpointing operation. Users are encouraged to create their own callbacks for
|
||||||
|
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
|
||||||
|
distributed systems, be sure to know the rank of the worker to avoid multiple
|
||||||
|
workers checkpointing to the same place.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||||
name_2.json ....
|
name_2.json ....
|
||||||
as_pickle :
|
as_pickle :
|
||||||
When set to True, all training parameters will be saved in pickle format, instead
|
When set to True, all training parameters will be saved in pickle format,
|
||||||
of saving only the model.
|
instead of saving only the model.
|
||||||
iterations :
|
interval :
|
||||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||||
reduce performance hit.
|
reduce performance hit.
|
||||||
|
|
||||||
@ -566,15 +569,20 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
directory: Union[str, os.PathLike],
|
directory: Union[str, os.PathLike],
|
||||||
name: str = "model",
|
name: str = "model",
|
||||||
as_pickle: bool = False,
|
as_pickle: bool = False,
|
||||||
iterations: int = 100,
|
interval: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._path = os.fspath(directory)
|
self._path = os.fspath(directory)
|
||||||
self._name = name
|
self._name = name
|
||||||
self._as_pickle = as_pickle
|
self._as_pickle = as_pickle
|
||||||
self._iterations = iterations
|
self._iterations = interval
|
||||||
self._epoch = 0
|
self._epoch = 0 # counter for iterval
|
||||||
|
self._start = 0 # beginning iteration
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def before_training(self, model: _Model) -> _Model:
|
||||||
|
self._start = model.num_boosted_rounds()
|
||||||
|
return model
|
||||||
|
|
||||||
def after_iteration(
|
def after_iteration(
|
||||||
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -583,11 +591,12 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
self._path,
|
self._path,
|
||||||
self._name
|
self._name
|
||||||
+ "_"
|
+ "_"
|
||||||
+ str(epoch)
|
+ (str(epoch + self._start))
|
||||||
+ (".pkl" if self._as_pickle else ".json"),
|
+ (".pkl" if self._as_pickle else ".json"),
|
||||||
)
|
)
|
||||||
self._epoch = 0
|
self._epoch = 0 # reset counter
|
||||||
if collective.get_rank() == 0:
|
if collective.get_rank() == 0:
|
||||||
|
# checkpoint using the first worker
|
||||||
if self._as_pickle:
|
if self._as_pickle:
|
||||||
with open(path, "wb") as fd:
|
with open(path, "wb") as fd:
|
||||||
pickle.dump(model, fd)
|
pickle.dump(model, fd)
|
||||||
|
|||||||
@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step,
|
||||||
int* version) {
|
|
||||||
API_BEGIN();
|
|
||||||
CHECK_HANDLE();
|
|
||||||
auto* bst = static_cast<Learner*>(handle);
|
|
||||||
xgboost_CHECK_C_ARG_PTR(version);
|
|
||||||
*version = rabit::LoadCheckPoint();
|
|
||||||
if (*version != 0) {
|
|
||||||
bst->Configure();
|
|
||||||
}
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
|
||||||
API_BEGIN();
|
|
||||||
CHECK_HANDLE();
|
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
|
||||||
learner->Configure();
|
|
||||||
rabit::CheckPoint();
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
|
||||||
int end_layer, int step,
|
|
||||||
BoosterHandle *out) {
|
BoosterHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
auto* learner = static_cast<Learner*>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
bool out_of_bound = false;
|
bool out_of_bound = false;
|
||||||
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
||||||
if (out_of_bound) {
|
if (out_of_bound) {
|
||||||
|
|||||||
@ -443,7 +443,7 @@ class TestCallbacks:
|
|||||||
m = xgb.DMatrix(X, y)
|
m = xgb.DMatrix(X, y)
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
check_point = xgb.callback.TrainingCheckPoint(
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
directory=tmpdir, iterations=1, name="model"
|
directory=tmpdir, interval=1, name="model"
|
||||||
)
|
)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{"objective": "binary:logistic"},
|
{"objective": "binary:logistic"},
|
||||||
@ -456,7 +456,7 @@ class TestCallbacks:
|
|||||||
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
||||||
|
|
||||||
check_point = xgb.callback.TrainingCheckPoint(
|
check_point = xgb.callback.TrainingCheckPoint(
|
||||||
directory=tmpdir, iterations=1, as_pickle=True, name="model"
|
directory=tmpdir, interval=1, as_pickle=True, name="model"
|
||||||
)
|
)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
{"objective": "binary:logistic"},
|
{"objective": "binary:logistic"},
|
||||||
|
|||||||
@ -2238,7 +2238,7 @@ class TestDaskCallbacks:
|
|||||||
y,
|
y,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
xgb.callback.TrainingCheckPoint(
|
xgb.callback.TrainingCheckPoint(
|
||||||
directory=Path(tmpdir), iterations=1, name="model"
|
directory=Path(tmpdir), interval=1, name="model"
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user