[GPU-Plugin] Multi-GPU gpu_id bug fixes for grow_gpu_hist and grow_gpu methods, and additional documentation for the gpu plugin. (#2463)

This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2017-06-30 01:04:17 -07:00 committed by Rory Mitchell
parent 91dae84a00
commit 6b287177c8
21 changed files with 578 additions and 449 deletions

2
.gitignore vendored
View File

@ -15,7 +15,7 @@
*.Rcheck
*.rds
*.tar.gz
*txt*
#*txt*
*conf
*buffer
*model

View File

@ -3,11 +3,8 @@ project (xgboost)
find_package(OpenMP)
option(PLUGIN_UPDATER_GPU "Build GPU accelerated tree construction plugin")
set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING
"Space separated list of compute versions to be built against")
if(PLUGIN_UPDATER_GPU)
cmake_minimum_required (VERSION 3.5)
find_package(CUDA REQUIRED)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
@ -83,6 +80,14 @@ set(RABIT_SOURCES
rabit/src/c_api.cc
)
set(NCCL_SOURCES
nccl/src/*.cu
)
set(UPDATER_GPU_SOURCES
plugin/updater_gpu/src/*.cu
plugin/updater_gpu/src/exact/*.cu
)
add_subdirectory(dmlc-core)
add_library(rabit STATIC ${RABIT_SOURCES})
@ -102,19 +107,26 @@ endif()
set(LINK_LIBRARIES dmlccore rabit)
if(PLUGIN_UPDATER_GPU)
find_package(CUDA REQUIRED)
# nccl
set(LINK_LIBRARIES ${LINK_LIBRARIES} nccl)
add_subdirectory(nccl)
set(NCCL_DIRECTORY ${PROJECT_SOURCE_DIR}/nccl)
include_directories(${NCCL_DIRECTORY}/src)
set(LINK_LIBRARIES ${LINK_LIBRARIES} ${CUDA_LIBRARIES})
#Find cub
set(CUB_DIRECTORY ${PROJECT_SOURCE_DIR}/cub/)
include_directories(${CUB_DIRECTORY})
#Find googletest
set(GTEST_DIRECTORY "${CACHE_PREFIX}" CACHE PATH "Googletest directory")
include_directories(${GTEST_DIRECTORY}/include)
#gencode flags
set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING
"Space separated list of compute versions to be built against")
set(GENCODE_FLAGS "")
foreach(ver ${GPU_COMPUTE_VER})
set(GENCODE_FLAGS "${GENCODE_FLAGS}-gencode arch=compute_${ver},code=sm_${ver};")
@ -129,6 +141,8 @@ if(PLUGIN_UPDATER_GPU)
)
# use below for forcing specific arch
cuda_compile(CUDA_OBJS ${CUDA_SOURCES} ${CUDA_NVCC_FLAGS})
else()
set(CUDA_OBJS "")
endif()

View File

@ -96,7 +96,7 @@ endif
CFLAGS += $(OPENMP_FLAGS)
# for using GPUs
GPU_COMPUTE_VER ?= 50 52 60 61
GPU_COMPUTE_VER ?= 35 50 52 60 61
NVCC = nvcc
INCLUDES = -Iinclude -I$(DMLC_CORE)/include -I$(RABIT)/include
INCLUDES += -I$(CUB_PATH)
@ -106,14 +106,13 @@ NVCC_FLAGS = --std=c++11 $(CODE) $(INCLUDES) -lineinfo --expt-extended-lambda
NVCC_FLAGS += -Xcompiler=$(OPENMP_FLAGS) -Xcompiler=-fPIC
ifeq ($(PLUGIN_UPDATER_GPU),ON)
CUDA_ROOT = $(shell dirname $(shell dirname $(shell which $(NVCC))))
INCLUDES += -I$(CUDA_ROOT)/include
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart
INCLUDES += -I$(CUDA_ROOT)/include -Inccl/src/
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart -lcudadevrt -Lnccl/build/lib/ -lnccl_static -lm -ldl -lrt
endif
# specify tensor path
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
$(DMLC_CORE)/libdmlc.a: $(wildcard $(DMLC_CORE)/src/*.cc $(DMLC_CORE)/src/*/*.cc)
@ -143,7 +142,7 @@ build/%.o: src/%.cc
$(CXX) -c $(CFLAGS) $< -o $@
# order of this rule matters wrt %.cc rule below!
build_plugin/%.o: plugin/%.cu
build_plugin/%.o: plugin/%.cu build_nccl
@mkdir -p $(@D)
$(NVCC) -c $(NVCC_FLAGS) $< -o $@
@ -152,6 +151,11 @@ build_plugin/%.o: plugin/%.cc
$(CXX) $(CFLAGS) -MM -MT build_plugin/$*.o $< >build_plugin/$*.d
$(CXX) -c $(CFLAGS) $< -o $@
build_nccl:
@mkdir -p build/include
cd build/include ; ln -sf ../../nccl/src/nccl.h .
cd nccl ; make -j ; cd ..
# The should be equivalent to $(ALL_OBJ) except for build/cli_main.o
amalgamation/xgboost-all0.o: amalgamation/xgboost-all0.cc
$(CXX) -c $(CFLAGS) $< -o $@
@ -173,6 +177,7 @@ jvm-packages/lib/libxgboost4j.so: jvm-packages/xgboost4j/src/native/xgboost4j.cp
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(JAVAINCFLAGS) -shared -o $@ $(filter %.cpp %.o %.a, $^) $(LDFLAGS)
xgboost: $(CLI_OBJ) $(ALL_DEP)
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

2
cub

@ -1 +1 @@
Subproject commit 89de7ab20167909bc2c4f8acd397671c47cf3c0d
Subproject commit f3937a96fdec78a73446aaaa114c112ff31f5503

2
nccl

@ -1 +1 @@
Subproject commit 93183bca921b2e8e1754e27e1b43d73cf6caec9d
Subproject commit 8ec6c27a33a900fb92f7e39acc73cc0f43e8539b

View File

@ -63,6 +63,22 @@ submodule: The plugin also depends on CUB 1.6.4 - https://nvlabs.github.io/cub/
submodule: NVIDIA NCCL from https://github.com/NVIDIA/nccl with windows port allowed by git@github.com:h2oai/nccl.git
## Download full repo + full submodules for your choice (or empty) path <mypath>
git clone --recursive https://github.com/dmlc/xgboost.git <mypath>
## Download with shallow submodules for much quicker download:
git 2.9.0+ (assumes only HEAD used for all submodules, but not true currently for dmlc-core and rabbit)
git clone --recursive --shallow-submodules https://github.com/dmlc/xgboost.git <mypath>
git 2.9.0-: (only cub is shallow, as largest repo)
git clone https://github.com/dmlc/xgboost.git <mypath>
cd <mypath>
bash plugin/updater/gpu/gitshallow_submodules.sh
## Build
From the command line on Linux starting from the xgboost directory:
@ -84,12 +100,18 @@ $ mkdir build
$ cd build
$ cmake .. -G"Visual Studio 14 2015 Win64" -DPLUGIN_UPDATER_GPU=ON
```
Cmake will generate an xgboost.sln solution file in the build directory. Build this solution in release mode as a x64 build.
Cmake will create an xgboost.sln solution file in the build directory. Build this solution in release mode as a x64 build.
Visual studio community 2015, supported by cuda toolkit (http://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/#axzz4isREr2nS), can be downloaded from: https://my.visualstudio.com/Downloads?q=Visual%20Studio%20Community%202015 . You may also be able to use a later version of visual studio depending on whether the CUDA toolkit supports it. Note that Mingw cannot be used with cuda.
### For other nccl libraries
On some systems, nccl libraries are specific to a particular system (IBM Power or nvidia-docker) and can enable use of nvlink (between GPUs or even between GPUs and system memory). In that case, one wants to avoid the static nccl library by changing "STATIC" to "SHARED" in nccl/CMakeLists.txt and deleting the shared nccl library created (so that the system one is used).
### For Developers!
In case you want to build only for a specific GPU(s), for eg. GP100 and GP102,
whose compute capability are 60 and 61 respectively:
```bash
@ -101,12 +123,12 @@ By default, the versions will include support for all GPUs in Maxwell and Pascal
Now, it also supports the usual 'make' flow to build gpu-enabled tree construction plugins. It's currently only tested on Linux. From the xgboost directory
```bash
# make sure CUDA SDK bin directory is in the 'PATH' env variable
$ make PLUGIN_UPDATER_GPU=ON
$ make -j PLUGIN_UPDATER_GPU=ON
```
Similar to cmake, if you want to build only for a specific GPU(s):
```bash
$ make PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
$ make -j PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
```
### For Developers!

View File

@ -16,6 +16,8 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
param = {'objective': 'binary:logistic',
'max_depth': 6,
'silent': 1,
'n_gpus': 1,
'gpu_id': 0,
'eval_metric': 'auc'}
param['tree_method'] = gpu_algorithm
@ -41,9 +43,9 @@ args = parser.parse_args()
if 'gpu_hist' in args.algorithm:
run_benchmark(args, args.algorithm, 'hist')
if 'gpu_exact' in args.algorithm:
elif 'gpu_exact' in args.algorithm:
run_benchmark(args, args.algorithm, 'exact')
if 'all' in args.algorithm:
elif 'all' in args.algorithm:
run_benchmark(args, 'gpu_exact', 'exact')
run_benchmark(args, 'gpu_hist', 'hist')

View File

@ -0,0 +1,12 @@
#!/bin/bash
git submodule init
for i in $(git submodule | awk '{print $2}'); do
spath=$(git config -f .gitmodules --get submodule.$i.path)
surl=$(git config -f .gitmodules --get submodule.$i.url)
if [ $spath == "cub" ]
then
git submodule update --depth 3 $spath
else
git submodule update $spath
fi
done

View File

@ -2,16 +2,16 @@
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <cstdio>
#include <stdexcept>
#include <string>
#include <vector>
#include "../../../src/common/random.h"
#include "../../../src/tree/param.h"
#include "device_helpers.cuh"
#include "types.cuh"
#include <string>
#include <stdexcept>
#include <cstdio>
#include "cub/cub.cuh"
#include "device_helpers.cuh"
#include "device_helpers.cuh"
#include "types.cuh"
namespace xgboost {
namespace tree {
@ -172,8 +172,8 @@ inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample) {
}
inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
int n = colsample * features.size();
CHECK_GT(n, 0);
CHECK_GT(features.size(), 0);
int n = std::max(1,static_cast<int>(colsample * features.size()));
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
@ -202,17 +202,18 @@ struct GpairCallbackOp {
* @param offsets the segments
*/
template <typename T1, typename T2>
void segmentedSort(dh::CubMemory &tmp_mem, dh::dvec2<T1> &keys, dh::dvec2<T2> &vals,
int nVals, int nSegs, dh::dvec<int> &offsets, int start=0,
int end=sizeof(T1)*8) {
void segmentedSort(dh::CubMemory& tmp_mem, dh::dvec2<T1>& keys,
dh::dvec2<T2>& vals, int nVals, int nSegs,
dh::dvec<int>& offsets, int start = 0,
int end = sizeof(T1) * 8) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
NULL, tmpSize, keys.buff(), vals.buff(), nVals, nSegs,
offsets.data(), offsets.data()+1, start, end));
NULL, tmpSize, keys.buff(), vals.buff(), nVals, nSegs, offsets.data(),
offsets.data() + 1, start, end));
tmp_mem.LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
tmp_mem.d_temp_storage, tmpSize, keys.buff(), vals.buff(),
nVals, nSegs, offsets.data(), offsets.data()+1, start, end));
tmp_mem.d_temp_storage, tmpSize, keys.buff(), vals.buff(), nVals, nSegs,
offsets.data(), offsets.data() + 1, start, end));
}
/**
@ -223,11 +224,11 @@ void segmentedSort(dh::CubMemory &tmp_mem, dh::dvec2<T1> &keys, dh::dvec2<T2> &v
* @param nVals number of elements in the input array
*/
template <typename T>
void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
void sumReduction(dh::CubMemory& tmp_mem, dh::dvec<T>& in, dh::dvec<T>& out,
int nVals) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(),
nVals));
dh::safe_cuda(
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
tmp_mem.LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem.d_temp_storage, tmpSize,
in.data(), out.data(), nVals));
@ -239,9 +240,10 @@ void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
* @param len number of elements i the buffer
* @param def default value to be filled
*/
template <typename T, int BlkDim=256, int ItemsPerThread=4>
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void fillConst(int device_idx, T* out, int len, T def) {
dh::launch_n<ItemsPerThread,BlkDim>(device_idx, len, [=] __device__(int i) { out[i] = def; });
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, len,
[=] __device__(int i) { out[i] = def; });
}
/**
@ -253,11 +255,11 @@ void fillConst(int device_idx, T* out, int len, T def) {
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T1, typename T2, int BlkDim=256, int ItemsPerThread=4>
void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2, const int* instId,
int nVals) {
dh::launch_n<ItemsPerThread,BlkDim>
(device_idx, nVals, [=] __device__(int i) {
template <typename T1, typename T2, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2,
const int* instId, int nVals) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
T1 v1 = in1[iid];
T2 v2 = in2[iid];
@ -273,10 +275,10 @@ void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2, co
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T, int BlkDim=256, int ItemsPerThread=4>
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T* out, const T* in, const int* instId, int nVals) {
dh::launch_n<ItemsPerThread,BlkDim>
(device_idx, nVals, [=] __device__(int i) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
out[i] = in[iid];
});

View File

@ -9,11 +9,11 @@
#include <algorithm>
#include <chrono>
#include <ctime>
#include <cub/cub.cuh>
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include <numeric>
#include <cub/cub.cuh>
#ifndef NCCL
#define NCCL 1
@ -126,6 +126,11 @@ inline std::string device_name(int device_idx) {
return std::string(prop.name);
}
// ensure gpu_id is correct, so not dependent upon user knowing details
inline int get_device_idx(int gpu_id) {
// protect against overrun for gpu_id
return (std::abs(gpu_id) + 0) % dh::n_visible_devices();
}
/*
* Timers
@ -309,11 +314,13 @@ enum memory_type { DEVICE, DEVICE_MANAGED };
template <memory_type MemoryT>
class bulk_allocator;
template <typename T> class dvec2;
template <typename T>
class dvec2;
template <typename T>
class dvec {
friend class dvec2<T>;
private:
T *_ptr;
size_t _size;
@ -327,9 +334,10 @@ class dvec {
_ptr = static_cast<T *>(ptr);
_size = size;
_device_idx = device_idx;
safe_cuda(cudaSetDevice(_device_idx));
}
dvec() : _ptr(NULL), _size(0), _device_idx(0) {}
dvec() : _ptr(NULL), _size(0), _device_idx(-1) {}
size_t size() const { return _size; }
int device_idx() const { return _device_idx; }
bool empty() const { return _ptr == NULL || _size == 0; }
@ -378,6 +386,10 @@ class dvec {
if (other.device_idx() == this->device_idx()) {
thrust::copy(other.tbegin(), other.tend(), this->tbegin());
} else {
std::cout << "deviceother: " << other.device_idx()
<< " devicethis: " << this->device_idx() << std::endl;
std::cout << "size deviceother: " << other.size()
<< " devicethis: " << this->device_idx() << std::endl;
throw std::runtime_error("Cannot copy to/from different devices");
}
@ -401,26 +413,24 @@ class dvec {
*/
template <typename T>
class dvec2 {
private:
dvec<T> _d1, _d2;
cub::DoubleBuffer<T> _buff;
int _device_idx;
public:
void external_allocate(int device_idx, void *ptr1, void *ptr2, size_t size) {
if (!empty()) {
throw std::runtime_error("Tried to allocate dvec2 but already allocated");
}
_device_idx = device_idx;
_d1.external_allocate(_device_idx, ptr1, size);
_d2.external_allocate(_device_idx, ptr2, size);
_buff.d_buffers[0] = static_cast<T *>(ptr1);
_buff.d_buffers[1] = static_cast<T *>(ptr2);
_buff.selector = 0;
_device_idx = device_idx;
}
dvec2() : _d1(), _d2(), _buff(), _device_idx(0) {}
dvec2() : _d1(), _d2(), _buff(), _device_idx(-1) {}
size_t size() const { return _d1.size(); }
int device_idx() const { return _device_idx; }
@ -433,7 +443,7 @@ class dvec2 {
T *current() { return _buff.Current(); }
dvec<T> &current_dvec() { return _buff.selector == 0? d1() : d2(); }
dvec<T> &current_dvec() { return _buff.selector == 0 ? d1() : d2(); }
T *other() { return _buff.Alternate(); }
};
@ -459,7 +469,8 @@ class bulk_allocator {
template <typename T, typename SizeT, typename... Args>
size_t get_size_bytes(dvec<T> *first_vec, SizeT first_size, Args... args) {
return get_size_bytes<T,SizeT>(first_vec, first_size) + get_size_bytes(args...);
return get_size_bytes<T, SizeT>(first_vec, first_size) +
get_size_bytes(args...);
}
template <typename T, typename SizeT>
@ -496,20 +507,23 @@ class bulk_allocator {
template <typename T, typename SizeT, typename... Args>
size_t get_size_bytes(dvec2<T> *first_vec, SizeT first_size, Args... args) {
return get_size_bytes<T,SizeT>(first_vec, first_size) + get_size_bytes(args...);
return get_size_bytes<T, SizeT>(first_vec, first_size) +
get_size_bytes(args...);
}
template <typename T, typename SizeT>
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec, SizeT first_size) {
first_vec->external_allocate(device_idx, static_cast<void *>(ptr),
static_cast<void *>(ptr+align_round_up(first_size * sizeof(T))),
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec,
SizeT first_size) {
first_vec->external_allocate(
device_idx, static_cast<void *>(ptr),
static_cast<void *>(ptr + align_round_up(first_size * sizeof(T))),
first_size);
}
template <typename T, typename SizeT, typename... Args>
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec, SizeT first_size,
Args... args) {
allocate_dvec<T,SizeT>(device_idx, ptr, first_vec, first_size);
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec,
SizeT first_size, Args... args) {
allocate_dvec<T, SizeT>(device_idx, ptr, first_vec, first_size);
ptr += (align_round_up(first_size * sizeof(T)) * 2);
allocate_dvec(device_idx, ptr, args...);
}
@ -711,6 +725,6 @@ struct BernoulliRng {
dh::Timer t1234; \
call; \
t1234.printElapsed(name); \
} while(0)
} while (0)
} // namespace dh

View File

@ -17,8 +17,8 @@
#include "../../../../src/tree/param.h"
#include "../common.cuh"
#include "node.cuh"
#include "loss_functions.cuh"
#include "node.cuh"
namespace xgboost {
namespace tree {
@ -45,7 +45,7 @@ HOST_DEV_INLINE Split maxSplit(Split a, Split b) {
out.index = b.index;
} else if (a.score == b.score) {
out.score = a.score;
out.index = (a.index < b.index)? a.index : b.index;
out.index = (a.index < b.index) ? a.index : b.index;
} else {
out.score = a.score;
out.index = a.index;
@ -54,7 +54,7 @@ HOST_DEV_INLINE Split maxSplit(Split a, Split b) {
}
DEV_INLINE void atomicArgMax(Split* address, Split val) {
unsigned long long* intAddress = (unsigned long long*) address;
unsigned long long* intAddress = (unsigned long long*)address;
unsigned long long old = *intAddress;
unsigned long long assumed;
do {
@ -65,23 +65,19 @@ DEV_INLINE void atomicArgMax(Split* address, Split val) {
}
template <typename node_id_t>
DEV_INLINE void argMaxWithAtomics(int id, Split* nodeSplits,
const gpu_gpair* gradScans,
const gpu_gpair* gradSums, const float* vals,
const int* colIds,
const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len,
const TrainParam &param) {
DEV_INLINE void argMaxWithAtomics(
int id, Split* nodeSplits, const gpu_gpair* gradScans,
const gpu_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam& param) {
int nodeId = nodeAssigns[id];
///@todo: this is really a bad check! but will be fixed when we move
/// to key-based reduction
if ((id == 0) || !((nodeId == nodeAssigns[id-1]) &&
(colIds[id] == colIds[id-1]) &&
(vals[id] == vals[id-1]))) {
if ((id == 0) ||
!((nodeId == nodeAssigns[id - 1]) && (colIds[id] == colIds[id - 1]) &&
(vals[id] == vals[id - 1]))) {
if (nodeId != UNUSED_NODE) {
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart,
nUniqKeys);
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys);
gpu_gpair colSum = gradSums[sumId];
int uid = nodeId - nodeStart;
Node<node_id_t> n = nodes[nodeId];
@ -90,22 +86,19 @@ DEV_INLINE void argMaxWithAtomics(int id, Split* nodeSplits,
bool tmp;
Split s;
gpu_gpair missing = parentSum - colSum;
s.score = loss_chg_missing(gradScans[id], missing, parentSum,
parentGain, param, tmp);
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
param, tmp);
s.index = id;
atomicArgMax(nodeSplits+uid, s);
atomicArgMax(nodeSplits + uid, s);
} // end if nodeId != UNUSED_NODE
} // end if id == 0 ...
}
template <typename node_id_t>
__global__ void atomicArgMaxByKeyGmem(Split* nodeSplits,
const gpu_gpair* gradScans,
const gpu_gpair* gradSums,
const float* vals, const int* colIds,
const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len,
__global__ void atomicArgMaxByKeyGmem(
Split* nodeSplits, const gpu_gpair* gradScans, const gpu_gpair* gradSums,
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
const TrainParam param) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
@ -116,19 +109,16 @@ __global__ void atomicArgMaxByKeyGmem(Split* nodeSplits,
}
template <typename node_id_t>
__global__ void atomicArgMaxByKeySmem(Split* nodeSplits,
const gpu_gpair* gradScans,
const gpu_gpair* gradSums,
const float* vals, const int* colIds,
const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len,
__global__ void atomicArgMaxByKeySmem(
Split* nodeSplits, const gpu_gpair* gradScans, const gpu_gpair* gradSums,
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
const TrainParam param) {
extern __shared__ char sArr[];
Split* sNodeSplits = (Split*)sArr;
int tid = threadIdx.x;
Split defVal;
#pragma unroll 1
#pragma unroll 1
for (int i = tid; i < nUniqKeys; i += blockDim.x) {
sNodeSplits[i] = defVal;
}
@ -142,7 +132,7 @@ __global__ void atomicArgMaxByKeySmem(Split* nodeSplits,
__syncthreads();
for (int i = tid; i < nUniqKeys; i += blockDim.x) {
Split s = sNodeSplits[i];
atomicArgMax(nodeSplits+i, s);
atomicArgMax(nodeSplits + i, s);
}
}
@ -162,24 +152,26 @@ __global__ void atomicArgMaxByKeySmem(Split* nodeSplits,
* @param param training parameters
* @param algo which algorithm to use for argmax_by_key
*/
template <typename node_id_t, int BLKDIM=256, int ITEMS_PER_THREAD=4>
template <typename node_id_t, int BLKDIM = 256, int ITEMS_PER_THREAD = 4>
void argMaxByKey(Split* nodeSplits, const gpu_gpair* gradScans,
const gpu_gpair* gradSums, const float* vals, const int* colIds,
const node_id_t* nodeAssigns, const Node<node_id_t>* nodes, int nUniqKeys,
const gpu_gpair* gradSums, const float* vals,
const int* colIds, const node_id_t* nodeAssigns,
const Node<node_id_t>* nodes, int nUniqKeys,
node_id_t nodeStart, int len, const TrainParam param,
ArgMaxByKeyAlgo algo) {
fillConst<Split,BLKDIM,ITEMS_PER_THREAD>(param.gpu_id, nodeSplits, nUniqKeys, Split());
int nBlks = dh::div_round_up(len, ITEMS_PER_THREAD*BLKDIM);
switch(algo) {
fillConst<Split, BLKDIM, ITEMS_PER_THREAD>(dh::get_device_idx(param.gpu_id),
nodeSplits, nUniqKeys, Split());
int nBlks = dh::div_round_up(len, ITEMS_PER_THREAD * BLKDIM);
switch (algo) {
case ABK_GMEM:
atomicArgMaxByKeyGmem<node_id_t><<<nBlks,BLKDIM>>>
(nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
atomicArgMaxByKeyGmem<node_id_t><<<nBlks, BLKDIM>>>(
nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
nUniqKeys, nodeStart, len, param);
break;
case ABK_SMEM:
atomicArgMaxByKeySmem<node_id_t>
<<<nBlks,BLKDIM,sizeof(Split)*nUniqKeys>>>
(nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
atomicArgMaxByKeySmem<
node_id_t><<<nBlks, BLKDIM, sizeof(Split) * nUniqKeys>>>(
nodeSplits, gradScans, gradSums, vals, colIds, nodeAssigns, nodes,
nUniqKeys, nodeStart, len, param);
break;
default:

View File

@ -18,7 +18,6 @@
#include "../common.cuh"
#include "gradients.cuh"
namespace xgboost {
namespace tree {
namespace exact {
@ -41,7 +40,7 @@ static const int NONE_KEY = -100;
* @param tmpKeys keys buffer
* @param size number of elements that will be scanned
*/
template <int BLKDIM_L1L3=256>
template <int BLKDIM_L1L3 = 256>
int scanTempBufferSize(int size) {
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
return nBlks;
@ -49,7 +48,7 @@ int scanTempBufferSize(int size) {
struct AddByKey {
template <typename T>
HOST_DEV_INLINE T operator()(const T &first, const T &second) const {
HOST_DEV_INLINE T operator()(const T& first, const T& second) const {
T result;
if (first.key == second.key) {
result.key = first.key;
@ -74,7 +73,7 @@ __global__ void cubScanByKeyL1(gpu_gpair* scans, const gpu_gpair* vals,
typedef cub::BlockScan<Pair, BLKDIM_L1L3> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
Pair threadData;
int tid = blockIdx.x*BLKDIM_L1L3 + threadIdx.x;
int tid = blockIdx.x * BLKDIM_L1L3 + threadIdx.x;
if (tid < size) {
myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
myValue = get(tid, vals, instIds);
@ -90,17 +89,16 @@ __global__ void cubScanByKeyL1(gpu_gpair* scans, const gpu_gpair* vals,
// else, the result of this shuffle operation will be undefined
int previousKey = __shfl_up(myKey, 1);
// Collectively compute the block-wide exclusive prefix sum
BlockScan(temp_storage).ExclusiveScan(threadData, threadData, rootPair,
AddByKey());
BlockScan(temp_storage)
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());
if (tid < size) {
scans[tid] = threadData.value;
} else {
return;
}
if (threadIdx.x == BLKDIM_L1L3 - 1) {
threadData.value = (myKey == previousKey)?
threadData.value :
gpu_gpair(0.0f, 0.0f);
threadData.value =
(myKey == previousKey) ? threadData.value : gpu_gpair(0.0f, 0.0f);
mKeys[blockIdx.x] = myKey;
mScans[blockIdx.x] = threadData.value + myValue;
}
@ -111,11 +109,10 @@ __global__ void cubScanByKeyL2(gpu_gpair* mScans, int* mKeys, int mLength) {
typedef cub::BlockScan<Pair, BLKSIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
Pair threadData;
__shared__ typename BlockScan::TempStorage temp_storage;
for (int i = threadIdx.x; i < mLength; i += BLKSIZE-1) {
for (int i = threadIdx.x; i < mLength; i += BLKSIZE - 1) {
threadData.key = mKeys[i];
threadData.value = mScans[i];
BlockScan(temp_storage).InclusiveScan(threadData, threadData,
AddByKey());
BlockScan(temp_storage).InclusiveScan(threadData, threadData, AddByKey());
mScans[i] = threadData.value;
__syncthreads();
}
@ -136,15 +133,14 @@ __global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
__shared__ char gradBuff[sizeof(gpu_gpair)];
__shared__ int s_mKeys;
gpu_gpair* s_mScans = (gpu_gpair*)gradBuff;
if(tid >= size)
return;
if (tid >= size) return;
// cache block-wide partial scan info
if (relId == 0) {
s_mKeys = (blockIdx.x > 0)? mKeys[blockIdx.x-1] : NONE_KEY;
s_mScans[0] = (blockIdx.x > 0)? mScans[blockIdx.x-1] : gpu_gpair();
s_mKeys = (blockIdx.x > 0) ? mKeys[blockIdx.x - 1] : NONE_KEY;
s_mScans[0] = (blockIdx.x > 0) ? mScans[blockIdx.x - 1] : gpu_gpair();
}
int myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
int previousKey = tid == 0 ? NONE_KEY : abs2uniqKey(tid-1, keys, colIds,
int previousKey = tid == 0 ? NONE_KEY : abs2uniqKey(tid - 1, keys, colIds,
nodeStart, nUniqKeys);
gpu_gpair myValue = scans[tid];
__syncthreads();
@ -162,9 +158,11 @@ __global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
}
/**
* @brief Performs fused reduce and scan by key functionality. It is assumed that
* @brief Performs fused reduce and scan by key functionality. It is assumed
* that
* the keys occur contiguously!
* @param sums the output gradient reductions for each element performed key-wise
* @param sums the output gradient reductions for each element performed
* key-wise
* @param scans the output gradient scans for each element performed key-wise
* @param vals the gradients evaluated for each observation.
* @param instIds instance ids for each element
@ -179,19 +177,19 @@ __global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
* @param colIds column indices for each element in the array
* @param nodeStart index of the leftmost node in the current level
*/
template <typename node_id_t, int BLKDIM_L1L3=256, int BLKDIM_L2=512>
template <typename node_id_t, int BLKDIM_L1L3 = 256, int BLKDIM_L2 = 512>
void reduceScanByKey(gpu_gpair* sums, gpu_gpair* scans, const gpu_gpair* vals,
const int* instIds, const node_id_t* keys, int size,
int nUniqKeys, int nCols, gpu_gpair* tmpScans,
int* tmpKeys, const int* colIds, node_id_t nodeStart) {
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
cudaMemset(sums, 0, nUniqKeys*nCols*sizeof(gpu_gpair));
cubScanByKeyL1<node_id_t,BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>
(scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
cudaMemset(sums, 0, nUniqKeys * nCols * sizeof(gpu_gpair));
cubScanByKeyL1<node_id_t, BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>(
scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
nodeStart, size);
cubScanByKeyL2<BLKDIM_L2><<<1, BLKDIM_L2>>>(tmpScans, tmpKeys, nBlks);
cubScanByKeyL3<node_id_t,BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>
(sums, scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
cubScanByKeyL3<node_id_t, BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>(
sums, scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
nodeStart, size);
}

View File

@ -1,5 +1,6 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights reserved.
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
* reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,18 +16,17 @@
*/
#pragma once
#include "../../../../src/tree/param.h"
#include "xgboost/tree_updater.h"
#include "cub/cub.cuh"
#include "../common.cuh"
#include <vector>
#include "loss_functions.cuh"
#include "gradients.cuh"
#include "node.cuh"
#include "../../../../src/tree/param.h"
#include "../common.cuh"
#include "argmax_by_key.cuh"
#include "split2node.cuh"
#include "cub/cub.cuh"
#include "fused_scan_reduce_by_key.cuh"
#include "gradients.cuh"
#include "loss_functions.cuh"
#include "node.cuh"
#include "split2node.cuh"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace tree {
@ -48,8 +48,8 @@ template <typename node_id_t>
__global__ void assignColIds(int* colIds, const int* colOffsets) {
int myId = blockIdx.x;
int start = colOffsets[myId];
int end = colOffsets[myId+1];
for (int id = start+threadIdx.x; id < end; id += blockDim.x) {
int end = colOffsets[myId + 1];
for (int id = start + threadIdx.x; id < end; id += blockDim.x) {
colIds[id] = myId;
}
}
@ -70,7 +70,7 @@ __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
node_id_t result;
if (n.isLeaf() || n.isUnused()) {
result = UNUSED_NODE;
} else if(n.isDefaultLeft()) {
} else if (n.isDefaultLeft()) {
result = (2 * n.id) + 1;
} else {
result = (2 * n.id) + 2;
@ -81,8 +81,9 @@ __global__ void fillDefaultNodeIds(node_id_t* nodeIdsPerInst,
template <typename node_id_t>
__global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
const node_id_t* nodeIds, const int* instId,
const Node<node_id_t>* nodes, const int* colOffsets,
const float* vals, int nVals, int nCols) {
const Node<node_id_t>* nodes,
const int* colOffsets, const float* vals,
int nVals, int nCols) {
int id = threadIdx.x + (blockIdx.x * blockDim.x);
const int stride = blockDim.x * gridDim.x;
for (; id < nVals; id += stride) {
@ -95,7 +96,7 @@ __global__ void assignNodeIds(node_id_t* nodeIdsPerInst, int* nodeLocations,
if (nId != UNUSED_NODE) {
const Node<node_id_t> n = nodes[nId];
int colId = n.colIdx;
//printf("nid=%d colId=%d id=%d\n", nId, colId, id);
// printf("nid=%d colId=%d id=%d\n", nId, colId, id);
int start = colOffsets[colId];
int end = colOffsets[colId + 1];
///@todo: too much wasteful threads!!
@ -122,12 +123,18 @@ __global__ void markLeavesKernel(Node<node_id_t>* nodes, int len) {
}
// unit test forward declaration for friend function access
template <typename node_id_t> void testSmallData();
template <typename node_id_t> void testLargeData();
template <typename node_id_t> void testAllocate();
template <typename node_id_t> void testMarkLeaves();
template <typename node_id_t> void testDense2Sparse();
template <typename node_id_t> class GPUBuilder;
template <typename node_id_t>
void testSmallData();
template <typename node_id_t>
void testLargeData();
template <typename node_id_t>
void testAllocate();
template <typename node_id_t>
void testMarkLeaves();
template <typename node_id_t>
void testDense2Sparse();
template <typename node_id_t>
class GPUBuilder;
template <typename node_id_t>
std::shared_ptr<xgboost::DMatrix> setupGPUBuilder(
const std::string& file,
@ -136,7 +143,7 @@ std::shared_ptr<xgboost::DMatrix> setupGPUBuilder(
template <typename node_id_t>
class GPUBuilder {
public:
GPUBuilder(): allocated(false) {}
GPUBuilder() : allocated(false) {}
~GPUBuilder() {}
@ -146,10 +153,10 @@ class GPUBuilder {
maxLeaves = 1 << param.max_depth;
}
void UpdateParam(const TrainParam &param) { this->param = param; }
void UpdateParam(const TrainParam& param) { this->param = param; }
/// @note: Update should be only after Init!!
void Update(const std::vector<bst_gpair>& gpair, DMatrix *hMat,
void Update(const std::vector<bst_gpair>& gpair, DMatrix* hMat,
RegTree* hTree) {
if (!allocated) {
setupOneTimeData(*hMat);
@ -171,7 +178,7 @@ class GPUBuilder {
dense2sparse(*hTree);
}
private:
private:
friend void testSmallData<node_id_t>();
friend void testLargeData<node_id_t>();
friend void testAllocate<node_id_t>();
@ -194,7 +201,7 @@ private:
dh::dvec<gpu_gpair> gradsInst;
dh::dvec2<node_id_t> nodeAssigns;
dh::dvec2<int> nodeLocations;
dh::dvec<Node<node_id_t> > nodes;
dh::dvec<Node<node_id_t>> nodes;
dh::dvec<node_id_t> nodeAssignsPerInst;
dh::dvec<gpu_gpair> gradSums;
dh::dvec<gpu_gpair> gradScans;
@ -218,35 +225,26 @@ private:
argMaxByKey(nodeSplits.data(), gradScans.data(), gradSums.data(),
vals.current(), colIds.data(), nodeAssigns.current(),
nodes.data(), nNodes, nodeStart, nVals, param,
level<=MAX_ABK_LEVELS? ABK_SMEM : ABK_GMEM);
level <= MAX_ABK_LEVELS ? ABK_SMEM : ABK_GMEM);
split2node(nodes.data(), nodeSplits.data(), gradScans.data(),
gradSums.data(), vals.current(), colIds.data(), colOffsets.data(),
nodeAssigns.current(), nNodes, nodeStart, nCols, param);
gradSums.data(), vals.current(), colIds.data(),
colOffsets.data(), nodeAssigns.current(), nNodes, nodeStart,
nCols, param);
}
void allocateAllData(int offsetSize) {
int tmpBuffSize = scanTempBufferSize(nVals);
ba.allocate(param.gpu_id,
&vals, nVals,
&vals_cached, nVals,
&instIds, nVals,
&instIds_cached, nVals,
&colOffsets, offsetSize,
&gradsInst, nRows,
&nodeAssigns, nVals,
&nodeLocations, nVals,
&nodes, maxNodes,
&nodeAssignsPerInst, nRows,
&gradSums, maxLeaves*nCols,
&gradScans, nVals,
&nodeSplits, maxLeaves,
&tmpScanGradBuff, tmpBuffSize,
&tmpScanKeyBuff, tmpBuffSize,
&colIds, nVals);
ba.allocate(dh::get_device_idx(param.gpu_id), &vals, nVals, &vals_cached,
nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets,
offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
nRows, &gradSums, maxLeaves * nCols, &gradScans, nVals,
&nodeSplits, maxLeaves, &tmpScanGradBuff, tmpBuffSize,
&tmpScanKeyBuff, tmpBuffSize, &colIds, nVals);
}
void setupOneTimeData(DMatrix& hMat) {
size_t free_memory = dh::available_memory(param.gpu_id);
size_t free_memory = dh::available_memory(dh::get_device_idx(param.gpu_id));
if (!hMat.SingleColBlock()) {
throw std::runtime_error("exact::GPUBuilder - must have 1 column block");
}
@ -259,7 +257,8 @@ private:
if (!param.silent) {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on " << dh::device_name(param.gpu_id);
<< free_memory / mb_size << " MB on "
<< dh::device_name(dh::get_device_idx(param.gpu_id));
}
}
@ -282,9 +281,10 @@ private:
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch& batch = iter->Value();
for (int i=0;i<batch.size;i++) {
for (int i = 0; i < batch.size; i++) {
const ColBatch::Inst& col = batch[i];
for (const ColBatch::Entry* it=col.data;it!=col.data+col.length;it++) {
for (const ColBatch::Entry* it = col.data; it != col.data + col.length;
it++) {
int inst_id = static_cast<int>(it->index);
fval.push_back(it->fvalue);
fId.push_back(inst_id);
@ -301,16 +301,17 @@ private:
vals.current_dvec() = fval;
instIds.current_dvec() = fId;
colOffsets = offset;
segmentedSort<float,int>(tmp_mem, vals, instIds, nVals, nCols, colOffsets);
segmentedSort<float, int>(tmp_mem, vals, instIds, nVals, nCols, colOffsets);
vals_cached = vals.current_dvec();
instIds_cached = instIds.current_dvec();
assignColIds<node_id_t><<<nCols,512>>>(colIds.data(), colOffsets.data());
assignColIds<node_id_t><<<nCols, 512>>>(colIds.data(), colOffsets.data());
}
void transferGrads(const std::vector<bst_gpair>& gpair) {
// HACK
dh::safe_cuda(cudaMemcpy(gradsInst.data(), &(gpair[0]),
sizeof(gpu_gpair)*nRows, cudaMemcpyHostToDevice));
sizeof(gpu_gpair) * nRows,
cudaMemcpyHostToDevice));
// evaluate the full-grad reduction for the root node
sumReduction<gpu_gpair>(tmp_mem, gradsInst, gradSums, nRows);
}
@ -324,25 +325,23 @@ private:
// for root node, just update the gradient/score/weight/id info
// before splitting it! Currently all data is on GPU, hence this
// stupid little kernel
initRootNode<<<1,1>>>(nodes.data(), gradSums.data(), param);
initRootNode<<<1, 1>>>(nodes.data(), gradSums.data(), param);
} else {
const int BlkDim = 256;
const int ItemsPerThread = 4;
// assign default node ids first
int nBlks = dh::div_round_up(nRows, BlkDim);
fillDefaultNodeIds<<<nBlks,BlkDim>>>(nodeAssignsPerInst.data(),
fillDefaultNodeIds<<<nBlks, BlkDim>>>(nodeAssignsPerInst.data(),
nodes.data(), nRows);
// evaluate the correct child indices of non-missing values next
nBlks = dh::div_round_up(nVals, BlkDim*ItemsPerThread);
assignNodeIds<<<nBlks,BlkDim>>>(nodeAssignsPerInst.data(),
nodeLocations.current(),
nodeAssigns.current(),
instIds.current(), nodes.data(),
colOffsets.data(), vals.current(),
nVals, nCols);
nBlks = dh::div_round_up(nVals, BlkDim * ItemsPerThread);
assignNodeIds<<<nBlks, BlkDim>>>(
nodeAssignsPerInst.data(), nodeLocations.current(),
nodeAssigns.current(), instIds.current(), nodes.data(),
colOffsets.data(), vals.current(), nVals, nCols);
// gather the node assignments across all other columns too
gather<node_id_t>(param.gpu_id, nodeAssigns.current(), nodeAssignsPerInst.data(),
instIds.current(), nVals);
gather<node_id_t>(dh::get_device_idx(param.gpu_id), nodeAssigns.current(),
nodeAssignsPerInst.data(), instIds.current(), nVals);
sortKeys(level);
}
}
@ -351,9 +350,10 @@ private:
// segmented-sort the arrays based on node-id's
// but we don't need more than level+1 bits for sorting!
segmentedSort(tmp_mem, nodeAssigns, nodeLocations, nVals, nCols, colOffsets,
0, level+1);
gather<float,int>(param.gpu_id, vals.other(), vals.current(), instIds.other(),
instIds.current(), nodeLocations.current(), nVals);
0, level + 1);
gather<float, int>(dh::get_device_idx(param.gpu_id), vals.other(),
vals.current(), instIds.other(), instIds.current(),
nodeLocations.current(), nVals);
vals.buff().selector ^= 1;
instIds.buff().selector ^= 1;
}
@ -361,11 +361,11 @@ private:
void markLeaves() {
const int BlkDim = 128;
int nBlks = dh::div_round_up(maxNodes, BlkDim);
markLeavesKernel<<<nBlks,BlkDim>>>(nodes.data(), maxNodes);
markLeavesKernel<<<nBlks, BlkDim>>>(nodes.data(), maxNodes);
}
void dense2sparse(RegTree &tree) {
std::vector<Node<node_id_t> > hNodes = nodes.as_vector();
void dense2sparse(RegTree& tree) {
std::vector<Node<node_id_t>> hNodes = nodes.as_vector();
int nodeId = 0;
for (int i = 0; i < maxNodes; ++i) {
const Node<node_id_t>& n = hNodes[i];
@ -375,7 +375,7 @@ private:
++nodeId;
} else if (!hNodes[i].isUnused()) {
tree.AddChilds(nodeId);
tree[nodeId].set_split(n.colIdx, n.threshold, n.dir==LeftDir);
tree[nodeId].set_split(n.colIdx, n.threshold, n.dir == LeftDir);
tree.stat(nodeId).loss_chg = n.score;
tree.stat(nodeId).sum_hess = n.gradSum.h;
tree.stat(nodeId).base_weight = n.weight;

View File

@ -1,5 +1,6 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights reserved.
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
* reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,7 +18,6 @@
#include "../common.cuh"
namespace xgboost {
namespace tree {
namespace exact {
@ -32,9 +32,9 @@ struct gpu_gpair {
/** the 'h_i' as it appears in the xgboost paper */
float h;
HOST_DEV_INLINE gpu_gpair(): g(0.f), h(0.f) {}
HOST_DEV_INLINE gpu_gpair(const float& _g, const float& _h): g(_g), h(_h) {}
HOST_DEV_INLINE gpu_gpair(const gpu_gpair& a): g(a.g), h(a.h) {}
HOST_DEV_INLINE gpu_gpair() : g(0.f), h(0.f) {}
HOST_DEV_INLINE gpu_gpair(const float& _g, const float& _h) : g(_g), h(_h) {}
HOST_DEV_INLINE gpu_gpair(const gpu_gpair& a) : g(a.g), h(a.h) {}
/**
* @brief Checks whether the hessian is more than the defined weight
@ -60,12 +60,12 @@ struct gpu_gpair {
HOST_DEV_INLINE friend gpu_gpair operator+(const gpu_gpair& a,
const gpu_gpair& b) {
return gpu_gpair(a.g+b.g, a.h+b.h);
return gpu_gpair(a.g + b.g, a.h + b.h);
}
HOST_DEV_INLINE friend gpu_gpair operator-(const gpu_gpair& a,
const gpu_gpair& b) {
return gpu_gpair(a.g-b.g, a.h-b.h);
return gpu_gpair(a.g - b.g, a.h - b.h);
}
HOST_DEV_INLINE gpu_gpair(int value) {
@ -73,7 +73,6 @@ struct gpu_gpair {
}
};
/**
* @brief Gradient value getter function
* @param id the index into the vals or instIds array to which to fetch
@ -81,7 +80,8 @@ struct gpu_gpair {
* @param instIds instance index buffer
* @return the expected gradient value
*/
HOST_DEV_INLINE gpu_gpair get(int id, const gpu_gpair* vals, const int* instIds) {
HOST_DEV_INLINE gpu_gpair get(int id, const gpu_gpair* vals,
const int* instIds) {
id = instIds[id];
return vals[id];
}

View File

@ -1,5 +1,6 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights reserved.
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
* reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,17 +19,13 @@
#include "../common.cuh"
#include "gradients.cuh"
namespace xgboost {
namespace tree {
namespace exact {
HOST_DEV_INLINE float device_calc_loss_chg(const TrainParam &param,
const gpu_gpair &scan,
const gpu_gpair &missing,
const gpu_gpair &parent_sum,
const float &parent_gain,
bool missing_left) {
HOST_DEV_INLINE float device_calc_loss_chg(
const TrainParam &param, const gpu_gpair &scan, const gpu_gpair &missing,
const gpu_gpair &parent_sum, const float &parent_gain, bool missing_left) {
gpu_gpair left = scan;
if (missing_left) {
left += missing;

View File

@ -1,5 +1,6 @@
/*
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights reserved.
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
* reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,9 +16,8 @@
*/
#pragma once
#include "gradients.cuh"
#include "../common.cuh"
#include "gradients.cuh"
namespace xgboost {
namespace tree {
@ -34,11 +34,9 @@ enum DefaultDirection {
RightDir
};
/** used to assign default id to a Node */
static const int UNUSED_NODE = -1;
/**
* @struct Split node.cuh
* @brief Abstraction of a possible split in the decision tree
@ -49,7 +47,7 @@ struct Split {
/** index where to split in the DMatrix */
int index;
HOST_DEV_INLINE Split(): score(-FLT_MAX), index(INT_MAX) {}
HOST_DEV_INLINE Split() : score(-FLT_MAX), index(INT_MAX) {}
/**
* @brief Whether the split info is valid to be used to create a new child
@ -61,7 +59,6 @@ struct Split {
}
};
/**
* @struct Node node.cuh
* @brief Abstraction of a node in the decision tree
@ -84,8 +81,13 @@ class Node {
/** node id (used as key for reduce/scan) */
node_id_t id;
HOST_DEV_INLINE Node(): gradSum(), score(-FLT_MAX), weight(-FLT_MAX),
dir(LeftDir), threshold(0.f), colIdx(UNUSED_NODE),
HOST_DEV_INLINE Node()
: gradSum(),
score(-FLT_MAX),
weight(-FLT_MAX),
dir(LeftDir),
threshold(0.f),
colIdx(UNUSED_NODE),
id(UNUSED_NODE) {}
/** Tells whether this node is part of the decision tree */
@ -100,7 +102,6 @@ class Node {
HOST_DEV_INLINE bool isDefaultLeft() const { return (dir == LeftDir); }
};
/**
* @struct Segment node.cuh
* @brief Space inefficient, but super easy to implement structure to define
@ -112,7 +113,7 @@ struct Segment {
/** end index of the segment */
int end;
HOST_DEV_INLINE Segment(): start(-1), end(-1) {}
HOST_DEV_INLINE Segment() : start(-1), end(-1) {}
/** Checks whether the current structure defines a valid segment */
HOST_DEV_INLINE bool isValid() const {
@ -120,7 +121,6 @@ struct Segment {
}
};
/**
* @enum NodeType node.cuh
* @brief Useful to decribe the node type in a dense BFS-order tree array
@ -134,7 +134,6 @@ enum NodeType {
UNUSED
};
/**
* @brief Absolute BFS order IDs to col-wise unique IDs based on user input
* @param tid the index of the element that this thread should access

View File

@ -17,9 +17,8 @@
#include "../../../../src/tree/param.h"
#include "gradients.cuh"
#include "node.cuh"
#include "loss_functions.cuh"
#include "node.cuh"
namespace xgboost {
namespace tree {
@ -39,7 +38,7 @@ namespace exact {
template <typename node_id_t>
DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
const gpu_gpair& grad,
const TrainParam &param) {
const TrainParam& param) {
nodes[nid].gradSum = grad;
nodes[nid].score = CalcGain(param, grad.g, grad.h);
nodes[nid].weight = CalcWeight(param, grad.g, grad.h);
@ -58,18 +57,18 @@ DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
template <typename node_id_t>
DEV_INLINE void updateChildNodes(Node<node_id_t>* nodes, int pid,
const gpu_gpair& gradL, const gpu_gpair& gradR,
const TrainParam &param) {
const TrainParam& param) {
int childId = (pid * 2) + 1;
updateOneChildNode(nodes, childId, gradL, param);
updateOneChildNode(nodes, childId+1, gradR, param);
updateOneChildNode(nodes, childId + 1, gradR, param);
}
template <typename node_id_t>
DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
const Node<node_id_t>& n, int absNodeId, int colId,
const gpu_gpair& gradScan,
const Node<node_id_t>& n, int absNodeId,
int colId, const gpu_gpair& gradScan,
const gpu_gpair& colSum, float thresh,
const TrainParam &param) {
const TrainParam& param) {
bool missingLeft = true;
// get the default direction for the current node
gpu_gpair missing = n.gradSum - colSum;
@ -84,19 +83,17 @@ DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
rGradSum = n.gradSum - lGradSum;
updateChildNodes(nodes, absNodeId, lGradSum, rGradSum, param);
// update default-dir, threshold and feature id for current node
nodes[absNodeId].dir = missingLeft? LeftDir : RightDir;
nodes[absNodeId].dir = missingLeft ? LeftDir : RightDir;
nodes[absNodeId].colIdx = colId;
nodes[absNodeId].threshold = thresh;
}
template <typename node_id_t, int BLKDIM=256>
__global__ void split2nodeKernel(Node<node_id_t>* nodes, const Split* nodeSplits,
const gpu_gpair* gradScans,
const gpu_gpair* gradSums, const float* vals,
const int* colIds, const int* colOffsets,
const node_id_t* nodeAssigns, int nUniqKeys,
node_id_t nodeStart, int nCols,
const TrainParam param) {
template <typename node_id_t, int BLKDIM = 256>
__global__ void split2nodeKernel(
Node<node_id_t>* nodes, const Split* nodeSplits, const gpu_gpair* gradScans,
const gpu_gpair* gradSums, const float* vals, const int* colIds,
const int* colOffsets, const node_id_t* nodeAssigns, int nUniqKeys,
node_id_t nodeStart, int nCols, const TrainParam param) {
int uid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (uid >= nUniqKeys) {
return;
@ -105,11 +102,11 @@ __global__ void split2nodeKernel(Node<node_id_t>* nodes, const Split* nodeSplits
Split s = nodeSplits[uid];
if (s.isSplittable(param.min_split_loss)) {
int idx = s.index;
int nodeInstId = abs2uniqKey(idx, nodeAssigns, colIds, nodeStart,
nUniqKeys);
updateNodeAndChildren(nodes, s, nodes[absNodeId], absNodeId,
colIds[idx], gradScans[idx],
gradSums[nodeInstId], vals[idx], param);
int nodeInstId =
abs2uniqKey(idx, nodeAssigns, colIds, nodeStart, nUniqKeys);
updateNodeAndChildren(nodes, s, nodes[absNodeId], absNodeId, colIds[idx],
gradScans[idx], gradSums[nodeInstId], vals[idx],
param);
} else {
// cannot be split further, so this node is a leaf!
nodes[absNodeId].score = -FLT_MAX;
@ -129,20 +126,20 @@ __global__ void split2nodeKernel(Node<node_id_t>* nodes, const Split* nodeSplits
* @param nUniqKeys number of nodes that we are currently working on
* @param nodeStart start offset of the nodes in the overall BFS tree
* @param nCols number of columns
* @param preUniquifiedKeys whether to uniquify the keys from inside kernel or not
* @param preUniquifiedKeys whether to uniquify the keys from inside kernel or
* not
* @param param the training parameter struct
*/
template <typename node_id_t, int BLKDIM=256>
void split2node(Node<node_id_t>* nodes, const Split* nodeSplits, const gpu_gpair* gradScans,
const gpu_gpair* gradSums, const float* vals, const int* colIds,
const int* colOffsets, const node_id_t* nodeAssigns,
int nUniqKeys, node_id_t nodeStart, int nCols,
const TrainParam param) {
template <typename node_id_t, int BLKDIM = 256>
void split2node(Node<node_id_t>* nodes, const Split* nodeSplits,
const gpu_gpair* gradScans, const gpu_gpair* gradSums,
const float* vals, const int* colIds, const int* colOffsets,
const node_id_t* nodeAssigns, int nUniqKeys,
node_id_t nodeStart, int nCols, const TrainParam param) {
int nBlks = dh::div_round_up(nUniqKeys, BLKDIM);
split2nodeKernel<<<nBlks,BLKDIM>>>(nodes, nodeSplits, gradScans, gradSums,
split2nodeKernel<<<nBlks, BLKDIM>>>(nodes, nodeSplits, gradScans, gradSums,
vals, colIds, colOffsets, nodeAssigns,
nUniqKeys, nodeStart, nCols,
param);
nUniqKeys, nodeStart, nCols, param);
}
} // namespace exact

View File

@ -73,11 +73,12 @@ struct GPUData {
n_features, foffsets.data(), foffsets.data() + 1);
// Allocate memory
size_t free_memory = dh::available_memory(param_in.gpu_id);
ba.allocate(param_in.gpu_id,
&fvalues, in_fvalues.size(), &fvalues_temp,
in_fvalues.size(), &fvalues_cached, in_fvalues.size(), &foffsets,
in_foffsets.size(), &instance_id, in_instance_id.size(),
size_t free_memory =
dh::available_memory(dh::get_device_idx(param_in.gpu_id));
ba.allocate(
dh::get_device_idx(param_in.gpu_id), &fvalues, in_fvalues.size(),
&fvalues_temp, in_fvalues.size(), &fvalues_cached, in_fvalues.size(),
&foffsets, in_foffsets.size(), &instance_id, in_instance_id.size(),
&instance_id_temp, in_instance_id.size(), &instance_id_cached,
in_instance_id.size(), &feature_id, in_feature_id.size(), &node_id,
in_fvalues.size(), &node_id_temp, in_fvalues.size(), &node_id_instance,
@ -91,7 +92,7 @@ struct GPUData {
const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/"
<< free_memory / mb_size << " MB on "
<< dh::device_name(param_in.gpu_id);
<< dh::device_name(dh::get_device_idx(param_in.gpu_id));
}
fvalues_cached = in_fvalues;

View File

@ -125,7 +125,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
// set dList member
dList.resize(n_devices);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
int device_idx = (param.gpu_id + d_idx) % n_devices;
int device_idx = (param.gpu_id + d_idx) % dh::n_visible_devices();
dList[d_idx] = device_idx;
}
@ -141,7 +141,8 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
// printf("# NCCL: Using devices\n");
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
streams[d_idx] = reinterpret_cast<cudaStream_t*>(malloc(sizeof(cudaStream_t)));
streams[d_idx] =
reinterpret_cast<cudaStream_t*>(malloc(sizeof(cudaStream_t)));
dh::safe_cuda(cudaSetDevice(dList[d_idx]));
dh::safe_cuda(cudaStreamCreate(streams[d_idx]));
@ -159,7 +160,8 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
// local find_split group of comms for each case of reduced number of GPUs
// to use
find_split_comms.resize(
n_devices, std::vector<ncclComm_t>(n_devices)); // TODO(JCM): Excessive, but
n_devices,
std::vector<ncclComm_t>(n_devices)); // TODO(JCM): Excessive, but
// ok, and best to do
// here instead of
// repeatedly
@ -377,7 +379,8 @@ void GPUHistBuilder::BuildHist(int depth) {
#if (NCCL)
// (in-place) reduce each element of histogram (for only current level) across
// multiple gpus
// TODO(JCM): use out of place with pre-allocated buffer, but then have to copy
// TODO(JCM): use out of place with pre-allocated buffer, but then have to
// copy
// back on device
// fprintf(stderr,"sizeof(gpu_gpair)/sizeof(float)=%d\n",sizeof(gpu_gpair)/sizeof(float));
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
@ -621,17 +624,28 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
bool colsample =
param.colsample_bylevel < 1.0 || param.colsample_bytree < 1.0;
int dosimuljob = 1;
#if (NCCL)
int simuljob = 1; // whether to do job on single GPU and broadcast (0) or to
// do same job on each GPU (1) (could make user parameter,
// but too fine-grained maybe)
int findsplit_shardongpus = 0; // too expensive generally, disable for now
if (NCCL && findsplit_shardongpus) {
dosimuljob = 0;
// use power of 2 for split finder because nodes are power of 2 (broadcast
// result to remaining devices)
int find_split_n_devices = std::pow(2, std::floor(std::log2(n_devices)));
find_split_n_devices = std::min(n_nodes_level(depth), find_split_n_devices);
int num_nodes_device = n_nodes_level(depth) / find_split_n_devices;
int num_nodes_child_device = n_nodes_level(depth + 1) / find_split_n_devices;
int num_nodes_child_device =
n_nodes_level(depth + 1) / find_split_n_devices;
const int GRID_SIZE = num_nodes_device;
#if (NCCL)
// NOTE: No need to scatter before gather as all devices have same copy of
// nodes, and within find_split_kernel() nodes_temp is given values from nodes
// nodes, and within find_split_kernel() nodes_temp is given values from
// nodes
// for all nodes (split among devices) find best split per node
for (int d_idx = 0; d_idx < find_split_n_devices; d_idx++) {
@ -642,10 +656,9 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
(const gpu_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(),
nodes_temp[d_idx].data(), nodes_child_temp[d_idx].data(),
nodes_offset_device, fidx_min_map[d_idx].data(),
gidx_fvalue_map[d_idx].data(), gpu_param,
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
nodes_child_temp[d_idx].data(), nodes_offset_device,
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), gpu_param,
left_child_smallest_temp[d_idx].data(), colsample,
feature_flags[d_idx].data());
}
@ -661,14 +674,15 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
reinterpret_cast<const void*>(nodes_temp[d_idx].data()),
num_nodes_device * sizeof(Node) / sizeof(char), ncclChar,
reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth - 1)),
find_split_comms[find_split_n_devices - 1][d_idx], *(streams[d_idx])));
find_split_comms[find_split_n_devices - 1][d_idx],
*(streams[d_idx])));
if (depth !=
param.max_depth) { // don't copy over children nodes if no more nodes
dh::safe_nccl(
ncclAllGather(reinterpret_cast<const void*>(nodes_child_temp[d_idx].data()),
num_nodes_child_device * sizeof(Node) / sizeof(char),
ncclChar, reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth)),
dh::safe_nccl(ncclAllGather(
reinterpret_cast<const void*>(nodes_child_temp[d_idx].data()),
num_nodes_child_device * sizeof(Node) / sizeof(char), ncclChar,
reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth)),
find_split_comms[find_split_n_devices - 1][d_idx],
*(streams[d_idx]))); // Note offset by n_nodes(depth)
// for recvbuff for child nodes
@ -677,8 +691,10 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
dh::safe_nccl(ncclAllGather(
reinterpret_cast<const void*>(left_child_smallest_temp[d_idx].data()),
num_nodes_device * sizeof(bool) / sizeof(char), ncclChar,
reinterpret_cast<void*>(left_child_smallest[d_idx].data() + n_nodes(depth - 1)),
find_split_comms[find_split_n_devices - 1][d_idx], *(streams[d_idx])));
reinterpret_cast<void*>(left_child_smallest[d_idx].data() +
n_nodes(depth - 1)),
find_split_comms[find_split_n_devices - 1][d_idx],
*(streams[d_idx])));
}
for (int d_idx = 0; d_idx < find_split_n_devices; d_idx++) {
@ -689,20 +705,21 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
if (n_devices > find_split_n_devices && n_devices > 1) {
// if n_devices==1, no need to Bcast
// if find_split_n_devices==1, this is just a copy operation, else it copies
// if find_split_n_devices==1, this is just a copy operation, else it
// copies
// from master to all nodes in case extra devices not involved in split
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
int master_device = dList[0];
dh::safe_nccl(
ncclBcast(reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth - 1)),
n_nodes_level(depth) * sizeof(Node) / sizeof(char),
ncclChar, master_device, comms[d_idx], *(streams[d_idx])));
dh::safe_nccl(ncclBcast(
reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth - 1)),
n_nodes_level(depth) * sizeof(Node) / sizeof(char), ncclChar,
master_device, comms[d_idx], *(streams[d_idx])));
if (depth !=
param.max_depth) { // don't copy over children nodes if no more nodes
if (depth != param.max_depth) { // don't copy over children nodes if no
// more nodes
dh::safe_nccl(ncclBcast(
reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth)),
n_nodes_level(depth + 1) * sizeof(Node) / sizeof(char), ncclChar,
@ -710,7 +727,8 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
}
dh::safe_nccl(ncclBcast(
reinterpret_cast<void*>(left_child_smallest[d_idx].data() + n_nodes(depth - 1)),
reinterpret_cast<void*>(left_child_smallest[d_idx].data() +
n_nodes(depth - 1)),
n_nodes_level(depth) * sizeof(bool) / sizeof(char), ncclChar,
master_device, comms[d_idx], *(streams[d_idx])));
}
@ -721,10 +739,13 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
}
}
} else if (simuljob == 0 && NCCL == 1) {
dosimuljob = 0;
int num_nodes_device = n_nodes_level(depth);
const int GRID_SIZE = num_nodes_device;
#else
{
int d_idx = 0;
int master_device = dList[d_idx];
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
@ -737,9 +758,63 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
gidx_fvalue_map[d_idx].data(), gpu_param,
left_child_smallest[d_idx].data(), colsample,
feature_flags[d_idx].data());
// broadcast result
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
dh::safe_nccl(ncclBcast(
reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth - 1)),
n_nodes_level(depth) * sizeof(Node) / sizeof(char), ncclChar,
master_device, comms[d_idx], *(streams[d_idx])));
if (depth !=
param.max_depth) { // don't copy over children nodes if no more nodes
dh::safe_nccl(ncclBcast(
reinterpret_cast<void*>(nodes[d_idx].data() + n_nodes(depth)),
n_nodes_level(depth + 1) * sizeof(Node) / sizeof(char), ncclChar,
master_device, comms[d_idx], *(streams[d_idx])));
}
dh::safe_nccl(
ncclBcast(reinterpret_cast<void*>(left_child_smallest[d_idx].data() +
n_nodes(depth - 1)),
n_nodes_level(depth) * sizeof(bool) / sizeof(char),
ncclChar, master_device, comms[d_idx], *(streams[d_idx])));
}
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
}
} else {
dosimuljob = 1;
}
#endif
if (dosimuljob) { // if no NCCL or simuljob==1, do this
int num_nodes_device = n_nodes_level(depth);
const int GRID_SIZE = num_nodes_device;
// all GPUs do same work
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
int nodes_offset_device = 0;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
(const gpu_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
feature_segments[d_idx].data(), depth, (info->num_col),
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
nodes_offset_device, fidx_min_map[d_idx].data(),
gidx_fvalue_map[d_idx].data(), gpu_param,
left_child_smallest[d_idx].data(), colsample,
feature_flags[d_idx].data());
}
}
// NOTE: No need to syncrhonize with host as all above pure P2P ops or
// on-device ops
}
@ -776,17 +851,15 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
std::vector<std::future<gpu_gpair>> future_results(n_devices);
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
// std::async captures the algorithm parameters by value
// use std::launch::async to ensure the creation of a new thread
future_results[d_idx] = std::async(std::launch::async, [=] {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
auto begin = device_gpair[d_idx].tbegin();
auto end = device_gpair[d_idx].tend();
gpu_gpair init = gpu_gpair();
auto binary_op = thrust::plus<gpu_gpair>();
// std::async captures the algorithm parameters by value
// use std::launch::async to ensure the creation of a new thread
future_results[d_idx] = std::async(std::launch::async, [=] {
dh::safe_cuda(cudaSetDevice(device_idx));
return thrust::reduce(begin, end, init, binary_op);
});
}
@ -1047,8 +1120,8 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
// done with multi-GPU, pass back result from master to tree on host
int master_device = dList[0];
dense2sparse_tree(p_tree, nodes[master_device].tbegin(),
nodes[master_device].tend(), param);
dh::safe_cuda(cudaSetDevice(master_device));
dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param);
}
} // namespace tree
} // namespace xgboost

View File

@ -157,11 +157,11 @@ class ColMaker: public TreeUpdater {
feat_index.push_back(i);
}
}
unsigned n = static_cast<unsigned>(param.colsample_bytree * feat_index.size());
unsigned n = std::max(static_cast<unsigned>(1),
static_cast<unsigned>(param.colsample_bytree * feat_index.size()));
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
CHECK_GT(n, 0U)
<< "colsample_bytree=" << param.colsample_bytree
<< " is too small that no feature can be included";
CHECK_GT(param.colsample_bytree, 0U)
<< "colsample_bytree cannot be zero.";
feat_index.resize(n);
}
{
@ -627,9 +627,10 @@ class ColMaker: public TreeUpdater {
std::vector<bst_uint> feat_set = feat_index;
if (param.colsample_bylevel != 1.0f) {
std::shuffle(feat_set.begin(), feat_set.end(), common::GlobalRandom());
unsigned n = static_cast<unsigned>(param.colsample_bylevel * feat_index.size());
CHECK_GT(n, 0U)
<< "colsample_bylevel is too small that no feature can be included";
unsigned n = std::max(static_cast<unsigned>(1),
static_cast<unsigned>(param.colsample_bylevel * feat_index.size()));
CHECK_GT(param.colsample_bylevel, 0U)
<< "colsample_bylevel cannot be zero.";
feat_set.resize(n);
}
dmlc::DataIter<ColBatch>* iter = p_fmat->ColIterator(feat_set);

View File

@ -409,11 +409,11 @@ class FastHistMaker: public TreeUpdater {
feat_index.push_back(i);
}
}
unsigned n = static_cast<unsigned>(param.colsample_bytree * feat_index.size());
unsigned n = std::max(static_cast<unsigned>(1),
static_cast<unsigned>(param.colsample_bytree * feat_index.size()));
std::shuffle(feat_index.begin(), feat_index.end(), common::GlobalRandom());
CHECK_GT(n, 0U)
<< "colsample_bytree=" << param.colsample_bytree
<< " is too small that no feature can be included";
CHECK_GT(param.colsample_bytree, 0U)
<< "colsample_bytree cannot be zero.";
feat_index.resize(n);
}
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {