[GPU-Plugin] Resolve double compilation issue (#2479)
This commit is contained in:
parent
5f1b0bb386
commit
ed8bc4521e
@ -80,10 +80,7 @@ set(RABIT_SOURCES
|
|||||||
rabit/src/c_api.cc
|
rabit/src/c_api.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
set(NCCL_SOURCES
|
file(GLOB CUDA_SOURCES
|
||||||
nccl/src/*.cu
|
|
||||||
)
|
|
||||||
set(UPDATER_GPU_SOURCES
|
|
||||||
plugin/updater_gpu/src/*.cu
|
plugin/updater_gpu/src/*.cu
|
||||||
plugin/updater_gpu/src/exact/*.cu
|
plugin/updater_gpu/src/exact/*.cu
|
||||||
)
|
)
|
||||||
@ -110,7 +107,6 @@ if(PLUGIN_UPDATER_GPU)
|
|||||||
find_package(CUDA REQUIRED)
|
find_package(CUDA REQUIRED)
|
||||||
|
|
||||||
# nccl
|
# nccl
|
||||||
set(LINK_LIBRARIES ${LINK_LIBRARIES} nccl)
|
|
||||||
add_subdirectory(nccl)
|
add_subdirectory(nccl)
|
||||||
set(NCCL_DIRECTORY ${PROJECT_SOURCE_DIR}/nccl)
|
set(NCCL_DIRECTORY ${PROJECT_SOURCE_DIR}/nccl)
|
||||||
include_directories(${NCCL_DIRECTORY}/src)
|
include_directories(${NCCL_DIRECTORY}/src)
|
||||||
@ -135,14 +131,11 @@ if(PLUGIN_UPDATER_GPU)
|
|||||||
if(NOT MSVC)
|
if(NOT MSVC)
|
||||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
|
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
|
||||||
endif()
|
endif()
|
||||||
set(CUDA_SOURCES
|
|
||||||
plugin/updater_gpu/src/updater_gpu.cu
|
|
||||||
plugin/updater_gpu/src/gpu_hist_builder.cu
|
|
||||||
)
|
|
||||||
# use below for forcing specific arch
|
|
||||||
cuda_compile(CUDA_OBJS ${CUDA_SOURCES} ${CUDA_NVCC_FLAGS})
|
|
||||||
|
|
||||||
|
|
||||||
|
cuda_add_library(gpuxgboost ${CUDA_SOURCES} STATIC)
|
||||||
|
target_link_libraries(gpuxgboost nccl)
|
||||||
|
list(APPEND LINK_LIBRARIES gpuxgboost)
|
||||||
|
list(APPEND SOURCES plugin/updater_gpu/src/register_updater_gpu.cc)
|
||||||
else()
|
else()
|
||||||
set(CUDA_OBJS "")
|
set(CUDA_OBJS "")
|
||||||
endif()
|
endif()
|
||||||
@ -150,13 +143,16 @@ endif()
|
|||||||
add_library(objxgboost OBJECT ${SOURCES})
|
add_library(objxgboost OBJECT ${SOURCES})
|
||||||
set_target_properties(${objxgboost} PROPERTIES POSITION_INDEPENDENT_CODE 1)
|
set_target_properties(${objxgboost} PROPERTIES POSITION_INDEPENDENT_CODE 1)
|
||||||
|
|
||||||
add_executable(runxgboost $<TARGET_OBJECTS:objxgboost> ${CUDA_OBJS})
|
add_executable(runxgboost $<TARGET_OBJECTS:objxgboost>)
|
||||||
set_target_properties(runxgboost PROPERTIES OUTPUT_NAME xgboost)
|
set_target_properties(runxgboost PROPERTIES OUTPUT_NAME xgboost)
|
||||||
target_link_libraries(runxgboost ${LINK_LIBRARIES})
|
target_link_libraries(runxgboost ${LINK_LIBRARIES})
|
||||||
|
|
||||||
add_library(xgboost SHARED $<TARGET_OBJECTS:objxgboost> ${CUDA_OBJS})
|
add_library(xgboost SHARED $<TARGET_OBJECTS:objxgboost>)
|
||||||
target_link_libraries(xgboost ${LINK_LIBRARIES})
|
target_link_libraries(xgboost ${LINK_LIBRARIES})
|
||||||
|
|
||||||
|
#Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
|
||||||
|
add_dependencies(xgboost runxgboost)
|
||||||
|
|
||||||
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
||||||
|
|
||||||
if(JVM_BINDINGS)
|
if(JVM_BINDINGS)
|
||||||
@ -166,7 +162,6 @@ if(JVM_BINDINGS)
|
|||||||
|
|
||||||
add_library(xgboost4j SHARED
|
add_library(xgboost4j SHARED
|
||||||
$<TARGET_OBJECTS:objxgboost>
|
$<TARGET_OBJECTS:objxgboost>
|
||||||
${CUDA_OBJS}
|
|
||||||
jvm-packages/xgboost4j/src/native/xgboost4j.cpp)
|
jvm-packages/xgboost4j/src/native/xgboost4j.cpp)
|
||||||
target_link_libraries(xgboost4j
|
target_link_libraries(xgboost4j
|
||||||
${LINK_LIBRARIES}
|
${LINK_LIBRARIES}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
|
|
||||||
PLUGIN_OBJS += build_plugin/updater_gpu/src/updater_gpu.o \
|
PLUGIN_OBJS += build_plugin/updater_gpu/src/register_updater_gpu.o \
|
||||||
|
build_plugin/updater_gpu/src/updater_gpu.o \
|
||||||
build_plugin/updater_gpu/src/gpu_hist_builder.o
|
build_plugin/updater_gpu/src/gpu_hist_builder.o
|
||||||
PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart
|
PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart
|
||||||
|
|||||||
@ -10,7 +10,6 @@
|
|||||||
#include "../../../src/tree/param.h"
|
#include "../../../src/tree/param.h"
|
||||||
#include "cub/cub.cuh"
|
#include "cub/cub.cuh"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "device_helpers.cuh"
|
|
||||||
#include "types.cuh"
|
#include "types.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -169,7 +168,7 @@ inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample) {
|
|||||||
|
|
||||||
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
|
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
|
||||||
CHECK_GT(features.size(), 0);
|
CHECK_GT(features.size(), 0);
|
||||||
int n = std::max(1,static_cast<int>(colsample * features.size()));
|
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||||
|
|
||||||
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
|
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
|
||||||
features.resize(n);
|
features.resize(n);
|
||||||
@ -198,17 +197,17 @@ struct GpairCallbackOp {
|
|||||||
* @param offsets the segments
|
* @param offsets the segments
|
||||||
*/
|
*/
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
void segmentedSort(dh::CubMemory& tmp_mem, dh::dvec2<T1>& keys,
|
void segmentedSort(dh::CubMemory* tmp_mem, dh::dvec2<T1>* keys,
|
||||||
dh::dvec2<T2>& vals, int nVals, int nSegs,
|
dh::dvec2<T2>* vals, int nVals, int nSegs,
|
||||||
dh::dvec<int>& offsets, int start = 0,
|
const dh::dvec<int>& offsets, int start = 0,
|
||||||
int end = sizeof(T1) * 8) {
|
int end = sizeof(T1) * 8) {
|
||||||
size_t tmpSize;
|
size_t tmpSize;
|
||||||
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
NULL, tmpSize, keys.buff(), vals.buff(), nVals, nSegs, offsets.data(),
|
NULL, tmpSize, keys->buff(), vals->buff(), nVals, nSegs, offsets.data(),
|
||||||
offsets.data() + 1, start, end));
|
offsets.data() + 1, start, end));
|
||||||
tmp_mem.LazyAllocate(tmpSize);
|
tmp_mem->LazyAllocate(tmpSize);
|
||||||
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
tmp_mem.d_temp_storage, tmpSize, keys.buff(), vals.buff(), nVals, nSegs,
|
tmp_mem->d_temp_storage, tmpSize, keys->buff(), vals->buff(), nVals, nSegs,
|
||||||
offsets.data(), offsets.data() + 1, start, end));
|
offsets.data(), offsets.data() + 1, start, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#include <thrust/random.h>
|
#include <thrust/random.h>
|
||||||
#include <thrust/system/cuda/error.h>
|
#include <thrust/system/cuda/error.h>
|
||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
|
#include "nccl.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
@ -15,13 +16,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#ifndef NCCL
|
|
||||||
#define NCCL 1
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (NCCL)
|
|
||||||
#include "nccl.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Uncomment to enable
|
// Uncomment to enable
|
||||||
// #define DEVICE_TIMER
|
// #define DEVICE_TIMER
|
||||||
@ -53,7 +47,6 @@ inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
|
|||||||
|
|
||||||
#define safe_nccl(ans) throw_on_nccl_error((ans), __FILE__, __LINE__)
|
#define safe_nccl(ans) throw_on_nccl_error((ans), __FILE__, __LINE__)
|
||||||
|
|
||||||
#if (NCCL)
|
|
||||||
inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
|
inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
|
||||||
int line) {
|
int line) {
|
||||||
if (code != ncclSuccess) {
|
if (code != ncclSuccess) {
|
||||||
@ -65,7 +58,6 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
|
|||||||
|
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
#define gpuErrchk(ans) \
|
#define gpuErrchk(ans) \
|
||||||
{ gpuAssert((ans), __FILE__, __LINE__); }
|
{ gpuAssert((ans), __FILE__, __LINE__); }
|
||||||
@ -87,13 +79,6 @@ inline int n_visible_devices() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline int n_devices_all(int n_gpus) {
|
inline int n_devices_all(int n_gpus) {
|
||||||
if (NCCL == 0 && n_gpus > 1 || NCCL == 0 && n_gpus != 0) {
|
|
||||||
if (n_gpus != 1 && n_gpus != 0) {
|
|
||||||
fprintf(stderr, "NCCL=0, so forcing n_gpus=1\n");
|
|
||||||
fflush(stderr);
|
|
||||||
}
|
|
||||||
n_gpus = 1;
|
|
||||||
}
|
|
||||||
int n_devices_visible = dh::n_visible_devices();
|
int n_devices_visible = dh::n_visible_devices();
|
||||||
int n_devices = n_gpus < 0 ? n_devices_visible : n_gpus;
|
int n_devices = n_gpus < 0 ? n_devices_visible : n_gpus;
|
||||||
return (n_devices);
|
return (n_devices);
|
||||||
@ -344,6 +329,8 @@ class dvec {
|
|||||||
|
|
||||||
T *data() { return _ptr; }
|
T *data() { return _ptr; }
|
||||||
|
|
||||||
|
const T *data() const { return _ptr; }
|
||||||
|
|
||||||
std::vector<T> as_vector() const {
|
std::vector<T> as_vector() const {
|
||||||
std::vector<T> h_vector(size());
|
std::vector<T> h_vector(size());
|
||||||
safe_cuda(cudaSetDevice(_device_idx));
|
safe_cuda(cudaSetDevice(_device_idx));
|
||||||
|
|||||||
@ -60,7 +60,7 @@ DEV_INLINE void atomicArgMax(Split* address, Split val) {
|
|||||||
do {
|
do {
|
||||||
assumed = old;
|
assumed = old;
|
||||||
Split res = maxSplit(val, *(Split*)&assumed);
|
Split res = maxSplit(val, *(Split*)&assumed);
|
||||||
old = atomicCAS(intAddress, assumed, *(unsigned long long*)&res);
|
old = atomicCAS(intAddress, assumed, *(uint64_t*)&res);
|
||||||
} while (assumed != old);
|
} while (assumed != old);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ __global__ void atomicArgMaxByKeySmem(
|
|||||||
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
|
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
|
||||||
const TrainParam param) {
|
const TrainParam param) {
|
||||||
extern __shared__ char sArr[];
|
extern __shared__ char sArr[];
|
||||||
Split* sNodeSplits = (Split*)sArr;
|
Split* sNodeSplits = reinterpret_cast<Split*>(sArr);
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
Split defVal;
|
Split defVal;
|
||||||
#pragma unroll 1
|
#pragma unroll 1
|
||||||
@ -176,7 +176,7 @@ void argMaxByKey(Split* nodeSplits, const bst_gpair* gradScans,
|
|||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("argMaxByKey: Bad algo passed!");
|
throw std::runtime_error("argMaxByKey: Bad algo passed!");
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace exact
|
} // namespace exact
|
||||||
|
|||||||
@ -143,7 +143,7 @@ __global__ void cubScanByKeyL3(bst_gpair* sums, bst_gpair* scans,
|
|||||||
// (potential race between threads)
|
// (potential race between threads)
|
||||||
__shared__ char gradBuff[sizeof(bst_gpair)];
|
__shared__ char gradBuff[sizeof(bst_gpair)];
|
||||||
__shared__ int s_mKeys;
|
__shared__ int s_mKeys;
|
||||||
bst_gpair* s_mScans = (bst_gpair*)gradBuff;
|
bst_gpair* s_mScans = reinterpret_cast<bst_gpair*>(gradBuff);
|
||||||
if (tid >= size) return;
|
if (tid >= size) return;
|
||||||
// cache block-wide partial scan info
|
// cache block-wide partial scan info
|
||||||
if (relId == 0) {
|
if (relId == 0) {
|
||||||
|
|||||||
@ -16,14 +16,14 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../../../src/tree/param.h"
|
#include "../../../../src/tree/param.h"
|
||||||
#include "../common.cuh"
|
#include "../common.cuh"
|
||||||
#include <vector>
|
|
||||||
#include "node.cuh"
|
|
||||||
#include "split2node.cuh"
|
|
||||||
#include "argmax_by_key.cuh"
|
#include "argmax_by_key.cuh"
|
||||||
#include "fused_scan_reduce_by_key.cuh"
|
#include "fused_scan_reduce_by_key.cuh"
|
||||||
|
#include "node.cuh"
|
||||||
|
#include "split2node.cuh"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -36,8 +36,8 @@ __global__ void initRootNode(Node<node_id_t>* nodes, const bst_gpair* sums,
|
|||||||
// gradients already evaluated inside transferGrads
|
// gradients already evaluated inside transferGrads
|
||||||
Node<node_id_t> n;
|
Node<node_id_t> n;
|
||||||
n.gradSum = sums[0];
|
n.gradSum = sums[0];
|
||||||
n.score = CalcGain(param, n.gradSum.grad , n.gradSum.hess);
|
n.score = CalcGain(param, n.gradSum.grad, n.gradSum.hess);
|
||||||
n.weight = CalcWeight(param, n.gradSum.grad , n.gradSum.hess);
|
n.weight = CalcWeight(param, n.gradSum.grad, n.gradSum.hess);
|
||||||
n.id = 0;
|
n.id = 0;
|
||||||
nodes[0] = n;
|
nodes[0] = n;
|
||||||
}
|
}
|
||||||
@ -173,7 +173,7 @@ class GPUBuilder {
|
|||||||
}
|
}
|
||||||
// mark all the used nodes with unused children as leaf nodes
|
// mark all the used nodes with unused children as leaf nodes
|
||||||
markLeaves();
|
markLeaves();
|
||||||
dense2sparse(*hTree);
|
dense2sparse(hTree);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -299,7 +299,8 @@ class GPUBuilder {
|
|||||||
vals.current_dvec() = fval;
|
vals.current_dvec() = fval;
|
||||||
instIds.current_dvec() = fId;
|
instIds.current_dvec() = fId;
|
||||||
colOffsets = offset;
|
colOffsets = offset;
|
||||||
segmentedSort<float, int>(tmp_mem, vals, instIds, nVals, nCols, colOffsets);
|
segmentedSort<float, int>(&tmp_mem, &vals, &instIds, nVals, nCols,
|
||||||
|
colOffsets);
|
||||||
vals_cached = vals.current_dvec();
|
vals_cached = vals.current_dvec();
|
||||||
instIds_cached = instIds.current_dvec();
|
instIds_cached = instIds.current_dvec();
|
||||||
assignColIds<node_id_t><<<nCols, 512>>>(colIds.data(), colOffsets.data());
|
assignColIds<node_id_t><<<nCols, 512>>>(colIds.data(), colOffsets.data());
|
||||||
@ -347,8 +348,8 @@ class GPUBuilder {
|
|||||||
void sortKeys(int level) {
|
void sortKeys(int level) {
|
||||||
// segmented-sort the arrays based on node-id's
|
// segmented-sort the arrays based on node-id's
|
||||||
// but we don't need more than level+1 bits for sorting!
|
// but we don't need more than level+1 bits for sorting!
|
||||||
segmentedSort(tmp_mem, nodeAssigns, nodeLocations, nVals, nCols, colOffsets,
|
segmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols,
|
||||||
0, level + 1);
|
colOffsets, 0, level + 1);
|
||||||
gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
|
gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
|
||||||
vals.current(), instIds.other(), instIds.current(),
|
vals.current(), instIds.other(), instIds.current(),
|
||||||
nodeLocations.current(), nVals);
|
nodeLocations.current(), nVals);
|
||||||
@ -362,7 +363,8 @@ class GPUBuilder {
|
|||||||
markLeavesKernel<<<nBlks, BlkDim>>>(nodes.data(), maxNodes);
|
markLeavesKernel<<<nBlks, BlkDim>>>(nodes.data(), maxNodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dense2sparse(RegTree& tree) {
|
void dense2sparse(RegTree* p_tree) {
|
||||||
|
RegTree& tree = *p_tree;
|
||||||
std::vector<Node<node_id_t>> hNodes = nodes.as_vector();
|
std::vector<Node<node_id_t>> hNodes = nodes.as_vector();
|
||||||
int nodeId = 0;
|
int nodeId = 0;
|
||||||
for (int i = 0; i < maxNodes; ++i) {
|
for (int i = 0; i < maxNodes; ++i) {
|
||||||
|
|||||||
@ -23,7 +23,7 @@ void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
|||||||
CHECK_EQ(gidx.size(), end - begin) << "gidx must be externally allocated";
|
CHECK_EQ(gidx.size(), end - begin) << "gidx must be externally allocated";
|
||||||
CHECK_EQ(ridx.size(), end - begin) << "ridx must be externally allocated";
|
CHECK_EQ(ridx.size(), end - begin) << "ridx must be externally allocated";
|
||||||
|
|
||||||
thrust::copy(&gmat.index[begin], &gmat.index[end], gidx.tbegin());
|
thrust::copy(gmat.index.data() + begin, gmat.index.data() + end, gidx.tbegin());
|
||||||
thrust::device_vector<int> row_ptr = gmat.row_ptr;
|
thrust::device_vector<int> row_ptr = gmat.row_ptr;
|
||||||
|
|
||||||
auto counting = thrust::make_counting_iterator(begin);
|
auto counting = thrust::make_counting_iterator(begin);
|
||||||
@ -77,7 +77,6 @@ GPUHistBuilder::GPUHistBuilder()
|
|||||||
prediction_cache_initialised(false) {}
|
prediction_cache_initialised(false) {}
|
||||||
|
|
||||||
GPUHistBuilder::~GPUHistBuilder() {
|
GPUHistBuilder::~GPUHistBuilder() {
|
||||||
#if (NCCL)
|
|
||||||
if (initialised) {
|
if (initialised) {
|
||||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||||
ncclCommDestroy(comms[d_idx]);
|
ncclCommDestroy(comms[d_idx]);
|
||||||
@ -92,7 +91,6 @@ GPUHistBuilder::~GPUHistBuilder() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUHistBuilder::Init(const TrainParam& param) {
|
void GPUHistBuilder::Init(const TrainParam& param) {
|
||||||
@ -103,7 +101,7 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
|||||||
|
|
||||||
CHECK(param.n_gpus != 0) << "Must have at least one device";
|
CHECK(param.n_gpus != 0) << "Must have at least one device";
|
||||||
int n_devices_all = dh::n_devices_all(param.n_gpus);
|
int n_devices_all = dh::n_devices_all(param.n_gpus);
|
||||||
for (int device_idx = 0; device_idx < n_devices; device_idx++) {
|
for (int device_idx = 0; device_idx < n_devices_all; device_idx++) {
|
||||||
if (!param.silent) {
|
if (!param.silent) {
|
||||||
size_t free_memory = dh::available_memory(device_idx);
|
size_t free_memory = dh::available_memory(device_idx);
|
||||||
const int mb_size = 1048576;
|
const int mb_size = 1048576;
|
||||||
@ -129,7 +127,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
dList[d_idx] = device_idx;
|
dList[d_idx] = device_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (NCCL)
|
|
||||||
// initialize nccl
|
// initialize nccl
|
||||||
|
|
||||||
comms.resize(n_devices);
|
comms.resize(n_devices);
|
||||||
@ -173,7 +170,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
// process)
|
// process)
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
|
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
|
||||||
"block. Try setting 'tree_method' "
|
"block. Try setting 'tree_method' "
|
||||||
@ -376,7 +372,6 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
|
|
||||||
// time.printElapsed("Add Time");
|
// time.printElapsed("Add Time");
|
||||||
|
|
||||||
#if (NCCL)
|
|
||||||
// (in-place) reduce each element of histogram (for only current level) across
|
// (in-place) reduce each element of histogram (for only current level) across
|
||||||
// multiple gpus
|
// multiple gpus
|
||||||
// TODO(JCM): use out of place with pre-allocated buffer, but then have to
|
// TODO(JCM): use out of place with pre-allocated buffer, but then have to
|
||||||
@ -398,9 +393,7 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
|
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
// if no NCCL, then presume only 1 GPU, then already correct
|
// if no NCCL, then presume only 1 GPU, then already correct
|
||||||
#endif
|
|
||||||
|
|
||||||
// time.printElapsed("Reduce-Add Time");
|
// time.printElapsed("Reduce-Add Time");
|
||||||
|
|
||||||
@ -626,13 +619,12 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int dosimuljob = 1;
|
int dosimuljob = 1;
|
||||||
|
|
||||||
#if (NCCL)
|
|
||||||
int simuljob = 1; // whether to do job on single GPU and broadcast (0) or to
|
int simuljob = 1; // whether to do job on single GPU and broadcast (0) or to
|
||||||
// do same job on each GPU (1) (could make user parameter,
|
// do same job on each GPU (1) (could make user parameter,
|
||||||
// but too fine-grained maybe)
|
// but too fine-grained maybe)
|
||||||
int findsplit_shardongpus = 0; // too expensive generally, disable for now
|
int findsplit_shardongpus = 0; // too expensive generally, disable for now
|
||||||
|
|
||||||
if (NCCL && findsplit_shardongpus) {
|
if (findsplit_shardongpus) {
|
||||||
dosimuljob = 0;
|
dosimuljob = 0;
|
||||||
// use power of 2 for split finder because nodes are power of 2 (broadcast
|
// use power of 2 for split finder because nodes are power of 2 (broadcast
|
||||||
// result to remaining devices)
|
// result to remaining devices)
|
||||||
@ -739,7 +731,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
|
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (simuljob == 0 && NCCL == 1) {
|
} else if (simuljob == 0) {
|
||||||
dosimuljob = 0;
|
dosimuljob = 0;
|
||||||
int num_nodes_device = n_nodes_level(depth);
|
int num_nodes_device = n_nodes_level(depth);
|
||||||
const int GRID_SIZE = num_nodes_device;
|
const int GRID_SIZE = num_nodes_device;
|
||||||
@ -792,7 +784,6 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
} else {
|
} else {
|
||||||
dosimuljob = 1;
|
dosimuljob = 1;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
if (dosimuljob) { // if no NCCL or simuljob==1, do this
|
if (dosimuljob) { // if no NCCL or simuljob==1, do this
|
||||||
int num_nodes_device = n_nodes_level(depth);
|
int num_nodes_device = n_nodes_level(depth);
|
||||||
|
|||||||
20
plugin/updater_gpu/src/register_updater_gpu.cc
Normal file
20
plugin/updater_gpu/src/register_updater_gpu.cc
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
#include "updater_gpu.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
|
||||||
|
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
|
||||||
|
.describe("Grow tree with GPU.")
|
||||||
|
.set_body([]() { return new GPUMaker(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||||
|
.describe("Grow tree with GPU.")
|
||||||
|
.set_body([]() { return new GPUHistMaker(); });
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,8 +1,11 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017 XGBoost contributors
|
* Copyright 2017 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#include "updater_gpu.cuh"
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
#include "../../../src/common/random.h"
|
#include "../../../src/common/random.h"
|
||||||
#include "../../../src/common/sync.h"
|
#include "../../../src/common/sync.h"
|
||||||
#include "../../../src/tree/param.h"
|
#include "../../../src/tree/param.h"
|
||||||
@ -11,87 +14,64 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
|
|
||||||
|
|
||||||
/*! \brief column-wise update to construct a tree */
|
GPUMaker::GPUMaker() : builder(new exact::GPUBuilder<int16_t>()) {}
|
||||||
template <typename TStats>
|
|
||||||
class GPUMaker : public TreeUpdater {
|
|
||||||
public:
|
|
||||||
void Init(
|
|
||||||
const std::vector<std::pair<std::string, std::string>>& args) override {
|
|
||||||
param.InitAllowUnknown(args);
|
|
||||||
builder.Init(param);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
void GPUMaker::Init(
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<std::pair<std::string, std::string>>& args) {
|
||||||
TStats::CheckInfo(dmat->info());
|
param.InitAllowUnknown(args);
|
||||||
// rescale learning rate according to size of trees
|
builder->Init(param);
|
||||||
float lr = param.learning_rate;
|
}
|
||||||
param.learning_rate = lr / trees.size();
|
|
||||||
builder.UpdateParam(param);
|
|
||||||
|
|
||||||
try {
|
void GPUMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||||
// build tree
|
const std::vector<RegTree*>& trees) {
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
GradStats::CheckInfo(dmat->info());
|
||||||
builder.Update(gpair, dmat, trees[i]);
|
// rescale learning rate according to size of trees
|
||||||
}
|
float lr = param.learning_rate;
|
||||||
} catch (const std::exception& e) {
|
param.learning_rate = lr / trees.size();
|
||||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
builder->UpdateParam(param);
|
||||||
}
|
|
||||||
param.learning_rate = lr;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
try {
|
||||||
// training parameter
|
|
||||||
TrainParam param;
|
|
||||||
exact::GPUBuilder<int16_t> builder;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TStats>
|
|
||||||
class GPUHistMaker : public TreeUpdater {
|
|
||||||
public:
|
|
||||||
void Init(
|
|
||||||
const std::vector<std::pair<std::string, std::string>>& args) override {
|
|
||||||
param.InitAllowUnknown(args);
|
|
||||||
builder.Init(param);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
|
||||||
const std::vector<RegTree*>& trees) override {
|
|
||||||
TStats::CheckInfo(dmat->info());
|
|
||||||
// rescale learning rate according to size of trees
|
|
||||||
float lr = param.learning_rate;
|
|
||||||
param.learning_rate = lr / trees.size();
|
|
||||||
builder.UpdateParam(param);
|
|
||||||
// build tree
|
// build tree
|
||||||
try {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
builder->Update(gpair, dmat, trees[i]);
|
||||||
builder.Update(gpair, dmat, trees[i]);
|
|
||||||
}
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
|
||||||
}
|
}
|
||||||
param.learning_rate = lr;
|
} catch (const std::exception& e) {
|
||||||
|
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
|
param.learning_rate = lr;
|
||||||
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix* data,
|
GPUHistMaker::GPUHistMaker() : builder(new GPUHistBuilder()) {}
|
||||||
std::vector<bst_float>* out_preds) override {
|
|
||||||
return builder.UpdatePredictionCache(data, out_preds);
|
void GPUHistMaker::Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string>>& args) {
|
||||||
|
param.InitAllowUnknown(args);
|
||||||
|
builder->Init(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUHistMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||||
|
const std::vector<RegTree*>& trees) {
|
||||||
|
GradStats::CheckInfo(dmat->info());
|
||||||
|
// rescale learning rate according to size of trees
|
||||||
|
float lr = param.learning_rate;
|
||||||
|
param.learning_rate = lr / trees.size();
|
||||||
|
builder->UpdateParam(param);
|
||||||
|
// build tree
|
||||||
|
try {
|
||||||
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
|
builder->Update(gpair, dmat, trees[i]);
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
|
param.learning_rate = lr;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
bool GPUHistMaker::UpdatePredictionCache(const DMatrix* data,
|
||||||
// training parameter
|
std::vector<bst_float>* out_preds) {
|
||||||
TrainParam param;
|
return builder->UpdatePredictionCache(data, out_preds);
|
||||||
GPUHistBuilder builder;
|
}
|
||||||
};
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
|
|
||||||
.describe("Grow tree with GPU.")
|
|
||||||
.set_body([]() { return new GPUMaker<GradStats>(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
|
||||||
.describe("Grow tree with GPU.")
|
|
||||||
.set_body([]() { return new GPUHistMaker<GradStats>(); });
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
48
plugin/updater_gpu/src/updater_gpu.cuh
Normal file
48
plugin/updater_gpu/src/updater_gpu.cuh
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
|
||||||
|
/*!
|
||||||
|
* Copyright 2017 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
#include <memory>
|
||||||
|
#include "../../../src/tree/param.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
// Forward declare builder classes
|
||||||
|
class GPUHistBuilder;
|
||||||
|
namespace exact {
|
||||||
|
template <typename node_id_t>
|
||||||
|
class GPUBuilder;
|
||||||
|
}
|
||||||
|
|
||||||
|
class GPUMaker : public TreeUpdater {
|
||||||
|
protected:
|
||||||
|
TrainParam param;
|
||||||
|
std::unique_ptr<exact::GPUBuilder<int16_t>> builder;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GPUMaker();
|
||||||
|
void Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string>>& args) override;
|
||||||
|
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||||
|
const std::vector<RegTree*>& trees);
|
||||||
|
};
|
||||||
|
|
||||||
|
class GPUHistMaker : public TreeUpdater {
|
||||||
|
public:
|
||||||
|
GPUHistMaker();
|
||||||
|
void Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string>>& args) override;
|
||||||
|
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||||
|
const std::vector<RegTree*>& trees) override;
|
||||||
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
|
std::vector<bst_float>* out_preds) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TrainParam param;
|
||||||
|
std::unique_ptr<GPUHistBuilder> builder;
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user