Various bug fixes (#2825)
* Fatal error if GPU algorithm selected without GPU support compiled * Resolve type conversion warnings * Fix gpu unit test failure * Fix compressed iterator edge case * Fix python unit test failures due to flake8 update on pip
This commit is contained in:
parent
c71b62d48d
commit
13e7a2cff0
@ -171,19 +171,19 @@ class bst_gpair_internal {
|
||||
|
||||
template<>
|
||||
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetGrad() const {
|
||||
return grad_ * 1e-5;
|
||||
return grad_ * 1e-5f;
|
||||
}
|
||||
template<>
|
||||
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetHess() const {
|
||||
return hess_ * 1e-5;
|
||||
return hess_ * 1e-5f;
|
||||
}
|
||||
template<>
|
||||
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetGrad(float g) {
|
||||
grad_ = std::round(g * 1e5);
|
||||
grad_ = static_cast<int64_t>(std::round(g * 1e5));
|
||||
}
|
||||
template<>
|
||||
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetHess(float h) {
|
||||
hess_ = std::round(h * 1e5);
|
||||
hess_ = static_cast<int64_t>(std::round(h * 1e5));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -328,7 +328,7 @@ class DMatrix {
|
||||
|
||||
// implementation of inline functions
|
||||
inline bst_uint RowSet::operator[](size_t i) const {
|
||||
return rows_.size() == 0 ? i : rows_[i];
|
||||
return rows_.size() == 0 ? static_cast<bst_uint>(i) : rows_[i];
|
||||
}
|
||||
|
||||
inline size_t RowSet::size() const {
|
||||
|
||||
@ -651,7 +651,7 @@ inline void ExtendPath(PathElement *unique_path, unsigned unique_depth,
|
||||
unique_path[unique_depth].feature_index = feature_index;
|
||||
unique_path[unique_depth].zero_fraction = zero_fraction;
|
||||
unique_path[unique_depth].one_fraction = one_fraction;
|
||||
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
|
||||
unique_path[unique_depth].pweight = static_cast<bst_float>(unique_depth == 0 ? 1 : 0);
|
||||
for (int i = unique_depth-1; i >= 0; i--) {
|
||||
unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*(i+1)
|
||||
/ static_cast<bst_float>(unique_depth+1);
|
||||
@ -679,7 +679,7 @@ inline void UnwindPath(PathElement *unique_path, unsigned unique_depth, unsigned
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = path_index; i < unique_depth; ++i) {
|
||||
for (auto i = path_index; i < unique_depth; ++i) {
|
||||
unique_path[i].feature_index = unique_path[i+1].feature_index;
|
||||
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
|
||||
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
|
||||
@ -725,7 +725,7 @@ inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi,
|
||||
|
||||
// leaf node
|
||||
if (node.is_leaf()) {
|
||||
for (int i = 1; i <= unique_depth; ++i) {
|
||||
for (unsigned i = 1; i <= unique_depth; ++i) {
|
||||
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
|
||||
const PathElement &el = unique_path[i];
|
||||
phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*node.leaf_value();
|
||||
@ -775,7 +775,7 @@ inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned
|
||||
// find the expected value of the tree's predictions
|
||||
bst_float base_value = 0.0;
|
||||
bst_float total_cover = 0;
|
||||
for (unsigned i = 0; i < (*this).param.num_nodes; ++i) {
|
||||
for (int i = 0; i < (*this).param.num_nodes; ++i) {
|
||||
const auto node = (*this)[i];
|
||||
if (node.is_leaf()) {
|
||||
const auto cover = this->stat(i).sum_hess;
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
#include <xgboost/base.h>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include "dmlc/logging.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -28,8 +28,9 @@ static const int padding = 4; // Assign padding so we can read slightly off
|
||||
// the beginning of the array
|
||||
|
||||
// The number of bits required to represent a given unsigned range
|
||||
static int SymbolBits(int num_symbols) {
|
||||
return std::ceil(std::log2(num_symbols));
|
||||
static size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(std::log2(num_symbols));
|
||||
return std::max(static_cast<size_t>(bits), size_t(1));
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
@ -72,9 +73,9 @@ class CompressedBufferWriter {
|
||||
|
||||
static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) {
|
||||
const int bits_per_byte = 8;
|
||||
size_t compressed_size = std::ceil(
|
||||
size_t compressed_size = static_cast<size_t>(std::ceil(
|
||||
static_cast<double>(detail::SymbolBits(num_symbols) * num_elements) /
|
||||
bits_per_byte);
|
||||
bits_per_byte));
|
||||
return compressed_size + detail::padding;
|
||||
}
|
||||
|
||||
@ -98,8 +99,8 @@ class CompressedBufferWriter {
|
||||
template <typename iter_t>
|
||||
void Write(compressed_byte_t *buffer, iter_t input_begin, iter_t input_end) {
|
||||
uint64_t tmp = 0;
|
||||
int stored_bits = 0;
|
||||
const int max_stored_bits = 64 - symbol_bits_;
|
||||
size_t stored_bits = 0;
|
||||
const size_t max_stored_bits = 64 - symbol_bits_;
|
||||
size_t buffer_position = detail::padding;
|
||||
const size_t num_symbols = input_end - input_begin;
|
||||
for (size_t i = 0; i < num_symbols; i++) {
|
||||
@ -108,7 +109,8 @@ class CompressedBufferWriter {
|
||||
// Eject only full bytes
|
||||
size_t tmp_bytes = stored_bits / 8;
|
||||
for (size_t j = 0; j < tmp_bytes; j++) {
|
||||
buffer[buffer_position] = tmp >> (stored_bits - (j + 1) * 8);
|
||||
buffer[buffer_position] = static_cast<compressed_byte_t>(
|
||||
tmp >> (stored_bits - (j + 1) * 8));
|
||||
buffer_position++;
|
||||
}
|
||||
stored_bits -= tmp_bytes * 8;
|
||||
@ -121,13 +123,16 @@ class CompressedBufferWriter {
|
||||
}
|
||||
|
||||
// Eject all bytes
|
||||
size_t tmp_bytes = std::ceil(static_cast<float>(stored_bits) / 8);
|
||||
for (size_t j = 0; j < tmp_bytes; j++) {
|
||||
int shift_bits = stored_bits - (j + 1) * 8;
|
||||
int tmp_bytes =
|
||||
static_cast<int>(std::ceil(static_cast<float>(stored_bits) / 8));
|
||||
for (int j = 0; j < tmp_bytes; j++) {
|
||||
int shift_bits = static_cast<int>(stored_bits) - (j + 1) * 8;
|
||||
if (shift_bits >= 0) {
|
||||
buffer[buffer_position] = tmp >> shift_bits;
|
||||
buffer[buffer_position] =
|
||||
static_cast<compressed_byte_t>(tmp >> shift_bits);
|
||||
} else {
|
||||
buffer[buffer_position] = tmp << std::abs(shift_bits);
|
||||
buffer[buffer_position] =
|
||||
static_cast<compressed_byte_t>(tmp << std::abs(shift_bits));
|
||||
}
|
||||
buffer_position++;
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ inline size_t available_memory(int device_idx) {
|
||||
* \param device_idx Zero-based index of the device.
|
||||
*/
|
||||
|
||||
inline int max_shared_memory(int device_idx) {
|
||||
inline size_t max_shared_memory(int device_idx) {
|
||||
cudaDeviceProp prop;
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prop, device_idx));
|
||||
return prop.sharedMemPerBlock;
|
||||
@ -241,8 +241,7 @@ inline void launch_n(int device_idx, size_t n, L lambda) {
|
||||
}
|
||||
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
// TODO: Template on n so GRID_SIZE always fits into int.
|
||||
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
||||
const int GRID_SIZE = static_cast<int>(div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
||||
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
|
||||
lambda);
|
||||
}
|
||||
@ -428,74 +427,66 @@ class bulk_allocator {
|
||||
|
||||
const int align = 256;
|
||||
|
||||
template <typename SizeT>
|
||||
size_t align_round_up(SizeT n) {
|
||||
size_t align_round_up(size_t n) const {
|
||||
n = (n + align - 1) / align;
|
||||
return n * align;
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT>
|
||||
size_t get_size_bytes(dvec<T> *first_vec, SizeT first_size) {
|
||||
return align_round_up<SizeT>(first_size * sizeof(T));
|
||||
template <typename T>
|
||||
size_t get_size_bytes(dvec<T> *first_vec, size_t first_size) {
|
||||
return align_round_up(first_size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT, typename... Args>
|
||||
size_t get_size_bytes(dvec<T> *first_vec, SizeT first_size, Args... args) {
|
||||
return get_size_bytes<T, SizeT>(first_vec, first_size) +
|
||||
get_size_bytes(args...);
|
||||
template <typename T, typename... Args>
|
||||
size_t get_size_bytes(dvec<T> *first_vec, size_t first_size, Args... args) {
|
||||
return get_size_bytes<T>(first_vec, first_size) + get_size_bytes(args...);
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT>
|
||||
template <typename T>
|
||||
void allocate_dvec(int device_idx, char *ptr, dvec<T> *first_vec,
|
||||
SizeT first_size) {
|
||||
size_t first_size) {
|
||||
first_vec->external_allocate(device_idx, static_cast<void *>(ptr),
|
||||
first_size);
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT, typename... Args>
|
||||
template <typename T, typename... Args>
|
||||
void allocate_dvec(int device_idx, char *ptr, dvec<T> *first_vec,
|
||||
SizeT first_size, Args... args) {
|
||||
first_vec->external_allocate(device_idx, static_cast<void *>(ptr),
|
||||
first_size);
|
||||
size_t first_size, Args... args) {
|
||||
allocate_dvec<T>(device_idx, ptr, first_vec, first_size);
|
||||
ptr += align_round_up(first_size * sizeof(T));
|
||||
allocate_dvec(device_idx, ptr, args...);
|
||||
}
|
||||
|
||||
// template <memory_type MemoryT>
|
||||
char *allocate_device(int device_idx, size_t bytes, memory_type t) {
|
||||
char *ptr;
|
||||
if (t == memory_type::DEVICE) {
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
safe_cuda(cudaMalloc(&ptr, bytes));
|
||||
} else {
|
||||
safe_cuda(cudaMallocManaged(&ptr, bytes));
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
template <typename T, typename SizeT>
|
||||
size_t get_size_bytes(dvec2<T> *first_vec, SizeT first_size) {
|
||||
template <typename T>
|
||||
size_t get_size_bytes(dvec2<T> *first_vec, size_t first_size) {
|
||||
return 2 * align_round_up(first_size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT, typename... Args>
|
||||
size_t get_size_bytes(dvec2<T> *first_vec, SizeT first_size, Args... args) {
|
||||
return get_size_bytes<T, SizeT>(first_vec, first_size) +
|
||||
template <typename T, typename... Args>
|
||||
size_t get_size_bytes(dvec2<T> *first_vec, size_t first_size, Args... args) {
|
||||
return get_size_bytes<T>(first_vec, first_size) +
|
||||
get_size_bytes(args...);
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT>
|
||||
template <typename T>
|
||||
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec,
|
||||
SizeT first_size) {
|
||||
size_t first_size) {
|
||||
first_vec->external_allocate(
|
||||
device_idx, static_cast<void *>(ptr),
|
||||
static_cast<void *>(ptr + align_round_up(first_size * sizeof(T))),
|
||||
first_size);
|
||||
}
|
||||
|
||||
template <typename T, typename SizeT, typename... Args>
|
||||
template <typename T, typename... Args>
|
||||
void allocate_dvec(int device_idx, char *ptr, dvec2<T> *first_vec,
|
||||
SizeT first_size, Args... args) {
|
||||
allocate_dvec<T, SizeT>(device_idx, ptr, first_vec, first_size);
|
||||
size_t first_size, Args... args) {
|
||||
allocate_dvec<T>(device_idx, ptr, first_vec, first_size);
|
||||
ptr += (align_round_up(first_size * sizeof(T)) * 2);
|
||||
allocate_dvec(device_idx, ptr, args...);
|
||||
}
|
||||
@ -544,13 +535,12 @@ struct CubMemory {
|
||||
// Thrust
|
||||
typedef char value_type;
|
||||
|
||||
CubMemory() : d_temp_storage(NULL), temp_storage_bytes(0) {}
|
||||
CubMemory() : d_temp_storage(nullptr), temp_storage_bytes(0) {}
|
||||
|
||||
~CubMemory() { Free(); }
|
||||
|
||||
template <typename T>
|
||||
T* Pointer()
|
||||
{
|
||||
T *Pointer() {
|
||||
return static_cast<T *>(d_temp_storage);
|
||||
}
|
||||
|
||||
@ -611,7 +601,7 @@ void print(const dvec<T> &v, size_t max_items = 10) {
|
||||
|
||||
template <typename coordinate_t, typename segments_t, typename offset_t>
|
||||
void FindMergePartitions(int device_idx, coordinate_t *d_tile_coordinates,
|
||||
int num_tiles, int tile_size, segments_t segments,
|
||||
size_t num_tiles, int tile_size, segments_t segments,
|
||||
offset_t num_rows, offset_t num_elements) {
|
||||
dh::launch_n(device_idx, num_tiles + 1, [=] __device__(int idx) {
|
||||
offset_t diagonal = idx * tile_size;
|
||||
@ -692,7 +682,8 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
|
||||
const int BLOCK_THREADS = 256;
|
||||
const int ITEMS_PER_THREAD = 1;
|
||||
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||
int num_tiles = dh::div_round_up(count + num_segments, BLOCK_THREADS);
|
||||
auto num_tiles = dh::div_round_up(count + num_segments, BLOCK_THREADS);
|
||||
CHECK(num_tiles < std::numeric_limits<unsigned int>::max());
|
||||
|
||||
temp_memory->LazyAllocate(sizeof(coordinate_t) * (num_tiles + 1));
|
||||
coordinate_t *tmp_tile_coordinates =
|
||||
@ -702,7 +693,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
|
||||
BLOCK_THREADS, segments, num_segments, count);
|
||||
|
||||
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, offset_t>
|
||||
<<<num_tiles, BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f,
|
||||
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f,
|
||||
num_segments);
|
||||
}
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
|
||||
const int nthread = omp_get_max_threads();
|
||||
|
||||
unsigned nstep = (info.num_col + nthread - 1) / nthread;
|
||||
unsigned nstep = static_cast<unsigned>((info.num_col + nthread - 1) / nthread);
|
||||
unsigned ncol = static_cast<unsigned>(info.num_col);
|
||||
sketchs.resize(info.num_col);
|
||||
for (auto& s : sketchs) {
|
||||
@ -79,7 +79,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
if (a.size > 1 && a.size <= 16) {
|
||||
/* specialized code categorial / ordinal data -- use midpoints */
|
||||
for (size_t i = 1; i < a.size; ++i) {
|
||||
bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0;
|
||||
bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0f;
|
||||
if (i == 1 || cpt > cut.back()) {
|
||||
cut.push_back(cpt);
|
||||
}
|
||||
@ -99,7 +99,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
bst_float last = cpt + fabs(cpt);
|
||||
cut.push_back(last);
|
||||
}
|
||||
row_ptr.push_back(cut.size());
|
||||
row_ptr.push_back(static_cast<bst_uint>(cut.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint idx = 0; idx < nbins; ++idx) {
|
||||
for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) {
|
||||
for (int tid = 0; tid < nthread; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
@ -226,7 +226,7 @@ FindGroups_(const std::vector<unsigned>& feature_list,
|
||||
bool need_new_group = true;
|
||||
|
||||
// randomly choose some of existing groups as candidates
|
||||
std::vector<unsigned> search_groups;
|
||||
std::vector<size_t> search_groups;
|
||||
for (size_t gid = 0; gid < groups.size(); ++gid) {
|
||||
if (group_nnz[gid] + cur_fid_nnz <= nrow + max_conflict_cnt) {
|
||||
search_groups.push_back(gid);
|
||||
@ -434,7 +434,7 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
||||
const size_t rid = row_indices.begin[i];
|
||||
const size_t ibegin = gmat.row_ptr[rid];
|
||||
const size_t iend = gmat.row_ptr[rid + 1];
|
||||
@ -448,7 +448,7 @@ void GHistBuilder::BuildHist(const std::vector<bst_gpair>& gpair,
|
||||
/* reduction */
|
||||
const uint32_t nbins = nbins_;
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint bin_id = 0; bin_id < nbins; ++bin_id) {
|
||||
for (bst_omp_uint bin_id = 0; bin_id < bst_omp_uint(nbins); ++bin_id) {
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
hist.begin[bin_id].Add(data_[tid * nbins_ + bin_id]);
|
||||
}
|
||||
@ -462,7 +462,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
|
||||
GHistRow hist) {
|
||||
const int K = 8; // loop unrolling factor
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
const uint32_t nblock = gmatb.GetNumBlock();
|
||||
const size_t nblock = gmatb.GetNumBlock();
|
||||
const size_t nrows = row_indices.end - row_indices.begin;
|
||||
const size_t rest = nrows % K;
|
||||
|
||||
@ -492,7 +492,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<bst_gpair>& gpair,
|
||||
}
|
||||
}
|
||||
}
|
||||
for (bst_omp_uint i = nrows - rest; i < nrows; ++i) {
|
||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
||||
const size_t rid = row_indices.begin[i];
|
||||
const size_t ibegin = gmat.row_ptr[rid];
|
||||
const size_t iend = gmat.row_ptr[rid + 1];
|
||||
@ -511,7 +511,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
|
||||
const int K = 8; // loop unrolling factor
|
||||
const uint32_t rest = nbins % K;
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint bin_id = 0; bin_id < nbins - rest; bin_id += K) {
|
||||
for (bst_omp_uint bin_id = 0; bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += K) {
|
||||
GHistEntry pb[K];
|
||||
GHistEntry sb[K];
|
||||
for (int k = 0; k < K; ++k) {
|
||||
|
||||
@ -118,11 +118,11 @@ struct GHistIndexMatrix {
|
||||
return GHistIndexRow(&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]);
|
||||
}
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
const unsigned nfeature = cut->row_ptr.size() - 1;
|
||||
auto nfeature = cut->row_ptr.size() - 1;
|
||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||
const unsigned ibegin = cut->row_ptr[fid];
|
||||
const unsigned iend = cut->row_ptr[fid + 1];
|
||||
for (unsigned i = ibegin; i < iend; ++i) {
|
||||
auto ibegin = cut->row_ptr[fid];
|
||||
auto iend = cut->row_ptr[fid + 1];
|
||||
for (auto i = ibegin; i < iend; ++i) {
|
||||
counts[fid] += hit_count[i];
|
||||
}
|
||||
}
|
||||
@ -235,7 +235,7 @@ class HistCollection {
|
||||
std::vector<GHistEntry> data_;
|
||||
|
||||
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
|
||||
std::vector<uint32_t> row_ptr_;
|
||||
std::vector<size_t> row_ptr_;
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@ -680,12 +680,12 @@ class QuantileSketchTemplate {
|
||||
nlevel = 1;
|
||||
while (true) {
|
||||
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
|
||||
size_t n = (1UL << nlevel);
|
||||
size_t n = (1ULL << nlevel);
|
||||
if (n * limit_size >= maxn) break;
|
||||
++nlevel;
|
||||
}
|
||||
// check invariant
|
||||
size_t n = (1UL << nlevel);
|
||||
size_t n = (1ULL << nlevel);
|
||||
CHECK(n * limit_size >= maxn) << "invalid init parameter";
|
||||
CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
|
||||
// lazy reserve the space, if there is only one value, no need to allocate space
|
||||
|
||||
@ -88,7 +88,7 @@ class RowSetCollection {
|
||||
unsigned left_node_id,
|
||||
unsigned right_node_id) {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
const unsigned nthread = row_split_tloc.size();
|
||||
const bst_omp_uint nthread = static_cast<bst_omp_uint>(row_split_tloc.size());
|
||||
CHECK(e.begin != nullptr);
|
||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
size_t* begin = all_begin + (e.begin - all_begin);
|
||||
|
||||
@ -120,7 +120,7 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||
double tstart = dmlc::GetTime();
|
||||
// print every 4 sec.
|
||||
const double kStep = 4.0;
|
||||
size_t tick_expected = kStep;
|
||||
size_t tick_expected = static_cast<double>(kStep);
|
||||
|
||||
while (src->Next()) {
|
||||
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
||||
@ -149,7 +149,7 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||
LOG(CONSOLE) << "Writing row.page to " << cache_info << " in "
|
||||
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||
<< (bytes_write >> 20UL) << " written";
|
||||
tick_expected += kStep;
|
||||
tick_expected += static_cast<size_t>(kStep);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,6 +146,12 @@ class LearnerImpl : public Learner {
|
||||
name_gbm_ = "gbtree";
|
||||
}
|
||||
|
||||
static void AssertGPUSupport() {
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
||||
#endif
|
||||
}
|
||||
|
||||
void ConfigureUpdaters() {
|
||||
if (tparam.tree_method == 0 || tparam.tree_method == 1 ||
|
||||
tparam.tree_method == 2) {
|
||||
@ -166,6 +172,7 @@ class LearnerImpl : public Learner {
|
||||
<< "grow_fast_histmaker.";
|
||||
cfg_["updater"] = "grow_fast_histmaker";
|
||||
} else if (tparam.tree_method == 4) {
|
||||
this->AssertGPUSupport();
|
||||
if (cfg_.count("updater") == 0) {
|
||||
cfg_["updater"] = "grow_gpu,prune";
|
||||
}
|
||||
@ -173,6 +180,7 @@ class LearnerImpl : public Learner {
|
||||
cfg_["predictor"] = "gpu_predictor";
|
||||
}
|
||||
} else if (tparam.tree_method == 5) {
|
||||
this->AssertGPUSupport();
|
||||
if (cfg_.count("updater") == 0) {
|
||||
cfg_["updater"] = "grow_gpu_hist";
|
||||
}
|
||||
@ -180,6 +188,7 @@ class LearnerImpl : public Learner {
|
||||
cfg_["predictor"] = "gpu_predictor";
|
||||
}
|
||||
} else if (tparam.tree_method == 6) {
|
||||
this->AssertGPUSupport();
|
||||
if (cfg_.count("updater") == 0) {
|
||||
cfg_["updater"] = "grow_gpu_hist_experimental,prune";
|
||||
}
|
||||
|
||||
@ -216,11 +216,11 @@ __device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void PredictKernel(const DevicePredictionNode* d_nodes,
|
||||
float* d_out_predictions, int* d_tree_segments,
|
||||
float* d_out_predictions, size_t* d_tree_segments,
|
||||
int* d_tree_group, size_t* d_row_ptr,
|
||||
SparseBatch::Entry* d_data, int tree_begin,
|
||||
int tree_end, int num_features, bst_uint num_rows,
|
||||
bool use_shared, int num_group) {
|
||||
SparseBatch::Entry* d_data, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features,
|
||||
size_t num_rows, bool use_shared, int num_group) {
|
||||
extern __shared__ float smem[];
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem,
|
||||
@ -249,8 +249,8 @@ __global__ void PredictKernel(const DevicePredictionNode* d_nodes,
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void DevicePredictInternal(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
int tree_end) {
|
||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||
size_t tree_end) {
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
@ -267,17 +267,17 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
// Copy decision trees to device
|
||||
thrust::host_vector<int> h_tree_segments;
|
||||
thrust::host_vector<size_t> h_tree_segments;
|
||||
h_tree_segments.reserve((tree_end - tree_end) + 1);
|
||||
int sum = 0;
|
||||
size_t sum = 0;
|
||||
h_tree_segments.push_back(sum);
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
sum += model.trees[tree_idx]->GetNodes().size();
|
||||
h_tree_segments.push_back(sum);
|
||||
}
|
||||
|
||||
thrust::host_vector<DevicePredictionNode> h_nodes(h_tree_segments.back());
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
@ -299,11 +299,11 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
const int BLOCK_THREADS = 128;
|
||||
const int GRID_SIZE =
|
||||
dh::div_round_up(device_matrix->row_ptr.size() - 1, BLOCK_THREADS);
|
||||
const int GRID_SIZE = static_cast<int>(
|
||||
dh::div_round_up(device_matrix->row_ptr.size() - 1, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes =
|
||||
sizeof(float) * device_matrix->p_mat->info().num_col * BLOCK_THREADS;
|
||||
int shared_memory_bytes = static_cast<int>(
|
||||
sizeof(float) * device_matrix->p_mat->info().num_col * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > dh::max_shared_memory(param.gpu_id)) {
|
||||
shared_memory_bytes = 0;
|
||||
@ -347,8 +347,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const gbm::GBTreeModel& model,
|
||||
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||
int num_new_trees) override {
|
||||
// dh::Timer t;
|
||||
int old_ntree = model.trees.size() - num_new_trees;
|
||||
auto old_ntree = model.trees.size() - num_new_trees;
|
||||
// update cache entry
|
||||
for (auto& kv : cache_) {
|
||||
PredictionCacheEntry& e = kv.second;
|
||||
@ -356,7 +355,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
if (e.predictions.size() == 0) {
|
||||
cpu_predictor->PredictBatch(dmat, &(e.predictions), model, 0,
|
||||
model.trees.size());
|
||||
static_cast<bst_uint>(model.trees.size()));
|
||||
} else if (model.param.num_output_group == 1 && updaters->size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
updaters->back()->UpdatePredictionCache(e.data.get(),
|
||||
@ -383,11 +382,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
bool approximate) override {
|
||||
cpu_predictor->PredictContribution(p_fmat, out_contribs, model,
|
||||
ntree_limit, approximate);
|
||||
cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit,
|
||||
approximate);
|
||||
}
|
||||
|
||||
void Init(const std::vector<std::pair<std::string, std::string>>& cfg,
|
||||
@ -403,7 +401,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
std::unordered_map<DMatrix*, std::unique_ptr<DeviceMatrix>>
|
||||
device_matrix_cache_;
|
||||
thrust::device_vector<DevicePredictionNode> nodes;
|
||||
thrust::device_vector<int> tree_segments;
|
||||
thrust::device_vector<size_t> tree_segments;
|
||||
thrust::device_vector<int> tree_group;
|
||||
};
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
@ -241,7 +241,7 @@ XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad,
|
||||
template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
||||
if (sum_hess < p.min_child_weight)
|
||||
return 0.0;
|
||||
return T(0.0);
|
||||
if (p.max_delta_step == 0.0f) {
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
return Sqr(sum_grad) / (sum_hess + p.reg_lambda);
|
||||
@ -251,11 +251,11 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess
|
||||
}
|
||||
} else {
|
||||
T w = CalcWeight(p, sum_grad, sum_hess);
|
||||
T ret = sum_grad * w + 0.5 * (sum_hess + p.reg_lambda) * Sqr(w);
|
||||
T ret = sum_grad * w + T(0.5) * (sum_hess + p.reg_lambda) * Sqr(w);
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
return -2.0 * ret;
|
||||
return T(-2.0) * ret;
|
||||
} else {
|
||||
return -2.0 * (ret + p.reg_alpha * std::abs(w));
|
||||
return T(-2.0) * (ret + p.reg_alpha * std::abs(w));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -630,7 +630,8 @@ class GPUMaker : public TreeUpdater {
|
||||
throw std::runtime_error("exact::GPUBuilder - must have 1 column block");
|
||||
}
|
||||
std::vector<float> fval;
|
||||
std::vector<int> fId, offset;
|
||||
std::vector<int> fId;
|
||||
std::vector<size_t> offset;
|
||||
convertToCsc(dmat, &fval, &fId, &offset);
|
||||
allocateAllData(static_cast<int>(offset.size()));
|
||||
transferAndSortData(fval, fId, offset);
|
||||
@ -638,10 +639,12 @@ class GPUMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
void convertToCsc(DMatrix* dmat, std::vector<float>* fval,
|
||||
std::vector<int>* fId, std::vector<int>* offset) {
|
||||
std::vector<int>* fId, std::vector<size_t>* offset) {
|
||||
MetaInfo info = dmat->info();
|
||||
nRows = info.num_row;
|
||||
nCols = info.num_col;
|
||||
CHECK(info.num_col < std::numeric_limits<int>::max());
|
||||
CHECK(info.num_row < std::numeric_limits<int>::max());
|
||||
nRows = static_cast<int>(info.num_row);
|
||||
nCols = static_cast<int>(info.num_col);
|
||||
offset->reserve(nCols + 1);
|
||||
offset->push_back(0);
|
||||
fval->reserve(nCols * nRows);
|
||||
@ -667,12 +670,13 @@ class GPUMaker : public TreeUpdater {
|
||||
offset->push_back(fval->size());
|
||||
}
|
||||
}
|
||||
nVals = fval->size();
|
||||
CHECK(fval->size() < std::numeric_limits<int>::max());
|
||||
nVals = static_cast<int>(fval->size());
|
||||
}
|
||||
|
||||
void transferAndSortData(const std::vector<float>& fval,
|
||||
const std::vector<int>& fId,
|
||||
const std::vector<int>& offset) {
|
||||
const std::vector<size_t>& offset) {
|
||||
vals.current_dvec() = fval;
|
||||
instIds.current_dvec() = fId;
|
||||
colOffsets = offset;
|
||||
|
||||
@ -104,7 +104,7 @@ struct DeviceHist {
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void find_split_kernel(
|
||||
const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth,
|
||||
int n_features, int n_bins, DeviceNodeStats* d_nodes,
|
||||
uint64_t n_features, int n_bins, DeviceNodeStats* d_nodes,
|
||||
int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map,
|
||||
GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp,
|
||||
bool colsample, int* d_feature_flags) {
|
||||
@ -293,7 +293,8 @@ class GPUHistMaker : public TreeUpdater {
|
||||
dh::Timer time1;
|
||||
// set member num_rows and n_devices for rest of GPUHistBuilder members
|
||||
info = &fmat.info();
|
||||
num_rows = info->num_row;
|
||||
CHECK(info->num_row < std::numeric_limits<bst_uint>::max());
|
||||
num_rows = static_cast<bst_uint>(info->num_row);
|
||||
n_devices = dh::n_devices(param.n_gpus, num_rows);
|
||||
|
||||
if (!initialised) {
|
||||
@ -396,15 +397,15 @@ class GPUHistMaker : public TreeUpdater {
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
int n_bins = hmat_.row_ptr.back();
|
||||
int n_features = hmat_.row_ptr.size() - 1;
|
||||
int n_bins = static_cast<int >(hmat_.row_ptr.back());
|
||||
int n_features = static_cast<int >(hmat_.row_ptr.size() - 1);
|
||||
|
||||
// deliniate data onto multiple gpus
|
||||
device_row_segments.push_back(0);
|
||||
device_element_segments.push_back(0);
|
||||
bst_uint offset = 0;
|
||||
bst_uint shard_size =
|
||||
std::ceil(static_cast<double>(num_rows) / n_devices);
|
||||
bst_uint shard_size = static_cast<bst_uint>(
|
||||
std::ceil(static_cast<double>(num_rows) / n_devices));
|
||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||
int device_idx = dList[d_idx];
|
||||
offset += shard_size;
|
||||
@ -425,7 +426,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
// Construct feature map
|
||||
std::vector<int> h_gidx_feature_map(n_bins);
|
||||
for (int fidx = 0; fidx < n_features; fidx++) {
|
||||
for (int i = hmat_.row_ptr[fidx]; i < hmat_.row_ptr[fidx + 1]; i++) {
|
||||
for (auto i = hmat_.row_ptr[fidx]; i < hmat_.row_ptr[fidx + 1]; i++) {
|
||||
h_gidx_feature_map[i] = fidx;
|
||||
}
|
||||
}
|
||||
@ -456,7 +457,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
gidx_feature_map.resize(n_devices);
|
||||
gidx_fvalue_map.resize(n_devices);
|
||||
|
||||
int find_split_n_devices = std::pow(2, std::floor(std::log2(n_devices)));
|
||||
int find_split_n_devices = static_cast<int >(std::pow(2, std::floor(std::log2(n_devices))));
|
||||
find_split_n_devices =
|
||||
std::min(n_nodes_level(param.max_depth), find_split_n_devices);
|
||||
int max_num_nodes_device =
|
||||
@ -707,7 +708,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
int nodes_offset_device = 0;
|
||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||
hist_vec[d_idx].GetLevelPtr(depth), feature_segments[d_idx].data(),
|
||||
depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(),
|
||||
depth, info->num_col, hmat_.row_ptr.back(), nodes[d_idx].data(),
|
||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||
left_child_smallest[d_idx].data(), colsample,
|
||||
@ -769,7 +770,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
DeviceNodeStats* d_nodes = nodes[d_idx].data();
|
||||
auto d_gidx_fvalue_map = gidx_fvalue_map[d_idx].data();
|
||||
auto d_gidx = device_matrix[d_idx].gidx;
|
||||
int n_columns = info->num_col;
|
||||
auto n_columns = info->num_col;
|
||||
size_t begin = device_row_segments[d_idx];
|
||||
size_t end = device_row_segments[d_idx + 1];
|
||||
|
||||
|
||||
@ -113,13 +113,11 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist,
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void evaluate_split_kernel(const bst_gpair_integer* d_hist, int nidx,
|
||||
int n_features, DeviceNodeStats nodes,
|
||||
const int* d_feature_segments,
|
||||
const float* d_fidx_min_map,
|
||||
const float* d_gidx_fvalue_map,
|
||||
GPUTrainingParam gpu_param,
|
||||
DeviceSplitCandidate* d_split) {
|
||||
__global__ void evaluate_split_kernel(
|
||||
const bst_gpair_integer* d_hist, int nidx, uint64_t n_features,
|
||||
DeviceNodeStats nodes, const int* d_feature_segments,
|
||||
const float* d_fidx_min_map, const float* d_gidx_fvalue_map,
|
||||
GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split) {
|
||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||
typedef cub::BlockScan<bst_gpair_integer, BLOCK_THREADS,
|
||||
cub::BLOCK_SCAN_WARP_SCANS>
|
||||
@ -190,24 +188,6 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data,
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void RadixSortSmall(bst_uint* d_ridx, int* d_position, bst_uint n) {
|
||||
typedef cub::BlockRadixSort<int, BLOCK_THREADS, 1, bst_uint> BlockRadixSort;
|
||||
__shared__ typename BlockRadixSort::TempStorage temp_storage;
|
||||
|
||||
bool thread_active = threadIdx.x < n;
|
||||
int thread_key[1];
|
||||
bst_uint thread_value[1];
|
||||
thread_key[0] = thread_active ? d_position[threadIdx.x] : INT_MAX;
|
||||
thread_value[0] = thread_active ? d_ridx[threadIdx.x] : UINT_MAX;
|
||||
BlockRadixSort(temp_storage).Sort(thread_key, thread_value);
|
||||
|
||||
if (thread_active) {
|
||||
d_position[threadIdx.x] = thread_key[0];
|
||||
d_ridx[threadIdx.x] = thread_value[0];
|
||||
}
|
||||
}
|
||||
|
||||
struct DeviceHistogram {
|
||||
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
|
||||
dh::dvec<bst_gpair_integer> data;
|
||||
@ -269,7 +249,7 @@ struct DeviceShard {
|
||||
null_gidx_value(n_bins) {
|
||||
// Convert to ELLPACK matrix representation
|
||||
int max_elements_row = 0;
|
||||
for (int i = row_begin; i < row_end; i++) {
|
||||
for (auto i = row_begin; i < row_end; i++) {
|
||||
max_elements_row =
|
||||
(std::max)(max_elements_row,
|
||||
static_cast<int>(gmat.row_ptr[i + 1] - gmat.row_ptr[i]));
|
||||
@ -277,9 +257,9 @@ struct DeviceShard {
|
||||
row_stride = max_elements_row;
|
||||
std::vector<int> ellpack_matrix(row_stride * n_rows, null_gidx_value);
|
||||
|
||||
for (int i = row_begin; i < row_end; i++) {
|
||||
for (auto i = row_begin; i < row_end; i++) {
|
||||
int row_count = 0;
|
||||
for (int j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
|
||||
for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
|
||||
ellpack_matrix[i * row_stride + row_count] = gmat.index[j];
|
||||
row_count++;
|
||||
}
|
||||
@ -394,13 +374,8 @@ struct DeviceShard {
|
||||
int right_nidx) {
|
||||
auto n = segment.second - segment.first;
|
||||
int min_bits = 0;
|
||||
int max_bits = std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1));
|
||||
// const int SINGLE_TILE_SIZE = 1024;
|
||||
// if (n < SINGLE_TILE_SIZE) {
|
||||
// RadixSortSmall<SINGLE_TILE_SIZE>
|
||||
// <<<1, SINGLE_TILE_SIZE>>>(ridx.current() + segment.first,
|
||||
// position.current() + segment.first, n);
|
||||
//} else {
|
||||
int max_bits = static_cast<int>(
|
||||
std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1)));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceRadixSort::SortPairs(
|
||||
@ -509,7 +484,7 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
nidx_set.size());
|
||||
auto d_split = shard.temp_memory.Pointer<DeviceSplitCandidate>();
|
||||
|
||||
auto& streams = shard.GetStreams(nidx_set.size());
|
||||
auto& streams = shard.GetStreams(static_cast<int>(nidx_set.size()));
|
||||
|
||||
// Use streams to process nodes concurrently
|
||||
for (auto i = 0; i < nidx_set.size(); i++) {
|
||||
@ -518,7 +493,7 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
|
||||
const int BLOCK_THREADS = 256;
|
||||
evaluate_split_kernel<BLOCK_THREADS>
|
||||
<<<columns, BLOCK_THREADS, 0, streams[i]>>>(
|
||||
<<<uint32_t(columns), BLOCK_THREADS, 0, streams[i]>>>(
|
||||
shard.hist.node_map[nidx], nidx, info->num_col, node,
|
||||
shard.feature_segments.data(), shard.min_fvalue.data(),
|
||||
shard.gidx_fvalue_map.data(), GPUTrainingParam(param),
|
||||
@ -573,10 +548,11 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
__host__ __device__ int operator()(int x) const { return x == val; }
|
||||
};
|
||||
|
||||
__device__ void CountLeft(bst_uint* d_count, int val, int left_nidx) {
|
||||
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
|
||||
unsigned ballot = __ballot(val == left_nidx);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(d_count, __popc(ballot));
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
@ -601,9 +577,9 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
|
||||
for (auto& shard : shards) {
|
||||
monitor.Start("update position kernel");
|
||||
shard.temp_memory.LazyAllocate(sizeof(bst_uint));
|
||||
auto d_left_count = shard.temp_memory.Pointer<bst_uint>();
|
||||
dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(bst_uint)));
|
||||
shard.temp_memory.LazyAllocate(sizeof(int64_t));
|
||||
auto d_left_count = shard.temp_memory.Pointer<int64_t>();
|
||||
dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t)));
|
||||
dh::safe_cuda(cudaSetDevice(shard.device_idx));
|
||||
auto segment = shard.ridx_segments[nidx];
|
||||
CHECK_GT(segment.second - segment.first, 0);
|
||||
@ -639,8 +615,8 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
d_position[idx] = position;
|
||||
});
|
||||
|
||||
bst_uint left_count;
|
||||
dh::safe_cuda(cudaMemcpy(&left_count, d_left_count, sizeof(bst_uint),
|
||||
int64_t left_count;
|
||||
dh::safe_cuda(cudaMemcpy(&left_count, d_left_count, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost));
|
||||
monitor.Stop("update position kernel");
|
||||
|
||||
@ -722,7 +698,7 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
this->InitRoot(gpair, p_tree);
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
unsigned timestamp = qexpand_->size();
|
||||
auto timestamp = qexpand_->size();
|
||||
auto num_leaves = 1;
|
||||
|
||||
while (!qexpand_->empty()) {
|
||||
@ -764,9 +740,9 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
int nid;
|
||||
int depth;
|
||||
DeviceSplitCandidate split;
|
||||
unsigned timestamp;
|
||||
uint64_t timestamp;
|
||||
ExpandEntry(int nid, int depth, const DeviceSplitCandidate& split,
|
||||
unsigned timestamp)
|
||||
uint64_t timestamp)
|
||||
: nid(nid), depth(depth), split(split), timestamp(timestamp) {}
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= rt_eps) return false;
|
||||
|
||||
@ -7,7 +7,7 @@ namespace common {
|
||||
TEST(CompressedIterator, Test) {
|
||||
ASSERT_TRUE(detail::SymbolBits(256) == 8);
|
||||
ASSERT_TRUE(detail::SymbolBits(150) == 8);
|
||||
std::vector<int> test_cases = {3, 426, 21, 64, 256, 100000, INT32_MAX};
|
||||
std::vector<int> test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX};
|
||||
int num_elements = 1000;
|
||||
int repetitions = 1000;
|
||||
srand(9);
|
||||
|
||||
@ -12,7 +12,7 @@ void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||
thrust::host_vector<xgboost::bst_uint> *rows) {
|
||||
row_ptr->resize(num_rows + 1);
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= num_rows; i++) {
|
||||
for (xgboost::bst_uint i = 0; i <= num_rows; i++) {
|
||||
(*row_ptr)[i] = sum;
|
||||
sum += rand() % max_row_size; // NOLINT
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ TEST(gpu_predictor, Test) {
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>());
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
||||
trees.back()->InitModel();
|
||||
(*trees.back())[0].set_leaf(1.5f);
|
||||
(*trees.back()).stat(0).sum_hess = 1.0f;
|
||||
@ -39,7 +39,6 @@ TEST(gpu_predictor, Test) {
|
||||
ASSERT_LT(std::abs(gpu_out_predictions[i] - cpu_out_predictions[i]),
|
||||
abs_tolerance);
|
||||
}
|
||||
|
||||
// Test predict instance
|
||||
auto batch = dmat->RowIterator()->Value();
|
||||
for (int i = 0; i < batch.size; i++) {
|
||||
|
||||
@ -16,7 +16,7 @@ TEST(gpu_hist_experimental, TestSparseShard) {
|
||||
int rows = 100;
|
||||
int columns = 80;
|
||||
int max_bins = 4;
|
||||
auto dmat = CreateDMatrix(rows, columns, 0.9);
|
||||
auto dmat = CreateDMatrix(rows, columns, 0.9f);
|
||||
common::HistCutMatrix hmat;
|
||||
common::GHistIndexMatrix gmat;
|
||||
hmat.Init(dmat.get(), max_bins);
|
||||
@ -33,7 +33,7 @@ TEST(gpu_hist_experimental, TestSparseShard) {
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
int row_offset = 0;
|
||||
for (int j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
|
||||
for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
|
||||
ASSERT_EQ(gidx[i * shard.row_stride + row_offset], gmat.index[j]);
|
||||
row_offset++;
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ if [ ${TASK} == "python_lightweight_test" ]; then
|
||||
conda install numpy scipy nose
|
||||
python -m pip install graphviz
|
||||
python -m nose tests/python || exit -1
|
||||
python -m pip install flake8
|
||||
python -m pip install flake8==3.4.1
|
||||
flake8 --ignore E501 python-package || exit -1
|
||||
flake8 --ignore E501 tests/python || exit -1
|
||||
exit 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user