Revert "[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4876)" (#4965)
This reverts commit 86ed01c4bbecef66e1bc4d02fb13116bd6130fae.
This commit is contained in:
parent
86ed01c4bb
commit
010b8f1428
@ -94,16 +94,36 @@ set_target_properties(dmlc PROPERTIES
|
|||||||
list(APPEND LINKED_LIBRARIES_PRIVATE dmlc)
|
list(APPEND LINKED_LIBRARIES_PRIVATE dmlc)
|
||||||
|
|
||||||
# rabit
|
# rabit
|
||||||
set(RABIT_BUILD_DMLC OFF)
|
# full rabit doesn't build on windows, so we can't import it as subdirectory
|
||||||
set(DMLC_ROOT ${xgboost_SOURCE_DIR}/dmlc-core)
|
if(MINGW OR R_LIB OR WIN32)
|
||||||
set(RABIT_WITH_R_LIB ${R_LIB})
|
set(RABIT_SOURCES
|
||||||
add_subdirectory(rabit)
|
rabit/src/engine_empty.cc
|
||||||
|
rabit/src/c_api.cc)
|
||||||
if (RABIT_MOCK)
|
|
||||||
list(APPEND LINKED_LIBRARIES_PRIVATE rabit_mock_static)
|
|
||||||
else ()
|
else ()
|
||||||
list(APPEND LINKED_LIBRARIES_PRIVATE rabit)
|
if(RABIT_MOCK)
|
||||||
|
set(RABIT_SOURCES
|
||||||
|
rabit/src/allreduce_base.cc
|
||||||
|
rabit/src/allreduce_robust.cc
|
||||||
|
rabit/src/engine_mock.cc
|
||||||
|
rabit/src/c_api.cc)
|
||||||
|
else()
|
||||||
|
set(RABIT_SOURCES
|
||||||
|
rabit/src/allreduce_base.cc
|
||||||
|
rabit/src/allreduce_robust.cc
|
||||||
|
rabit/src/engine.cc
|
||||||
|
rabit/src/c_api.cc)
|
||||||
endif(RABIT_MOCK)
|
endif(RABIT_MOCK)
|
||||||
|
endif (MINGW OR R_LIB OR WIN32)
|
||||||
|
add_library(rabit STATIC ${RABIT_SOURCES})
|
||||||
|
target_include_directories(rabit PRIVATE
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}/dmlc-core/include>
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}/rabit/include/rabit>)
|
||||||
|
set_target_properties(rabit
|
||||||
|
PROPERTIES
|
||||||
|
CXX_STANDARD 11
|
||||||
|
CXX_STANDARD_REQUIRED ON
|
||||||
|
POSITION_INDEPENDENT_CODE ON)
|
||||||
|
list(APPEND LINKED_LIBRARIES_PRIVATE rabit)
|
||||||
|
|
||||||
# Exports some R specific definitions and objects
|
# Exports some R specific definitions and objects
|
||||||
if (R_LIB)
|
if (R_LIB)
|
||||||
|
|||||||
1
jvm-packages/.gitignore
vendored
1
jvm-packages/.gitignore
vendored
@ -1,3 +1,2 @@
|
|||||||
tracker.py
|
tracker.py
|
||||||
build.sh
|
build.sh
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ CONFIG = {
|
|||||||
"USE_HDFS": "OFF",
|
"USE_HDFS": "OFF",
|
||||||
"USE_AZURE": "OFF",
|
"USE_AZURE": "OFF",
|
||||||
"USE_S3": "OFF",
|
"USE_S3": "OFF",
|
||||||
"RABIT_MOCK": "OFF",
|
|
||||||
"USE_CUDA": "OFF",
|
"USE_CUDA": "OFF",
|
||||||
"JVM_BINDINGS": "ON"
|
"JVM_BINDINGS": "ON"
|
||||||
}
|
}
|
||||||
@ -68,7 +68,6 @@ def normpath(path):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CONFIG["RABIT_MOCK"] = str(sys.argv[1])
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
# Enable of your compiler supports OpenMP.
|
# Enable of your compiler supports OpenMP.
|
||||||
CONFIG["USE_OPENMP"] = "OFF"
|
CONFIG["USE_OPENMP"] = "OFF"
|
||||||
|
|||||||
@ -37,7 +37,6 @@
|
|||||||
<spark.version>2.4.3</spark.version>
|
<spark.version>2.4.3</spark.version>
|
||||||
<scala.version>2.12.8</scala.version>
|
<scala.version>2.12.8</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
<rabit.mock>OFF</rabit.mock>
|
|
||||||
</properties>
|
</properties>
|
||||||
<repositories>
|
<repositories>
|
||||||
<repository>
|
<repository>
|
||||||
@ -53,47 +52,6 @@
|
|||||||
<module>xgboost4j-flink</module>
|
<module>xgboost4j-flink</module>
|
||||||
</modules>
|
</modules>
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
|
||||||
<id>dev</id>
|
|
||||||
<properties>
|
|
||||||
<rabit.mock>ON</rabit.mock>
|
|
||||||
</properties>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
|
||||||
<version>3.0.2</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>empty-javadoc-jar</id>
|
|
||||||
<phase>package</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>jar</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<classifier>javadoc</classifier>
|
|
||||||
<classesDirectory>${basedir}/javadoc</classesDirectory>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-source-plugin</artifactId>
|
|
||||||
<version>2.2.1</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>attach-sources</id>
|
|
||||||
<goals>
|
|
||||||
<goal>jar-no-fork</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>release</id>
|
<id>release</id>
|
||||||
<build>
|
<build>
|
||||||
|
|||||||
@ -49,7 +49,7 @@ This file is divided into 3 sections:
|
|||||||
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
|
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
|
||||||
<parameters>
|
<parameters>
|
||||||
<parameter name="header"><![CDATA[/*
|
<parameter name="header"><![CDATA[/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -21,11 +21,11 @@ import java.nio.file.Files
|
|||||||
|
|
||||||
import scala.collection.{AbstractIterator, mutable}
|
import scala.collection.{AbstractIterator, mutable}
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
@ -155,7 +155,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
overridedParams
|
overridedParams
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||||
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
|
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
|
||||||
@ -221,25 +221,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
xgbExecParam.setRawParamMap(overridedParams)
|
xgbExecParam.setRawParamMap(overridedParams)
|
||||||
xgbExecParam
|
xgbExecParam
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def buildRabitParams: Map[String, String] = Map(
|
|
||||||
"rabit_reduce_ring_mincount" ->
|
|
||||||
overridedParams.getOrElse("rabit_reduce_ring_mincount", 32<<10).toString,
|
|
||||||
"rabit_reduce_buffer" ->
|
|
||||||
overridedParams.getOrElse("rabit_reduce_buffer", "256MB").toString,
|
|
||||||
"rabit_bootstrap_cache" ->
|
|
||||||
overridedParams.getOrElse("rabit_bootstrap_cache", false).toString,
|
|
||||||
"rabit_debug" ->
|
|
||||||
overridedParams.getOrElse("rabit_debug", false).toString,
|
|
||||||
"rabit_timeout" ->
|
|
||||||
overridedParams.getOrElse("rabit_timeout", false).toString,
|
|
||||||
"rabit_timeout_sec" ->
|
|
||||||
overridedParams.getOrElse("rabit_timeout_sec", 1800).toString,
|
|
||||||
"DMLC_WORKER_CONNECT_RETRY" ->
|
|
||||||
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString,
|
|
||||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" ->
|
|
||||||
overridedParams.getOrElse("dmlc_worker_stop_process_on_error", false).toString
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -340,6 +321,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
val taskId = TaskContext.getPartitionId().toString
|
val taskId = TaskContext.getPartitionId().toString
|
||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
|
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
@ -508,9 +490,8 @@ object XGBoost extends Serializable {
|
|||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||||
(Booster, Map[String, Array[Float]]) = {
|
(Booster, Map[String, Array[Float]]) = {
|
||||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
|
||||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
buildXGBRuntimeParams
|
||||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||||
checkpointPath)
|
checkpointPath)
|
||||||
@ -529,12 +510,12 @@ object XGBoost extends Serializable {
|
|||||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||||
xgbExecParams.timeoutRequestWorkers,
|
xgbExecParams.timeoutRequestWorkers,
|
||||||
xgbExecParams.numWorkers)
|
xgbExecParams.numWorkers)
|
||||||
val rabitEnv = tracker.getWorkerEnvs.asScala ++ xgbRabitParams
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
val boostersAndMetrics = if (hasGroup) {
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv.asJava,
|
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
|
||||||
checkpointRound, prevBooster, evalSetsMap)
|
checkpointRound, prevBooster, evalSetsMap)
|
||||||
} else {
|
} else {
|
||||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv.asJava,
|
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||||
checkpointRound, prevBooster, evalSetsMap)
|
checkpointRound, prevBooster, evalSetsMap)
|
||||||
}
|
}
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -50,7 +50,7 @@ class XGBoostClassifier (
|
|||||||
def this(xgboostParams: Map[String, Any]) = this(
|
def this(xgboostParams: Map[String, Any]) = this(
|
||||||
Identifiable.randomUID("xgbc"), xgboostParams)
|
Identifiable.randomUID("xgbc"), xgboostParams)
|
||||||
|
|
||||||
XGBoost2MLlibParams(xgboostParams)
|
XGBoostToMLlibParams(xgboostParams)
|
||||||
|
|
||||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._
|
|||||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
|
|
||||||
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
with BoosterParams with ParamMapFuncs with NonParamVariables {
|
||||||
|
|
||||||
def needDeterministicRepartitioning: Boolean = {
|
def needDeterministicRepartitioning: Boolean = {
|
||||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -54,7 +54,7 @@ class XGBoostRegressor (
|
|||||||
def this(xgboostParams: Map[String, Any]) = this(
|
def this(xgboostParams: Map[String, Any]) = this(
|
||||||
Identifiable.randomUID("xgbr"), xgboostParams)
|
Identifiable.randomUID("xgbr"), xgboostParams)
|
||||||
|
|
||||||
XGBoost2MLlibParams(xgboostParams)
|
XGBoostToMLlibParams(xgboostParams)
|
||||||
|
|
||||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -241,7 +241,7 @@ trait HasNumClass extends Params {
|
|||||||
|
|
||||||
private[spark] trait ParamMapFuncs extends Params {
|
private[spark] trait ParamMapFuncs extends Params {
|
||||||
|
|
||||||
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||||
for ((paramName, paramValue) <- xgboostParams) {
|
for ((paramName, paramValue) <- xgboostParams) {
|
||||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,62 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
|
||||||
|
|
||||||
import org.apache.spark.ml.param._
|
|
||||||
|
|
||||||
private[spark] trait RabitParams extends Params {
|
|
||||||
/**
|
|
||||||
* Rabit worker configurations. These parameters were passed to Rabit.Init and decide
|
|
||||||
* rabit_reduce_ring_mincount - threshold of enable ring based allreduce/broadcast operations.
|
|
||||||
* rabit_reduce_buffer - buffer size to recv and run reduction
|
|
||||||
* rabit_bootstrap_cache - enable save allreduce cache before loadcheckpoint
|
|
||||||
* rabit_debug - enable more verbose rabit logging to stdout
|
|
||||||
* rabit_timeout - enable sidecar thread after rabit observed failures
|
|
||||||
* rabit_timeout_sec - wait interval before exit after rabit observed failures
|
|
||||||
* dmlc_worker_connect_retry - number of retrys to tracker
|
|
||||||
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
|
|
||||||
*/
|
|
||||||
final val ringReduceMin = new IntParam(this, "rabitReduceRingMincount",
|
|
||||||
"minimal counts of enable allreduce/broadcast with ring based topology",
|
|
||||||
ParamValidators.gtEq(1))
|
|
||||||
|
|
||||||
final def reduceBuffer: Param[String] = new Param[String](this, "rabitReduceBuffer",
|
|
||||||
"buffer size (MB/GB) allocated to each xgb trainner recv and run reduction",
|
|
||||||
(buf: String) => buf.contains("MB") || buf.contains("GB"))
|
|
||||||
|
|
||||||
final def bootstrapCache: BooleanParam = new BooleanParam(this, "rabitBootstrapCache",
|
|
||||||
"enable save allreduce cache before loadcheckpoint, used to allow failed task retry")
|
|
||||||
|
|
||||||
final def rabitDebug: BooleanParam = new BooleanParam(this, "rabitDebug",
|
|
||||||
"enable more verbose rabit logging to stdout")
|
|
||||||
|
|
||||||
final def rabitTimeout: BooleanParam = new BooleanParam(this, "rabitTimeout",
|
|
||||||
"enable failure timeout sidecar threads")
|
|
||||||
|
|
||||||
final def timeoutInterval: IntParam = new IntParam(this, "rabitTimeoutSec",
|
|
||||||
"timeout threshold after rabit observed failures", (interval: Int) => interval > 0)
|
|
||||||
|
|
||||||
final def connectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
|
|
||||||
"number of retry worker do before fail", ParamValidators.gtEq(1))
|
|
||||||
|
|
||||||
final def exitOnError: BooleanParam = new BooleanParam(this, "dmlcWorkerStopProcessOnError",
|
|
||||||
"exit process when rabit see assert error")
|
|
||||||
|
|
||||||
setDefault(ringReduceMin -> (32 << 10), reduceBuffer -> "256MB", bootstrapCache -> false,
|
|
||||||
rabitDebug -> false, connectRetry -> 5, rabitTimeout -> false, timeoutInterval -> 1800,
|
|
||||||
exitOnError -> false)
|
|
||||||
}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,10 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
@ -31,7 +28,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("nthread configuration must be no larger than spark.task.cpus") {
|
test("nthread configuration must be no larger than spark.task.cpus") {
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
|
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
|
||||||
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
@ -43,7 +40,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
|||||||
// TODO write an isolated test for Booster.
|
// TODO write an isolated test for Booster.
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val testDM = new DMatrix(Classification.test.iterator, null)
|
val testDM = new DMatrix(Classification.test.iterator, null)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
|
||||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||||
@ -55,7 +52,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
|||||||
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
||||||
ss.conf.set("spark.ssl.enabled", true)
|
ss.conf.set("spark.ssl.enabled", true)
|
||||||
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
|
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,110 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError}
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import org.apache.spark.sql._
|
|
||||||
import org.scalatest.FunSuite
|
|
||||||
|
|
||||||
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
|
||||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
|
||||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
|
||||||
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
|
||||||
|
|
||||||
test("test parity classification prediction") {
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDF = buildDataFrame(Classification.test)
|
|
||||||
|
|
||||||
val model1 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
).fit(training)
|
|
||||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
|
||||||
|
|
||||||
val model2 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
|
|
||||||
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
|
|
||||||
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
|
|
||||||
|
|
||||||
assert(Rabit.rabitEnvs.asScala.size > 7)
|
|
||||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
|
||||||
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
|
|
||||||
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
|
|
||||||
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "1")
|
|
||||||
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
|
|
||||||
})
|
|
||||||
|
|
||||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
|
||||||
// check parity w/o rabit cache
|
|
||||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
|
||||||
assert(p1 == p2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test parity regression prediction") {
|
|
||||||
val training = buildDataFrame(Regression.train)
|
|
||||||
val testDM = new DMatrix(Regression.test.iterator, null)
|
|
||||||
val testDF = buildDataFrame(Classification.test)
|
|
||||||
|
|
||||||
val model1 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
|
||||||
).fit(training)
|
|
||||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
|
||||||
|
|
||||||
val model2 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
|
|
||||||
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
|
|
||||||
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
|
|
||||||
assert(Rabit.rabitEnvs.asScala.size > 7)
|
|
||||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
|
||||||
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
|
|
||||||
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
|
|
||||||
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
|
|
||||||
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
|
|
||||||
if (item._1.toString == "DMLC_WORKER_STOP_PROCESS_ON_ERROR") assert(item._2 == "false")
|
|
||||||
})
|
|
||||||
// check the equality of single instance prediction
|
|
||||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
|
||||||
// check parity w/o rabit cache
|
|
||||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
|
||||||
assert(math.abs(p1 - p2) < 0.00001f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test graceful failure handle") {
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDF = buildDataFrame(Classification.test)
|
|
||||||
// mock rank 0 failure during 4th allreduce synchronization
|
|
||||||
Rabit.mockList = Array("0,4,0,0").toList.asJava
|
|
||||||
intercept[XGBoostError] {
|
|
||||||
new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
|
||||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
|
||||||
"rabit_timeout" -> true, "rabit_timeout_sec" -> 1,
|
|
||||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> false)).fit(training)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -66,7 +66,6 @@
|
|||||||
<executable>python</executable>
|
<executable>python</executable>
|
||||||
<arguments>
|
<arguments>
|
||||||
<argument>create_jni.py</argument>
|
<argument>create_jni.py</argument>
|
||||||
<argument>${rabit.mock}</argument>
|
|
||||||
</arguments>
|
</arguments>
|
||||||
<workingDirectory>${user.dir}</workingDirectory>
|
<workingDirectory>${user.dir}</workingDirectory>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
|||||||
@ -3,8 +3,6 @@ package ml.dmlc.xgboost4j.java;
|
|||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -53,25 +51,18 @@ public class Rabit {
|
|||||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// used as way to test/debug passed rabit init parameters
|
|
||||||
public static Map<String, String> rabitEnvs;
|
|
||||||
public static List<String> mockList = new LinkedList<>();
|
|
||||||
/**
|
/**
|
||||||
* Initialize the rabit library on current working thread.
|
* Initialize the rabit library on current working thread.
|
||||||
* @param envs The additional environment variables to pass to rabit.
|
* @param envs The additional environment variables to pass to rabit.
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||||
rabitEnvs = envs;
|
String[] args = new String[envs.size()];
|
||||||
String[] args = new String[envs.size() + mockList.size()];
|
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||||
}
|
}
|
||||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
|
||||||
for(String mock : mockList) {
|
|
||||||
args[idx++] = "mock=" + mock;
|
|
||||||
}
|
|
||||||
checkCall(XGBoostJNI.RabitInit(args));
|
checkCall(XGBoostJNI.RabitInit(args));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 - 2019 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit d22e0809a890ce0bc7af8d76c3c504b333d62f49
|
Subproject commit 9a7ac85d7eb65b1a0b904e1fa8d5a01b910adda4
|
||||||
@ -16,7 +16,7 @@ spark_version=$1
|
|||||||
rm -rf build/
|
rm -rf build/
|
||||||
cd jvm-packages
|
cd jvm-packages
|
||||||
|
|
||||||
mvn --no-transfer-progress package -Dspark.version=${spark_version} -Pdev
|
mvn --no-transfer-progress package -Dspark.version=${spark_version}
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
set +e
|
set +e
|
||||||
|
|||||||
@ -27,8 +27,8 @@ fi
|
|||||||
if [ ${TASK} == "java_test" ]; then
|
if [ ${TASK} == "java_test" ]; then
|
||||||
set -e
|
set -e
|
||||||
cd jvm-packages
|
cd jvm-packages
|
||||||
mvn -q clean install -DskipTests -Dmaven.test.skip -Pdev
|
mvn -q clean install -DskipTests -Dmaven.test.skip
|
||||||
mvn -q test -Pdev
|
mvn -q test
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${TASK} == "cmake_test" ]; then
|
if [ ${TASK} == "cmake_test" ]; then
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user