[GPU-Plugin] Resolve double compilation issue (#2479)

This commit is contained in:
Rory Mitchell 2017-07-03 13:29:10 +12:00 committed by GitHub
parent 5f1b0bb386
commit ed8bc4521e
11 changed files with 161 additions and 138 deletions

View File

@ -80,10 +80,7 @@ set(RABIT_SOURCES
rabit/src/c_api.cc
)
set(NCCL_SOURCES
nccl/src/*.cu
)
set(UPDATER_GPU_SOURCES
file(GLOB CUDA_SOURCES
plugin/updater_gpu/src/*.cu
plugin/updater_gpu/src/exact/*.cu
)
@ -110,7 +107,6 @@ if(PLUGIN_UPDATER_GPU)
find_package(CUDA REQUIRED)
# nccl
set(LINK_LIBRARIES ${LINK_LIBRARIES} nccl)
add_subdirectory(nccl)
set(NCCL_DIRECTORY ${PROJECT_SOURCE_DIR}/nccl)
include_directories(${NCCL_DIRECTORY}/src)
@ -135,14 +131,11 @@ if(PLUGIN_UPDATER_GPU)
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
endif()
set(CUDA_SOURCES
plugin/updater_gpu/src/updater_gpu.cu
plugin/updater_gpu/src/gpu_hist_builder.cu
)
# use below for forcing specific arch
cuda_compile(CUDA_OBJS ${CUDA_SOURCES} ${CUDA_NVCC_FLAGS})
cuda_add_library(gpuxgboost ${CUDA_SOURCES} STATIC)
target_link_libraries(gpuxgboost nccl)
list(APPEND LINK_LIBRARIES gpuxgboost)
list(APPEND SOURCES plugin/updater_gpu/src/register_updater_gpu.cc)
else()
set(CUDA_OBJS "")
endif()
@ -150,13 +143,16 @@ endif()
add_library(objxgboost OBJECT ${SOURCES})
set_target_properties(${objxgboost} PROPERTIES POSITION_INDEPENDENT_CODE 1)
add_executable(runxgboost $<TARGET_OBJECTS:objxgboost> ${CUDA_OBJS})
add_executable(runxgboost $<TARGET_OBJECTS:objxgboost>)
set_target_properties(runxgboost PROPERTIES OUTPUT_NAME xgboost)
target_link_libraries(runxgboost ${LINK_LIBRARIES})
add_library(xgboost SHARED $<TARGET_OBJECTS:objxgboost> ${CUDA_OBJS})
add_library(xgboost SHARED $<TARGET_OBJECTS:objxgboost>)
target_link_libraries(xgboost ${LINK_LIBRARIES})
#Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
add_dependencies(xgboost runxgboost)
option(JVM_BINDINGS "Build JVM bindings" OFF)
if(JVM_BINDINGS)
@ -166,7 +162,6 @@ if(JVM_BINDINGS)
add_library(xgboost4j SHARED
$<TARGET_OBJECTS:objxgboost>
${CUDA_OBJS}
jvm-packages/xgboost4j/src/native/xgboost4j.cpp)
target_link_libraries(xgboost4j
${LINK_LIBRARIES}

View File

@ -1,4 +1,5 @@
PLUGIN_OBJS += build_plugin/updater_gpu/src/updater_gpu.o \
PLUGIN_OBJS += build_plugin/updater_gpu/src/register_updater_gpu.o \
build_plugin/updater_gpu/src/updater_gpu.o \
build_plugin/updater_gpu/src/gpu_hist_builder.o
PLUGIN_LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart

View File

@ -10,7 +10,6 @@
#include "../../../src/tree/param.h"
#include "cub/cub.cuh"
#include "device_helpers.cuh"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
@ -169,7 +168,7 @@ inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample) {
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
CHECK_GT(features.size(), 0);
int n = std::max(1,static_cast<int>(colsample * features.size()));
int n = std::max(1, static_cast<int>(colsample * features.size()));
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
@ -198,17 +197,17 @@ struct GpairCallbackOp {
* @param offsets the segments
*/
template <typename T1, typename T2>
void segmentedSort(dh::CubMemory& tmp_mem, dh::dvec2<T1>& keys,
dh::dvec2<T2>& vals, int nVals, int nSegs,
dh::dvec<int>& offsets, int start = 0,
void segmentedSort(dh::CubMemory* tmp_mem, dh::dvec2<T1>* keys,
dh::dvec2<T2>* vals, int nVals, int nSegs,
const dh::dvec<int>& offsets, int start = 0,
int end = sizeof(T1) * 8) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
NULL, tmpSize, keys.buff(), vals.buff(), nVals, nSegs, offsets.data(),
NULL, tmpSize, keys->buff(), vals->buff(), nVals, nSegs, offsets.data(),
offsets.data() + 1, start, end));
tmp_mem.LazyAllocate(tmpSize);
tmp_mem->LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
tmp_mem.d_temp_storage, tmpSize, keys.buff(), vals.buff(), nVals, nSegs,
tmp_mem->d_temp_storage, tmpSize, keys->buff(), vals->buff(), nVals, nSegs,
offsets.data(), offsets.data() + 1, start, end));
}

View File

@ -6,6 +6,7 @@
#include <thrust/random.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include "nccl.h"
#include <algorithm>
#include <chrono>
#include <ctime>
@ -15,13 +16,6 @@
#include <string>
#include <vector>
#ifndef NCCL
#define NCCL 1
#endif
#if (NCCL)
#include "nccl.h"
#endif
// Uncomment to enable
// #define DEVICE_TIMER
@ -53,7 +47,6 @@ inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
#define safe_nccl(ans) throw_on_nccl_error((ans), __FILE__, __LINE__)
#if (NCCL)
inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
int line) {
if (code != ncclSuccess) {
@ -65,7 +58,6 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
return code;
}
#endif
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
@ -87,13 +79,6 @@ inline int n_visible_devices() {
}
inline int n_devices_all(int n_gpus) {
if (NCCL == 0 && n_gpus > 1 || NCCL == 0 && n_gpus != 0) {
if (n_gpus != 1 && n_gpus != 0) {
fprintf(stderr, "NCCL=0, so forcing n_gpus=1\n");
fflush(stderr);
}
n_gpus = 1;
}
int n_devices_visible = dh::n_visible_devices();
int n_devices = n_gpus < 0 ? n_devices_visible : n_gpus;
return (n_devices);
@ -344,6 +329,8 @@ class dvec {
T *data() { return _ptr; }
const T *data() const { return _ptr; }
std::vector<T> as_vector() const {
std::vector<T> h_vector(size());
safe_cuda(cudaSetDevice(_device_idx));

View File

@ -60,7 +60,7 @@ DEV_INLINE void atomicArgMax(Split* address, Split val) {
do {
assumed = old;
Split res = maxSplit(val, *(Split*)&assumed);
old = atomicCAS(intAddress, assumed, *(unsigned long long*)&res);
old = atomicCAS(intAddress, assumed, *(uint64_t*)&res);
} while (assumed != old);
}
@ -115,7 +115,7 @@ __global__ void atomicArgMaxByKeySmem(
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
const TrainParam param) {
extern __shared__ char sArr[];
Split* sNodeSplits = (Split*)sArr;
Split* sNodeSplits = reinterpret_cast<Split*>(sArr);
int tid = threadIdx.x;
Split defVal;
#pragma unroll 1
@ -176,7 +176,7 @@ void argMaxByKey(Split* nodeSplits, const bst_gpair* gradScans,
break;
default:
throw std::runtime_error("argMaxByKey: Bad algo passed!");
};
}
}
} // namespace exact

View File

@ -143,7 +143,7 @@ __global__ void cubScanByKeyL3(bst_gpair* sums, bst_gpair* scans,
// (potential race between threads)
__shared__ char gradBuff[sizeof(bst_gpair)];
__shared__ int s_mKeys;
bst_gpair* s_mScans = (bst_gpair*)gradBuff;
bst_gpair* s_mScans = reinterpret_cast<bst_gpair*>(gradBuff);
if (tid >= size) return;
// cache block-wide partial scan info
if (relId == 0) {

View File

@ -16,14 +16,14 @@
*/
#pragma once
#include <string>
#include <vector>
#include "../../../../src/tree/param.h"
#include "../common.cuh"
#include <vector>
#include "node.cuh"
#include "split2node.cuh"
#include "argmax_by_key.cuh"
#include "fused_scan_reduce_by_key.cuh"
#include "node.cuh"
#include "split2node.cuh"
#include "xgboost/tree_updater.h"
namespace xgboost {
@ -36,8 +36,8 @@ __global__ void initRootNode(Node<node_id_t>* nodes, const bst_gpair* sums,
// gradients already evaluated inside transferGrads
Node<node_id_t> n;
n.gradSum = sums[0];
n.score = CalcGain(param, n.gradSum.grad , n.gradSum.hess);
n.weight = CalcWeight(param, n.gradSum.grad , n.gradSum.hess);
n.score = CalcGain(param, n.gradSum.grad, n.gradSum.hess);
n.weight = CalcWeight(param, n.gradSum.grad, n.gradSum.hess);
n.id = 0;
nodes[0] = n;
}
@ -173,7 +173,7 @@ class GPUBuilder {
}
// mark all the used nodes with unused children as leaf nodes
markLeaves();
dense2sparse(*hTree);
dense2sparse(hTree);
}
private:
@ -299,7 +299,8 @@ class GPUBuilder {
vals.current_dvec() = fval;
instIds.current_dvec() = fId;
colOffsets = offset;
segmentedSort<float, int>(tmp_mem, vals, instIds, nVals, nCols, colOffsets);
segmentedSort<float, int>(&tmp_mem, &vals, &instIds, nVals, nCols,
colOffsets);
vals_cached = vals.current_dvec();
instIds_cached = instIds.current_dvec();
assignColIds<node_id_t><<<nCols, 512>>>(colIds.data(), colOffsets.data());
@ -347,8 +348,8 @@ class GPUBuilder {
void sortKeys(int level) {
// segmented-sort the arrays based on node-id's
// but we don't need more than level+1 bits for sorting!
segmentedSort(tmp_mem, nodeAssigns, nodeLocations, nVals, nCols, colOffsets,
0, level + 1);
segmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols,
colOffsets, 0, level + 1);
gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
vals.current(), instIds.other(), instIds.current(),
nodeLocations.current(), nVals);
@ -362,7 +363,8 @@ class GPUBuilder {
markLeavesKernel<<<nBlks, BlkDim>>>(nodes.data(), maxNodes);
}
void dense2sparse(RegTree& tree) {
void dense2sparse(RegTree* p_tree) {
RegTree& tree = *p_tree;
std::vector<Node<node_id_t>> hNodes = nodes.as_vector();
int nodeId = 0;
for (int i = 0; i < maxNodes; ++i) {

View File

@ -23,7 +23,7 @@ void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
CHECK_EQ(gidx.size(), end - begin) << "gidx must be externally allocated";
CHECK_EQ(ridx.size(), end - begin) << "ridx must be externally allocated";
thrust::copy(&gmat.index[begin], &gmat.index[end], gidx.tbegin());
thrust::copy(gmat.index.data() + begin, gmat.index.data() + end, gidx.tbegin());
thrust::device_vector<int> row_ptr = gmat.row_ptr;
auto counting = thrust::make_counting_iterator(begin);
@ -77,7 +77,6 @@ GPUHistBuilder::GPUHistBuilder()
prediction_cache_initialised(false) {}
GPUHistBuilder::~GPUHistBuilder() {
#if (NCCL)
if (initialised) {
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
ncclCommDestroy(comms[d_idx]);
@ -92,7 +91,6 @@ GPUHistBuilder::~GPUHistBuilder() {
}
}
}
#endif
}
void GPUHistBuilder::Init(const TrainParam& param) {
@ -103,7 +101,7 @@ void GPUHistBuilder::Init(const TrainParam& param) {
CHECK(param.n_gpus != 0) << "Must have at least one device";
int n_devices_all = dh::n_devices_all(param.n_gpus);
for (int device_idx = 0; device_idx < n_devices; device_idx++) {
for (int device_idx = 0; device_idx < n_devices_all; device_idx++) {
if (!param.silent) {
size_t free_memory = dh::available_memory(device_idx);
const int mb_size = 1048576;
@ -129,7 +127,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
dList[d_idx] = device_idx;
}
#if (NCCL)
// initialize nccl
comms.resize(n_devices);
@ -173,7 +170,6 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
// process)
}
#endif
CHECK(fmat.SingleColBlock()) << "grow_gpu_hist: must have single column "
"block. Try setting 'tree_method' "
@ -376,7 +372,6 @@ void GPUHistBuilder::BuildHist(int depth) {
// time.printElapsed("Add Time");
#if (NCCL)
// (in-place) reduce each element of histogram (for only current level) across
// multiple gpus
// TODO(JCM): use out of place with pre-allocated buffer, but then have to
@ -398,9 +393,7 @@ void GPUHistBuilder::BuildHist(int depth) {
dh::safe_cuda(cudaSetDevice(device_idx));
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
}
#else
// if no NCCL, then presume only 1 GPU, then already correct
#endif
// time.printElapsed("Reduce-Add Time");
@ -626,13 +619,12 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
int dosimuljob = 1;
#if (NCCL)
int simuljob = 1; // whether to do job on single GPU and broadcast (0) or to
// do same job on each GPU (1) (could make user parameter,
// but too fine-grained maybe)
int findsplit_shardongpus = 0; // too expensive generally, disable for now
if (NCCL && findsplit_shardongpus) {
if (findsplit_shardongpus) {
dosimuljob = 0;
// use power of 2 for split finder because nodes are power of 2 (broadcast
// result to remaining devices)
@ -739,7 +731,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
}
}
} else if (simuljob == 0 && NCCL == 1) {
} else if (simuljob == 0) {
dosimuljob = 0;
int num_nodes_device = n_nodes_level(depth);
const int GRID_SIZE = num_nodes_device;
@ -792,7 +784,6 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
} else {
dosimuljob = 1;
}
#endif
if (dosimuljob) { // if no NCCL or simuljob==1, do this
int num_nodes_device = n_nodes_level(depth);

View File

@ -0,0 +1,20 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <xgboost/tree_updater.h>
#include "updater_gpu.cuh"
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker(); });
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker(); });
} // namespace tree
} // namespace xgboost

View File

@ -1,8 +1,11 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include "updater_gpu.cuh"
#include <xgboost/tree_updater.h>
#include <vector>
#include <utility>
#include <string>
#include "../../../src/common/random.h"
#include "../../../src/common/sync.h"
#include "../../../src/tree/param.h"
@ -11,87 +14,64 @@
namespace xgboost {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpumaker);
/*! \brief column-wise update to construct a tree */
template <typename TStats>
class GPUMaker : public TreeUpdater {
public:
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
builder.Init(param);
}
GPUMaker::GPUMaker() : builder(new exact::GPUBuilder<int16_t>()) {}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder.UpdateParam(param);
void GPUMaker::Init(
const std::vector<std::pair<std::string, std::string>>& args) {
param.InitAllowUnknown(args);
builder->Init(param);
}
try {
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
void GPUMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder->UpdateParam(param);
protected:
// training parameter
TrainParam param;
exact::GPUBuilder<int16_t> builder;
};
template <typename TStats>
class GPUHistMaker : public TreeUpdater {
public:
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
builder.Init(param);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder.UpdateParam(param);
try {
// build tree
try {
for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
for (size_t i = 0; i < trees.size(); ++i) {
builder->Update(gpair, dmat, trees[i]);
}
param.learning_rate = lr;
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* out_preds) override {
return builder.UpdatePredictionCache(data, out_preds);
GPUHistMaker::GPUHistMaker() : builder(new GPUHistBuilder()) {}
void GPUHistMaker::Init(
const std::vector<std::pair<std::string, std::string>>& args) {
param.InitAllowUnknown(args);
builder->Init(param);
}
void GPUHistMaker::Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
builder->UpdateParam(param);
// build tree
try {
for (size_t i = 0; i < trees.size(); ++i) {
builder->Update(gpair, dmat, trees[i]);
}
} catch (const std::exception& e) {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
}
protected:
// training parameter
TrainParam param;
GPUHistBuilder builder;
};
bool GPUHistMaker::UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* out_preds) {
return builder->UpdatePredictionCache(data, out_preds);
}
XGBOOST_REGISTER_TREE_UPDATER(GPUMaker, "grow_gpu")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUMaker<GradStats>(); });
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([]() { return new GPUHistMaker<GradStats>(); });
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,48 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <xgboost/tree_updater.h>
#include <memory>
#include "../../../src/tree/param.h"
namespace xgboost {
namespace tree {
// Forward declare builder classes
class GPUHistBuilder;
namespace exact {
template <typename node_id_t>
class GPUBuilder;
}
class GPUMaker : public TreeUpdater {
protected:
TrainParam param;
std::unique_ptr<exact::GPUBuilder<int16_t>> builder;
public:
GPUMaker();
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override;
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees);
};
class GPUHistMaker : public TreeUpdater {
public:
GPUHistMaker();
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override;
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* out_preds) override;
protected:
TrainParam param;
std::unique_ptr<GPUHistBuilder> builder;
};
} // namespace tree
} // namespace xgboost