Remove RABIT CMake targets. (#6275)
* Now it's built as part of libxgboost. * Set correct C API error in RABIT initialization and finalization. * Remove redundant message. * Guard the tracker print C API.
This commit is contained in:
@@ -2,71 +2,16 @@ cmake_minimum_required(VERSION 3.3)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
add_library(rabit src/allreduce_base.cc src/engine.cc src/c_api.cc)
|
||||
add_library(rabit_mock_static src/allreduce_base.cc src/engine_mock.cc src/c_api.cc)
|
||||
set(RABIT_SOURCES
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/c_api.cc)
|
||||
|
||||
target_link_libraries(rabit Threads::Threads dmlc)
|
||||
target_link_libraries(rabit_mock_static Threads::Threads dmlc)
|
||||
if (RABIT_BUILD_MPI)
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mpi.cc)
|
||||
elseif (RABIT_MOCK)
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
||||
else ()
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine.cc)
|
||||
endif ()
|
||||
|
||||
set(rabit_libs rabit rabit_mock_static)
|
||||
set_target_properties(rabit rabit_mock_static
|
||||
PROPERTIES CXX_STANDARD 14
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if(RABIT_BUILD_MPI)
|
||||
find_package(MPI REQUIRED)
|
||||
if (NOT MPI_CXX_FOUND)
|
||||
message(FATAL_ERROR "CXX Interface for MPI is required for building MPI backend.")
|
||||
endif (NOT MPI_CXX_FOUND)
|
||||
add_library(rabit_mpi src/engine_mpi.cc ${MPI_INCLUDE_PATH})
|
||||
target_link_libraries(rabit_mpi ${MPI_CXX_LIBRARIES})
|
||||
list(APPEND rabit_libs rabit_mpi)
|
||||
endif()
|
||||
|
||||
# place binaries and libraries according to GNU standards
|
||||
include(GNUInstallDirs)
|
||||
|
||||
# we use this to get code coverage
|
||||
if ((CMAKE_CONFIGURATION_TYPES STREQUAL "Debug") AND (CMAKE_CXX_COMPILER_ID MATCHES GNU))
|
||||
foreach(lib ${rabit_libs})
|
||||
target_compile_options(${lib}
|
||||
-fprofile-arcs
|
||||
-ftest-coverage)
|
||||
endforeach()
|
||||
endif((CMAKE_CONFIGURATION_TYPES STREQUAL "Debug") AND (CMAKE_CXX_COMPILER_ID MATCHES GNU))
|
||||
|
||||
foreach(lib ${rabit_libs})
|
||||
target_include_directories(${lib} PUBLIC
|
||||
"$<BUILD_INTERFACE:${xgboost_SOURCE_DIR}/rabit/include>"
|
||||
"$<BUILD_INTERFACE:${xgboost_SOURCE_DIR}/dmlc-core/include>")
|
||||
endforeach()
|
||||
|
||||
if (GOOGLE_TEST AND (NOT WIN32))
|
||||
enable_testing()
|
||||
|
||||
# rabit mock based integration tests
|
||||
list(REMOVE_ITEM rabit_libs "rabit_mock_static") # remove here to avoid installing it
|
||||
set(tests lazy_recover local_recover model_recover)
|
||||
|
||||
foreach(test ${tests})
|
||||
add_executable(${test} test/${test}.cc)
|
||||
target_link_libraries(${test} rabit_mock_static)
|
||||
set_target_properties(${test} PROPERTIES CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON)
|
||||
add_test(NAME ${test} COMMAND ${test} WORKING_DIRECTORY ${xgboost_BINARY_DIR})
|
||||
endforeach()
|
||||
|
||||
if(RABIT_BUILD_MPI)
|
||||
add_executable(speed_test_mpi test/speed_test.cc)
|
||||
target_link_libraries(speed_test_mpi rabit_mpi)
|
||||
add_test(NAME speed_test_mpi COMMAND speed_test_mpi WORKING_DIRECTORY ${xgboost_BINARY_DIR})
|
||||
endif(RABIT_BUILD_MPI)
|
||||
endif (GOOGLE_TEST AND (NOT WIN32))
|
||||
|
||||
# Headers:
|
||||
set(include_install_dir "include")
|
||||
install(
|
||||
DIRECTORY "include/"
|
||||
DESTINATION "${include_install_dir}"
|
||||
FILES_MATCHING PATTERN "*.h"
|
||||
)
|
||||
set(RABIT_SOURCES ${RABIT_SOURCES} PARENT_SCOPE)
|
||||
|
||||
@@ -73,7 +73,7 @@ RABIT_DLL int RabitIsDistributed(void);
|
||||
* the user who monitors the tracker
|
||||
* \param msg the message to be printed
|
||||
*/
|
||||
RABIT_DLL void RabitTrackerPrint(const char *msg);
|
||||
RABIT_DLL int RabitTrackerPrint(const char *msg);
|
||||
/*!
|
||||
* \brief get name of processor
|
||||
* \param out_name hold output string
|
||||
|
||||
@@ -622,7 +622,9 @@ struct PollHelper {
|
||||
fdset.push_back(kv.second);
|
||||
}
|
||||
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
if (ret <= 0) {
|
||||
if (ret == 0) {
|
||||
LOG(FATAL) << "Poll timeout";
|
||||
} else if (ret < 0) {
|
||||
Socket::Error("Poll");
|
||||
} else {
|
||||
for (auto& pfd : fdset) {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include "dmlc/io.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include <cstdarg>
|
||||
|
||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
@@ -73,17 +74,14 @@ inline bool StringToBool(const char* s) {
|
||||
* \param msg error message
|
||||
*/
|
||||
inline void HandleAssertError(const char *msg) {
|
||||
fprintf(stderr,
|
||||
"AssertError:%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
/*!
|
||||
* \brief handling of Check error, caused by inappropriate input
|
||||
* \param msg error message
|
||||
*/
|
||||
inline void HandleCheckError(const char *msg) {
|
||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
|
||||
inline void HandlePrint(const char *msg) {
|
||||
@@ -154,13 +152,6 @@ inline void Error(const char *fmt, ...) {
|
||||
HandleCheckError(msg.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief replace fopen, report error when the file open fails */
|
||||
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||
std::FILE *fp = fopen64(fname, flag);
|
||||
Check(fp != nullptr, "can not open file \"%s\"\n", fname);
|
||||
return fp;
|
||||
}
|
||||
} // namespace utils
|
||||
|
||||
// Can not use std::min on Windows with msvc due to:
|
||||
|
||||
@@ -123,7 +123,9 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
bool AllreduceBase::Shutdown() {
|
||||
try {
|
||||
for (auto & all_link : all_links) {
|
||||
all_link.sock.Close();
|
||||
if (!all_link.sock.IsClosed()) {
|
||||
all_link.sock.Close();
|
||||
}
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
@@ -136,7 +138,7 @@ bool AllreduceBase::Shutdown() {
|
||||
utils::TCPSocket::Finalize();
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
|
||||
LOG(WARNING) << "Failed to shutdown due to" << e.what();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -217,7 +219,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
rabit_enable_tcp_no_delay = true;
|
||||
} else {
|
||||
rabit_enable_tcp_no_delay = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
@@ -586,7 +588,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
// eachreduce size
|
||||
if (max_reduce < total_size) {
|
||||
max_reduce = max_reduce - max_reduce % eachreduce;
|
||||
}
|
||||
}
|
||||
|
||||
// peform reduce, can be at most two rounds
|
||||
while (size_up_reduce < max_reduce) {
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "rabit/rabit.h"
|
||||
#include "rabit/c_api.h"
|
||||
|
||||
#include "../../src/c_api/c_api_error.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace c_api {
|
||||
// helper use to avoid BitOR operator
|
||||
@@ -219,11 +221,19 @@ struct WriteWrapper : public Serializable {
|
||||
} // namespace rabit
|
||||
|
||||
RABIT_DLL bool RabitInit(int argc, char *argv[]) {
|
||||
return rabit::Init(argc, argv);
|
||||
auto ret = rabit::Init(argc, argv);
|
||||
if (!ret) {
|
||||
XGBAPISetLastError("Failed to initialize RABIT.");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
RABIT_DLL bool RabitFinalize() {
|
||||
return rabit::Finalize();
|
||||
auto ret = rabit::Finalize();
|
||||
if (!ret) {
|
||||
XGBAPISetLastError("Failed to shutdown RABIT worker.");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitGetRingPrevRank() {
|
||||
@@ -242,9 +252,11 @@ RABIT_DLL int RabitIsDistributed() {
|
||||
return rabit::IsDistributed();
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitTrackerPrint(const char *msg) {
|
||||
RABIT_DLL int RabitTrackerPrint(const char *msg) {
|
||||
API_BEGIN()
|
||||
std::string m(msg);
|
||||
rabit::TrackerPrint(m);
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||
|
||||
Reference in New Issue
Block a user