[jvm-packages] fix checkpoint save/load (#3614)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix update checkpoint func
This commit is contained in:
parent
57f3c2f252
commit
4912c1f9c6
@ -63,8 +63,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
if (versions.nonEmpty) {
|
||||
val version = versions.max
|
||||
val fullPath = getPath(version)
|
||||
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
|
||||
logger.info(s"Start training from previous booster at $fullPath")
|
||||
val booster = SXGBoost.loadModel(fullPath)
|
||||
val booster = SXGBoost.loadModel(inputStream)
|
||||
booster.booster.setVersion(version)
|
||||
booster
|
||||
} else {
|
||||
@ -81,8 +82,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||
val fullPath = getPath(checkpoint.getVersion)
|
||||
val outputStream = fs.create(new Path(fullPath), true)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
||||
checkpoint.saveModel(fullPath)
|
||||
checkpoint.saveModel(outputStream)
|
||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user