Fix CPU hist init for sparse dataset. (#4625)
* Fix CPU hist init for sparse dataset. * Implement sparse histogram cut. * Allow empty features. * Fix windows build, don't use sparse in distributed environment. * Comments. * Smaller threshold. * Fix windows omp. * Fix msvc lambda capture. * Fix MSVC macro. * Fix MSVC initialization list. * Fix MSVC initialization list x2. * Preserve categorical feature behavior. * Rename matrix to sparse cuts. * Reuse UseGroup. * Check for categorical data when adding cut. Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * Sanity check. * Fix comments. * Fix comment.
This commit is contained in:
parent
b7a1f22d24
commit
d9a47794a5
@ -1,4 +1,4 @@
|
|||||||
Checks: 'modernize-*,-modernize-make-*,-modernize-use-auto,-modernize-raw-string-literal,google-*,-google-default-arguments,-clang-diagnostic-#pragma-messages,readability-identifier-naming'
|
Checks: 'modernize-*,-modernize-make-*,-modernize-use-auto,-modernize-raw-string-literal,-modernize-avoid-c-arrays,google-*,-google-default-arguments,-clang-diagnostic-#pragma-messages,readability-identifier-naming'
|
||||||
CheckOptions:
|
CheckOptions:
|
||||||
- { key: readability-identifier-naming.ClassCase, value: CamelCase }
|
- { key: readability-identifier-naming.ClassCase, value: CamelCase }
|
||||||
- { key: readability-identifier-naming.StructCase, value: CamelCase }
|
- { key: readability-identifier-naming.StructCase, value: CamelCase }
|
||||||
|
|||||||
@ -437,6 +437,7 @@ class DMatrix {
|
|||||||
bool load_row_split,
|
bool load_row_split,
|
||||||
const std::string& file_format = "auto",
|
const std::string& file_format = "auto",
|
||||||
const size_t page_size = kPageSize);
|
const size_t page_size = kPageSize);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
|
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
|
||||||
* \param source The source iterator of the data, the create function takes ownership of the source.
|
* \param source The source iterator of the data, the create function takes ownership of the source.
|
||||||
|
|||||||
@ -59,7 +59,7 @@ if (USE_CUDA)
|
|||||||
|
|
||||||
# OpenMP is mandatory for cuda version
|
# OpenMP is mandatory for cuda version
|
||||||
find_package(OpenMP REQUIRED)
|
find_package(OpenMP REQUIRED)
|
||||||
target_compile_options(objxgboost PRIVATE
|
target_compile_options(objxgboost PRIVATE
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${OpenMP_CXX_FLAGS}>
|
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${OpenMP_CXX_FLAGS}>
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -119,10 +119,9 @@ class NativeDataIter : public dmlc::Parser<uint32_t> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool Next() override {
|
bool Next() override {
|
||||||
if ((*next_callback_)(
|
if ((*next_callback_)(data_handle_,
|
||||||
data_handle_,
|
XGBoostNativeDataIterSetData,
|
||||||
XGBoostNativeDataIterSetData,
|
this) != 0) {
|
||||||
this) != 0) {
|
|
||||||
at_first_ = false;
|
at_first_ = false;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class ColumnMatrix {
|
|||||||
// construct column matrix from GHistIndexMatrix
|
// construct column matrix from GHistIndexMatrix
|
||||||
inline void Init(const GHistIndexMatrix& gmat,
|
inline void Init(const GHistIndexMatrix& gmat,
|
||||||
double sparse_threshold) {
|
double sparse_threshold) {
|
||||||
const int32_t nfeature = static_cast<int32_t>(gmat.cut.row_ptr.size() - 1);
|
const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1);
|
||||||
const size_t nrow = gmat.row_ptr.size() - 1;
|
const size_t nrow = gmat.row_ptr.size() - 1;
|
||||||
|
|
||||||
// identify type of each column
|
// identify type of each column
|
||||||
@ -85,7 +85,7 @@ class ColumnMatrix {
|
|||||||
|
|
||||||
uint32_t max_val = std::numeric_limits<uint32_t>::max();
|
uint32_t max_val = std::numeric_limits<uint32_t>::max();
|
||||||
for (int32_t fid = 0; fid < nfeature; ++fid) {
|
for (int32_t fid = 0; fid < nfeature; ++fid) {
|
||||||
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
|
CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
gmat.GetFeatureCounts(&feature_counts_[0]);
|
gmat.GetFeatureCounts(&feature_counts_[0]);
|
||||||
@ -123,7 +123,7 @@ class ColumnMatrix {
|
|||||||
// store least bin id for each feature
|
// store least bin id for each feature
|
||||||
index_base_.resize(nfeature);
|
index_base_.resize(nfeature);
|
||||||
for (int32_t fid = 0; fid < nfeature; ++fid) {
|
for (int32_t fid = 0; fid < nfeature; ++fid) {
|
||||||
index_base_[fid] = gmat.cut.row_ptr[fid];
|
index_base_[fid] = gmat.cut.Ptrs()[fid];
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-fill index_ for dense columns
|
// pre-fill index_ for dense columns
|
||||||
@ -150,9 +150,9 @@ class ColumnMatrix {
|
|||||||
size_t fid = 0;
|
size_t fid = 0;
|
||||||
for (size_t i = ibegin; i < iend; ++i) {
|
for (size_t i = ibegin; i < iend; ++i) {
|
||||||
const uint32_t bin_id = gmat.index[i];
|
const uint32_t bin_id = gmat.index[i];
|
||||||
while (bin_id >= gmat.cut.row_ptr[fid + 1]) {
|
auto iter = std::upper_bound(gmat.cut.Ptrs().cbegin() + fid,
|
||||||
++fid;
|
gmat.cut.Ptrs().cend(), bin_id);
|
||||||
}
|
fid = std::distance(gmat.cut.Ptrs().cbegin(), iter) - 1;
|
||||||
if (type_[fid] == kDenseColumn) {
|
if (type_[fid] == kDenseColumn) {
|
||||||
uint32_t* begin = &index_[boundary_[fid].index_begin];
|
uint32_t* begin = &index_[boundary_[fid].index_begin];
|
||||||
begin[rid] = bin_id - index_base_[fid];
|
begin[rid] = bin_id - index_base_[fid];
|
||||||
|
|||||||
@ -72,6 +72,11 @@ inline std::string ToString(const T& data) {
|
|||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
|
||||||
|
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Range iterator
|
* Range iterator
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -30,12 +30,13 @@ class ConfigParser {
|
|||||||
* \param path path to configuration file
|
* \param path path to configuration file
|
||||||
*/
|
*/
|
||||||
explicit ConfigParser(const std::string& path)
|
explicit ConfigParser(const std::string& path)
|
||||||
: line_comment_regex_("^#"),
|
: path_(path),
|
||||||
|
line_comment_regex_("^#"),
|
||||||
key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"),
|
key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"),
|
||||||
key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"),
|
key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"),
|
||||||
value_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"),
|
value_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"),
|
||||||
value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx"),
|
value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx")
|
||||||
path_(path) {}
|
{}
|
||||||
|
|
||||||
std::string LoadConfigFile(const std::string& path) {
|
std::string LoadConfigFile(const std::string& path) {
|
||||||
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
|
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
|
||||||
@ -77,8 +78,6 @@ class ConfigParser {
|
|||||||
content = NormalizeConfigEOL(content);
|
content = NormalizeConfigEOL(content);
|
||||||
std::stringstream ss { content };
|
std::stringstream ss { content };
|
||||||
std::vector<std::pair<std::string, std::string>> results;
|
std::vector<std::pair<std::string, std::string>> results;
|
||||||
char delimiter = '=';
|
|
||||||
char comment = '#';
|
|
||||||
std::string line;
|
std::string line;
|
||||||
std::string key, value;
|
std::string key, value;
|
||||||
// Loop over every line of the configuration file
|
// Loop over every line of the configuration file
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017 XGBoost contributors
|
* Copyright 2017-2019 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
@ -183,11 +183,6 @@ __device__ void BlockFill(IterT begin, size_t n, ValueT value) {
|
|||||||
* Kernel launcher
|
* Kernel launcher
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <typename T1, typename T2>
|
|
||||||
T1 DivRoundUp(const T1 a, const T2 b) {
|
|
||||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename L>
|
template <typename L>
|
||||||
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
||||||
for (auto i : GridStrideRange(begin, end)) {
|
for (auto i : GridStrideRange(begin, end)) {
|
||||||
@ -211,7 +206,7 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
|
|||||||
safe_cuda(cudaSetDevice(device_idx));
|
safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
|
||||||
const int GRID_SIZE =
|
const int GRID_SIZE =
|
||||||
static_cast<int>(DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
||||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0),
|
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0),
|
||||||
n, lambda);
|
n, lambda);
|
||||||
}
|
}
|
||||||
@ -619,7 +614,7 @@ struct CubMemory {
|
|||||||
if (this->IsAllocated()) {
|
if (this->IsAllocated()) {
|
||||||
XGBDeviceAllocator<uint8_t> allocator;
|
XGBDeviceAllocator<uint8_t> allocator;
|
||||||
allocator.deallocate(thrust::device_ptr<uint8_t>(static_cast<uint8_t *>(d_temp_storage)),
|
allocator.deallocate(thrust::device_ptr<uint8_t>(static_cast<uint8_t *>(d_temp_storage)),
|
||||||
temp_storage_bytes);
|
temp_storage_bytes);
|
||||||
d_temp_storage = nullptr;
|
d_temp_storage = nullptr;
|
||||||
temp_storage_bytes = 0;
|
temp_storage_bytes = 0;
|
||||||
}
|
}
|
||||||
@ -738,7 +733,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
|
|||||||
const int BLOCK_THREADS = 256;
|
const int BLOCK_THREADS = 256;
|
||||||
const int ITEMS_PER_THREAD = 1;
|
const int ITEMS_PER_THREAD = 1;
|
||||||
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
|
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
|
||||||
auto num_tiles = dh::DivRoundUp(count + num_segments, BLOCK_THREADS);
|
auto num_tiles = xgboost::common::DivRoundUp(count + num_segments, BLOCK_THREADS);
|
||||||
CHECK(num_tiles < std::numeric_limits<unsigned int>::max());
|
CHECK(num_tiles < std::numeric_limits<unsigned int>::max());
|
||||||
|
|
||||||
temp_memory->LazyAllocate(sizeof(CoordinateT) * (num_tiles + 1));
|
temp_memory->LazyAllocate(sizeof(CoordinateT) * (num_tiles + 1));
|
||||||
@ -1158,7 +1153,7 @@ class AllReducer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Synchronizes the device
|
* \brief Synchronizes the device
|
||||||
*
|
*
|
||||||
* \param device_id Identifier for the device.
|
* \param device_id Identifier for the device.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -25,25 +25,206 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
HistCutMatrix::HistCutMatrix() {
|
HistogramCuts::HistogramCuts() {
|
||||||
monitor_.Init("HistCutMatrix");
|
monitor_.Init(__FUNCTION__);
|
||||||
|
cut_ptrs_.emplace_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t HistCutMatrix::SearchGroupIndFromBaseRow(
|
// Dispatch to specific builder.
|
||||||
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) const {
|
void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
|
||||||
using KIt = std::vector<bst_uint>::const_iterator;
|
auto const& info = dmat->Info();
|
||||||
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
|
size_t const total = info.num_row_ * info.num_col_;
|
||||||
// Cannot use CHECK_NE because it will try to print the iterator.
|
size_t const nnz = info.num_nonzero_;
|
||||||
bool const found = res != group_ptr.cend() - 1;
|
float const sparsity = static_cast<float>(nnz) / static_cast<float>(total);
|
||||||
if (!found) {
|
// Use a small number to avoid calling `dmat->GetColumnBatches'.
|
||||||
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!\n";
|
float constexpr kSparsityThreshold = 0.0005;
|
||||||
|
// FIXME(trivialfis): Distributed environment is not supported.
|
||||||
|
if (sparsity < kSparsityThreshold && (!rabit::IsDistributed())) {
|
||||||
|
LOG(INFO) << "Building quantile cut on a sparse dataset.";
|
||||||
|
SparseCuts cuts(this);
|
||||||
|
cuts.Build(dmat, max_num_bins);
|
||||||
|
} else {
|
||||||
|
LOG(INFO) << "Building quantile cut on a dense dataset or distributed environment.";
|
||||||
|
DenseCuts cuts(this);
|
||||||
|
cuts.Build(dmat, max_num_bins);
|
||||||
}
|
}
|
||||||
size_t group_ind = std::distance(group_ptr.cbegin(), res);
|
|
||||||
return group_ind;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
bool CutsBuilder::UseGroup(DMatrix* dmat) {
|
||||||
monitor_.Start("Init");
|
auto& info = dmat->Info();
|
||||||
|
size_t const num_groups = info.group_ptr_.size() == 0 ?
|
||||||
|
0 : info.group_ptr_.size() - 1;
|
||||||
|
// Use group index for weights?
|
||||||
|
bool const use_group_ind = num_groups != 0 &&
|
||||||
|
(info.weights_.Size() != info.num_row_);
|
||||||
|
return use_group_ind;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||||
|
uint32_t max_num_bins,
|
||||||
|
bool const use_group_ind,
|
||||||
|
uint32_t beg_col, uint32_t end_col,
|
||||||
|
uint32_t thread_id) {
|
||||||
|
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||||
|
CHECK_GE(end_col, beg_col);
|
||||||
|
constexpr float kFactor = 8;
|
||||||
|
|
||||||
|
// Data groups, used in ranking.
|
||||||
|
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
|
||||||
|
p_cuts_->min_vals_.resize(end_col - beg_col, 0);
|
||||||
|
|
||||||
|
for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
|
||||||
|
// Using a local variable makes things easier, but at the cost of memory trashing.
|
||||||
|
WXQSketch sketch;
|
||||||
|
common::Span<xgboost::Entry const> const column = page[col_id];
|
||||||
|
uint32_t const n_bins = std::min(static_cast<uint32_t>(column.size()),
|
||||||
|
max_num_bins);
|
||||||
|
if (n_bins == 0) {
|
||||||
|
// cut_ptrs_ is initialized with a zero, so there's always an element at the back
|
||||||
|
p_cuts_->cut_ptrs_.emplace_back(p_cuts_->cut_ptrs_.back());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor));
|
||||||
|
for (auto const& entry : column) {
|
||||||
|
uint32_t weight_ind = 0;
|
||||||
|
if (use_group_ind) {
|
||||||
|
auto row_idx = entry.index;
|
||||||
|
uint32_t group_ind =
|
||||||
|
this->SearchGroupIndFromRow(group_ptr, page.base_rowid + row_idx);
|
||||||
|
weight_ind = group_ind;
|
||||||
|
} else {
|
||||||
|
weight_ind = entry.index;
|
||||||
|
}
|
||||||
|
sketch.Push(entry.fvalue, info.GetWeight(weight_ind));
|
||||||
|
}
|
||||||
|
|
||||||
|
WXQSketch::SummaryContainer out_summary;
|
||||||
|
sketch.GetSummary(&out_summary);
|
||||||
|
WXQSketch::SummaryContainer summary;
|
||||||
|
summary.Reserve(n_bins);
|
||||||
|
summary.SetPrune(out_summary, n_bins);
|
||||||
|
|
||||||
|
// Can be use data[1] as the min values so that we don't need to
|
||||||
|
// store another array?
|
||||||
|
float mval = summary.data[0].value;
|
||||||
|
p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5);
|
||||||
|
|
||||||
|
this->AddCutPoint(summary);
|
||||||
|
|
||||||
|
bst_float cpt = (summary.size > 0) ?
|
||||||
|
summary.data[summary.size - 1].value :
|
||||||
|
p_cuts_->min_vals_[col_id - beg_col];
|
||||||
|
cpt += fabs(cpt) + 1e-5;
|
||||||
|
p_cuts_->cut_values_.emplace_back(cpt);
|
||||||
|
|
||||||
|
p_cuts_->cut_ptrs_.emplace_back(p_cuts_->cut_values_.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> SparseCuts::LoadBalance(SparsePage const& page,
|
||||||
|
size_t const nthreads) {
|
||||||
|
/* Some sparse datasets have their mass concentrating on small
|
||||||
|
* number of features. To avoid wating for a few threads running
|
||||||
|
* forever, we here distirbute different number of columns to
|
||||||
|
* different threads according to number of entries. */
|
||||||
|
size_t const total_entries = page.data.Size();
|
||||||
|
size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads);
|
||||||
|
|
||||||
|
std::vector<size_t> cols_ptr(nthreads+1, 0);
|
||||||
|
size_t count {0};
|
||||||
|
size_t current_thread {1};
|
||||||
|
|
||||||
|
for (size_t col_id = 0; col_id < page.Size(); ++col_id) {
|
||||||
|
auto const column = page[col_id];
|
||||||
|
cols_ptr[current_thread]++; // add one column to thread
|
||||||
|
count += column.size();
|
||||||
|
if (count > entries_per_thread + 1) {
|
||||||
|
current_thread++;
|
||||||
|
count = 0;
|
||||||
|
cols_ptr[current_thread] = cols_ptr[current_thread-1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Idle threads.
|
||||||
|
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
|
||||||
|
cols_ptr[current_thread+1] = cols_ptr[current_thread];
|
||||||
|
}
|
||||||
|
|
||||||
|
return cols_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparseCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
|
||||||
|
monitor_.Start(__FUNCTION__);
|
||||||
|
// Use group index for weights?
|
||||||
|
auto use_group = UseGroup(dmat);
|
||||||
|
uint32_t nthreads = omp_get_max_threads();
|
||||||
|
CHECK_GT(nthreads, 0);
|
||||||
|
std::vector<HistogramCuts> cuts_containers(nthreads);
|
||||||
|
std::vector<std::unique_ptr<SparseCuts>> sparse_cuts(nthreads);
|
||||||
|
for (size_t i = 0; i < nthreads; ++i) {
|
||||||
|
sparse_cuts[i].reset(new SparseCuts(&cuts_containers[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto const& page : dmat->GetColumnBatches()) {
|
||||||
|
CHECK_LE(page.Size(), dmat->Info().num_col_);
|
||||||
|
monitor_.Start("Load balance");
|
||||||
|
std::vector<size_t> col_ptr = LoadBalance(page, nthreads);
|
||||||
|
monitor_.Stop("Load balance");
|
||||||
|
// We here decouples the logic between build and parallelization
|
||||||
|
// to simplify things a bit.
|
||||||
|
#pragma omp parallel for num_threads(nthreads) schedule(static)
|
||||||
|
for (omp_ulong i = 0; i < nthreads; ++i) {
|
||||||
|
common::Monitor t_monitor;
|
||||||
|
t_monitor.Init("SingleThreadBuild: " + std::to_string(i));
|
||||||
|
t_monitor.Start(std::to_string(i));
|
||||||
|
sparse_cuts[i]->SingleThreadBuild(page, dmat->Info(), max_num_bins, use_group,
|
||||||
|
col_ptr[i], col_ptr[i+1], i);
|
||||||
|
t_monitor.Stop(std::to_string(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
this->Concat(sparse_cuts, dmat->Info().num_col_);
|
||||||
|
}
|
||||||
|
|
||||||
|
monitor_.Stop(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparseCuts::Concat(
|
||||||
|
std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols) {
|
||||||
|
monitor_.Start(__FUNCTION__);
|
||||||
|
uint32_t nthreads = omp_get_max_threads();
|
||||||
|
p_cuts_->min_vals_.resize(n_cols, std::numeric_limits<float>::max());
|
||||||
|
size_t min_vals_tail = 0;
|
||||||
|
|
||||||
|
for (uint32_t t = 0; t < nthreads; ++t) {
|
||||||
|
// concat csc pointers.
|
||||||
|
size_t const old_ptr_size = p_cuts_->cut_ptrs_.size();
|
||||||
|
p_cuts_->cut_ptrs_.resize(
|
||||||
|
cuts[t]->p_cuts_->cut_ptrs_.size() + p_cuts_->cut_ptrs_.size() - 1);
|
||||||
|
size_t const new_icp_size = p_cuts_->cut_ptrs_.size();
|
||||||
|
auto tail = p_cuts_->cut_ptrs_[old_ptr_size-1];
|
||||||
|
for (size_t j = old_ptr_size; j < new_icp_size; ++j) {
|
||||||
|
p_cuts_->cut_ptrs_[j] = tail + cuts[t]->p_cuts_->cut_ptrs_[j-old_ptr_size+1];
|
||||||
|
}
|
||||||
|
// concat csc values
|
||||||
|
size_t const old_iv_size = p_cuts_->cut_values_.size();
|
||||||
|
p_cuts_->cut_values_.resize(
|
||||||
|
cuts[t]->p_cuts_->cut_values_.size() + p_cuts_->cut_values_.size());
|
||||||
|
size_t const new_iv_size = p_cuts_->cut_values_.size();
|
||||||
|
for (size_t j = old_iv_size; j < new_iv_size; ++j) {
|
||||||
|
p_cuts_->cut_values_[j] = cuts[t]->p_cuts_->cut_values_[j-old_iv_size];
|
||||||
|
}
|
||||||
|
// merge min values
|
||||||
|
for (size_t j = 0; j < cuts[t]->p_cuts_->min_vals_.size(); ++j) {
|
||||||
|
p_cuts_->min_vals_.at(min_vals_tail + j) =
|
||||||
|
std::min(p_cuts_->min_vals_.at(min_vals_tail + j), cuts.at(t)->p_cuts_->min_vals_.at(j));
|
||||||
|
}
|
||||||
|
min_vals_tail += cuts[t]->p_cuts_->min_vals_.size();
|
||||||
|
}
|
||||||
|
monitor_.Stop(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||||
|
monitor_.Start(__FUNCTION__);
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
|
|
||||||
// safe factor for better accuracy
|
// safe factor for better accuracy
|
||||||
@ -60,20 +241,18 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
|
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& weights = info.weights_.HostVector();
|
|
||||||
|
|
||||||
// Data groups, used in ranking.
|
// Data groups, used in ranking.
|
||||||
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
|
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
|
||||||
size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1;
|
size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1;
|
||||||
// Use group index for weights?
|
// Use group index for weights?
|
||||||
bool const use_group_ind = num_groups != 0 && weights.size() != info.num_row_;
|
bool const use_group = UseGroup(p_fmat);
|
||||||
|
|
||||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||||
size_t group_ind = 0;
|
size_t group_ind = 0;
|
||||||
if (use_group_ind) {
|
if (use_group) {
|
||||||
group_ind = this->SearchGroupIndFromBaseRow(group_ptr, batch.base_rowid);
|
group_ind = this->SearchGroupIndFromRow(group_ptr, batch.base_rowid);
|
||||||
}
|
}
|
||||||
#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group_ind)
|
#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group)
|
||||||
{
|
{
|
||||||
CHECK_EQ(nthread, omp_get_num_threads());
|
CHECK_EQ(nthread, omp_get_num_threads());
|
||||||
auto tid = static_cast<unsigned>(omp_get_thread_num());
|
auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||||
@ -85,7 +264,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
||||||
size_t const ridx = batch.base_rowid + i;
|
size_t const ridx = batch.base_rowid + i;
|
||||||
SparsePage::Inst const inst = batch[i];
|
SparsePage::Inst const inst = batch[i];
|
||||||
if (use_group_ind &&
|
if (use_group &&
|
||||||
group_ptr[group_ind] == ridx &&
|
group_ptr[group_ind] == ridx &&
|
||||||
// maximum equals to weights.size() - 1
|
// maximum equals to weights.size() - 1
|
||||||
group_ind < num_groups - 1) {
|
group_ind < num_groups - 1) {
|
||||||
@ -94,7 +273,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
}
|
}
|
||||||
for (auto const& entry : inst) {
|
for (auto const& entry : inst) {
|
||||||
if (entry.index >= begin && entry.index < end) {
|
if (entry.index >= begin && entry.index < end) {
|
||||||
size_t w_idx = use_group_ind ? group_ind : ridx;
|
size_t w_idx = use_group ? group_ind : ridx;
|
||||||
sketchs[entry.index].Push(entry.fvalue, info.GetWeight(w_idx));
|
sketchs[entry.index].Push(entry.fvalue, info.GetWeight(w_idx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -104,10 +283,10 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Init(&sketchs, max_num_bins);
|
Init(&sketchs, max_num_bins);
|
||||||
monitor_.Stop("Init");
|
monitor_.Stop(__FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HistCutMatrix::Init
|
void DenseCuts::Init
|
||||||
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||||
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
||||||
constexpr int kFactor = 8;
|
constexpr int kFactor = 8;
|
||||||
@ -124,62 +303,34 @@ void HistCutMatrix::Init
|
|||||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
this->min_val.resize(sketchs.size());
|
p_cuts_->min_vals_.resize(sketchs.size());
|
||||||
row_ptr.push_back(0);
|
|
||||||
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
|
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
|
||||||
WXQSketch::SummaryContainer a;
|
WXQSketch::SummaryContainer a;
|
||||||
a.Reserve(max_num_bins);
|
a.Reserve(max_num_bins);
|
||||||
a.SetPrune(summary_array[fid], max_num_bins);
|
a.SetPrune(summary_array[fid], max_num_bins);
|
||||||
const bst_float mval = a.data[0].value;
|
const bst_float mval = a.data[0].value;
|
||||||
this->min_val[fid] = mval - (fabs(mval) + 1e-5);
|
p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5);
|
||||||
if (a.size > 1 && a.size <= 16) {
|
AddCutPoint(a);
|
||||||
/* specialized code categorial / ordinal data -- use midpoints */
|
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
|
||||||
bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0f;
|
|
||||||
if (i == 1 || cpt > cut.back()) {
|
|
||||||
cut.push_back(cpt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 2; i < a.size; ++i) {
|
|
||||||
bst_float cpt = a.data[i - 1].value;
|
|
||||||
if (i == 2 || cpt > cut.back()) {
|
|
||||||
cut.push_back(cpt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// push a value that is greater than anything
|
// push a value that is greater than anything
|
||||||
const bst_float cpt
|
const bst_float cpt
|
||||||
= (a.size > 0) ? a.data[a.size - 1].value : this->min_val[fid];
|
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid];
|
||||||
// this must be bigger than last value in a scale
|
// this must be bigger than last value in a scale
|
||||||
const bst_float last = cpt + (fabs(cpt) + 1e-5);
|
const bst_float last = cpt + (fabs(cpt) + 1e-5);
|
||||||
cut.push_back(last);
|
p_cuts_->cut_values_.push_back(last);
|
||||||
|
|
||||||
// Ensure that every feature gets at least one quantile point
|
// Ensure that every feature gets at least one quantile point
|
||||||
CHECK_LE(cut.size(), std::numeric_limits<uint32_t>::max());
|
CHECK_LE(p_cuts_->cut_values_.size(), std::numeric_limits<uint32_t>::max());
|
||||||
auto cut_size = static_cast<uint32_t>(cut.size());
|
auto cut_size = static_cast<uint32_t>(p_cuts_->cut_values_.size());
|
||||||
CHECK_GT(cut_size, row_ptr.back());
|
CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back());
|
||||||
row_ptr.push_back(cut_size);
|
p_cuts_->cut_ptrs_.push_back(cut_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t HistCutMatrix::GetBinIdx(const Entry& e) {
|
|
||||||
unsigned fid = e.index;
|
|
||||||
auto cbegin = cut.begin() + row_ptr[fid];
|
|
||||||
auto cend = cut.begin() + row_ptr[fid + 1];
|
|
||||||
CHECK(cbegin != cend);
|
|
||||||
auto it = std::upper_bound(cbegin, cend, e.fvalue);
|
|
||||||
if (it == cend) {
|
|
||||||
it = cend - 1;
|
|
||||||
}
|
|
||||||
uint32_t idx = static_cast<uint32_t>(it - cut.begin());
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
||||||
cut.Init(p_fmat, max_num_bins);
|
cut.Build(p_fmat, max_num_bins);
|
||||||
const int32_t nthread = omp_get_max_threads();
|
const int32_t nthread = omp_get_max_threads();
|
||||||
const uint32_t nbins = cut.row_ptr.back();
|
const uint32_t nbins = cut.Ptrs().back();
|
||||||
hit_count.resize(nbins, 0);
|
hit_count.resize(nbins, 0);
|
||||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||||
|
|
||||||
@ -208,7 +359,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
|||||||
#pragma omp parallel num_threads(batch_threads)
|
#pragma omp parallel num_threads(batch_threads)
|
||||||
{
|
{
|
||||||
#pragma omp for
|
#pragma omp for
|
||||||
for (int32_t tid = 0; tid < batch_threads; ++tid) {
|
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||||
size_t ibegin = block_size * tid;
|
size_t ibegin = block_size * tid;
|
||||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||||
|
|
||||||
@ -222,13 +373,13 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
|||||||
#pragma omp single
|
#pragma omp single
|
||||||
{
|
{
|
||||||
p_part[0] = prev_sum;
|
p_part[0] = prev_sum;
|
||||||
for (int32_t i = 1; i < batch_threads; ++i) {
|
for (size_t i = 1; i < batch_threads; ++i) {
|
||||||
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
|
p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma omp for
|
#pragma omp for
|
||||||
for (int32_t tid = 0; tid < batch_threads; ++tid) {
|
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||||
size_t ibegin = block_size * tid;
|
size_t ibegin = block_size * tid;
|
||||||
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1)));
|
||||||
|
|
||||||
@ -240,7 +391,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
|||||||
|
|
||||||
index.resize(row_ptr[rbegin + batch.Size()]);
|
index.resize(row_ptr[rbegin + batch.Size()]);
|
||||||
|
|
||||||
CHECK_GT(cut.cut.size(), 0U);
|
CHECK_GT(cut.Values().size(), 0U);
|
||||||
|
|
||||||
#pragma omp parallel for num_threads(batch_threads) schedule(static)
|
#pragma omp parallel for num_threads(batch_threads) schedule(static)
|
||||||
for (omp_ulong i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
for (omp_ulong i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
||||||
@ -251,7 +402,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
|||||||
|
|
||||||
CHECK_EQ(ibegin + inst.size(), iend);
|
CHECK_EQ(ibegin + inst.size(), iend);
|
||||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||||
uint32_t idx = cut.GetBinIdx(inst[j]);
|
uint32_t idx = cut.SearchBin(inst[j]);
|
||||||
|
|
||||||
index[ibegin + j] = idx;
|
index[ibegin + j] = idx;
|
||||||
++hit_count_tloc_[tid * nbins + idx];
|
++hit_count_tloc_[tid * nbins + idx];
|
||||||
@ -382,7 +533,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
|
|||||||
const ColumnMatrix& colmat,
|
const ColumnMatrix& colmat,
|
||||||
const tree::TrainParam& param) {
|
const tree::TrainParam& param) {
|
||||||
const size_t nrow = gmat.row_ptr.size() - 1;
|
const size_t nrow = gmat.row_ptr.size() - 1;
|
||||||
const size_t nfeature = gmat.cut.row_ptr.size() - 1;
|
const size_t nfeature = gmat.cut.Ptrs().size() - 1;
|
||||||
|
|
||||||
std::vector<unsigned> feature_list(nfeature);
|
std::vector<unsigned> feature_list(nfeature);
|
||||||
std::iota(feature_list.begin(), feature_list.end(), 0);
|
std::iota(feature_list.begin(), feature_list.end(), 0);
|
||||||
@ -438,7 +589,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
|||||||
cut_ = &gmat.cut;
|
cut_ = &gmat.cut;
|
||||||
|
|
||||||
const size_t nrow = gmat.row_ptr.size() - 1;
|
const size_t nrow = gmat.row_ptr.size() - 1;
|
||||||
const uint32_t nbins = gmat.cut.row_ptr.back();
|
const uint32_t nbins = gmat.cut.Ptrs().back();
|
||||||
|
|
||||||
/* step 1: form feature groups */
|
/* step 1: form feature groups */
|
||||||
auto groups = FastFeatureGrouping(gmat, colmat, param);
|
auto groups = FastFeatureGrouping(gmat, colmat, param);
|
||||||
@ -448,8 +599,8 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
|||||||
std::vector<uint32_t> bin2block(nbins); // lookup table [bin id] => [block id]
|
std::vector<uint32_t> bin2block(nbins); // lookup table [bin id] => [block id]
|
||||||
for (uint32_t group_id = 0; group_id < nblock; ++group_id) {
|
for (uint32_t group_id = 0; group_id < nblock; ++group_id) {
|
||||||
for (auto& fid : groups[group_id]) {
|
for (auto& fid : groups[group_id]) {
|
||||||
const uint32_t bin_begin = gmat.cut.row_ptr[fid];
|
const uint32_t bin_begin = gmat.cut.Ptrs()[fid];
|
||||||
const uint32_t bin_end = gmat.cut.row_ptr[fid + 1];
|
const uint32_t bin_end = gmat.cut.Ptrs()[fid + 1];
|
||||||
for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
|
for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
|
||||||
bin2block[bin_id] = group_id;
|
bin2block[bin_id] = group_id;
|
||||||
}
|
}
|
||||||
@ -627,8 +778,8 @@ void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
|||||||
const size_t block_size = 1024; // aproximatly 1024 values per block
|
const size_t block_size = 1024; // aproximatly 1024 values per block
|
||||||
size_t n_blocks = size/block_size + !!(size%block_size);
|
size_t n_blocks = size/block_size + !!(size%block_size);
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int iblock = 0; iblock < n_blocks; ++iblock) {
|
for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) {
|
||||||
const size_t ibegin = iblock*block_size;
|
const size_t ibegin = iblock*block_size;
|
||||||
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
|
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
|
||||||
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
|
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "./hist_util.h"
|
#include "./hist_util.h"
|
||||||
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/functional.h>
|
#include <thrust/functional.h>
|
||||||
@ -24,7 +25,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
using WXQSketch = HistCutMatrix::WXQSketch;
|
using WXQSketch = DenseCuts::WXQSketch;
|
||||||
|
|
||||||
__global__ void FindCutsK
|
__global__ void FindCutsK
|
||||||
(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data,
|
(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data,
|
||||||
@ -92,7 +93,7 @@ __global__ void UnpackFeaturesK
|
|||||||
* across distinct rows.
|
* across distinct rows.
|
||||||
*/
|
*/
|
||||||
struct SketchContainer {
|
struct SketchContainer {
|
||||||
std::vector<HistCutMatrix::WXQSketch> sketches_; // NOLINT
|
std::vector<DenseCuts::WXQSketch> sketches_; // NOLINT
|
||||||
std::vector<std::mutex> col_locks_; // NOLINT
|
std::vector<std::mutex> col_locks_; // NOLINT
|
||||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||||
|
|
||||||
@ -300,7 +301,7 @@ struct GPUSketcher {
|
|||||||
} else if (n_cuts_cur_[icol] > 0) {
|
} else if (n_cuts_cur_[icol] > 0) {
|
||||||
// if more elements than cuts: use binary search on cumulative weights
|
// if more elements than cuts: use binary search on cumulative weights
|
||||||
int block = 256;
|
int block = 256;
|
||||||
FindCutsK<<<dh::DivRoundUp(n_cuts_cur_[icol], block), block>>>
|
FindCutsK<<<common::DivRoundUp(n_cuts_cur_[icol], block), block>>>
|
||||||
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
|
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
|
||||||
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
|
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
|
||||||
dh::safe_cuda(cudaGetLastError()); // NOLINT
|
dh::safe_cuda(cudaGetLastError()); // NOLINT
|
||||||
@ -342,8 +343,8 @@ struct GPUSketcher {
|
|||||||
|
|
||||||
dim3 block3(16, 64, 1);
|
dim3 block3(16, 64, 1);
|
||||||
// NOTE: This will typically support ~ 4M features - 64K*64
|
// NOTE: This will typically support ~ 4M features - 64K*64
|
||||||
dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
|
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||||
dh::DivRoundUp(num_cols_, block3.y), 1);
|
common::DivRoundUp(num_cols_, block3.y), 1);
|
||||||
UnpackFeaturesK<<<grid3, block3>>>
|
UnpackFeaturesK<<<grid3, block3>>>
|
||||||
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
|
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
|
||||||
row_ptrs_.data().get() + batch_row_begin,
|
row_ptrs_.data().get() + batch_row_begin,
|
||||||
@ -392,7 +393,7 @@ struct GPUSketcher {
|
|||||||
row_ptrs_.resize(n_rows_ + 1);
|
row_ptrs_.resize(n_rows_ + 1);
|
||||||
thrust::copy(offset_vec.data() + row_begin_,
|
thrust::copy(offset_vec.data() + row_begin_,
|
||||||
offset_vec.data() + row_end_ + 1, row_ptrs_.begin());
|
offset_vec.data() + row_end_ + 1, row_ptrs_.begin());
|
||||||
size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_);
|
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
|
||||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||||
SketchBatch(row_batch, info, gpu_batch);
|
SketchBatch(row_batch, info, gpu_batch);
|
||||||
}
|
}
|
||||||
@ -434,7 +435,7 @@ struct GPUSketcher {
|
|||||||
|
|
||||||
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
||||||
* for the entire dataset */
|
* for the entire dataset */
|
||||||
size_t Sketch(DMatrix *dmat, HistCutMatrix *hmat) {
|
size_t Sketch(DMatrix *dmat, DenseCuts *hmat) {
|
||||||
const MetaInfo &info = dmat->Info();
|
const MetaInfo &info = dmat->Info();
|
||||||
|
|
||||||
row_stride_ = 0;
|
row_stride_ = 0;
|
||||||
@ -459,9 +460,13 @@ struct GPUSketcher {
|
|||||||
|
|
||||||
size_t DeviceSketch
|
size_t DeviceSketch
|
||||||
(const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
|
(const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
|
||||||
DMatrix *dmat, HistCutMatrix *hmat) {
|
DMatrix *dmat, HistogramCuts *hmat) {
|
||||||
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
|
GPUSketcher sketcher(param, learner_param, gpu_batch_nrows);
|
||||||
return sketcher.Sketch(dmat, hmat);
|
// We only need to return the result in HistogramCuts container, so it is safe to
|
||||||
|
// use a pointer of local HistogramCutsDense
|
||||||
|
DenseCuts dense_cuts(hmat);
|
||||||
|
auto res = sketcher.Sketch(dmat, &dense_cuts);
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -12,18 +12,21 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "row_set.h"
|
#include "row_set.h"
|
||||||
#include "../tree/param.h"
|
#include "../tree/param.h"
|
||||||
#include "./quantile.h"
|
#include "./quantile.h"
|
||||||
#include "./timer.h"
|
#include "./timer.h"
|
||||||
#include "../include/rabit/rabit.h"
|
|
||||||
#include "random.h"
|
#include "random.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
|
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
|
||||||
|
* MaxStackSize, it will be allocated inside the stack. Otherwise, it will be
|
||||||
|
* heap-allocated.
|
||||||
*/
|
*/
|
||||||
template<typename T, size_t MaxStackSize>
|
template<typename T, size_t MaxStackSize>
|
||||||
class MemStackAllocator {
|
class MemStackAllocator {
|
||||||
@ -122,47 +125,175 @@ struct SimpleArray {
|
|||||||
size_t n_ = 0;
|
size_t n_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Cut configuration for all the features. */
|
/*!
|
||||||
struct HistCutMatrix {
|
* \brief A single row in global histogram index.
|
||||||
/*! \brief Unit pointer to rows by element position */
|
* Directly represent the global index in the histogram entry.
|
||||||
std::vector<uint32_t> row_ptr;
|
*/
|
||||||
/*! \brief minimum value of each feature */
|
using GHistIndexRow = Span<uint32_t const>;
|
||||||
std::vector<bst_float> min_val;
|
|
||||||
/*! \brief the cut field */
|
|
||||||
std::vector<bst_float> cut;
|
|
||||||
uint32_t GetBinIdx(const Entry &e);
|
|
||||||
|
|
||||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
// A CSC matrix representing histogram cuts, used in CPU quantile hist.
|
||||||
|
class HistogramCuts {
|
||||||
// create histogram cut matrix given statistics from data
|
// Using friends to avoid creating a virtual class, since HistogramCuts is used as value
|
||||||
// using approximate quantile sketch approach
|
// object in many places.
|
||||||
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
|
friend class SparseCuts;
|
||||||
|
friend class DenseCuts;
|
||||||
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
friend class CutsBuilder;
|
||||||
|
|
||||||
HistCutMatrix();
|
|
||||||
size_t NumBins() const { return row_ptr.back(); }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual size_t SearchGroupIndFromBaseRow(
|
using BinIdx = uint32_t;
|
||||||
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) const;
|
common::Monitor monitor_;
|
||||||
|
|
||||||
Monitor monitor_;
|
std::vector<bst_float> cut_values_;
|
||||||
|
std::vector<uint32_t> cut_ptrs_;
|
||||||
|
std::vector<float> min_vals_; // storing minimum value in a sketch set.
|
||||||
|
|
||||||
|
public:
|
||||||
|
HistogramCuts();
|
||||||
|
HistogramCuts(HistogramCuts const& that) = delete;
|
||||||
|
HistogramCuts(HistogramCuts&& that) noexcept(true) {
|
||||||
|
*this = std::forward<HistogramCuts&&>(that);
|
||||||
|
}
|
||||||
|
HistogramCuts& operator=(HistogramCuts const& that) = delete;
|
||||||
|
HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
|
||||||
|
monitor_ = std::move(that.monitor_);
|
||||||
|
cut_ptrs_ = std::move(that.cut_ptrs_);
|
||||||
|
cut_values_ = std::move(that.cut_values_);
|
||||||
|
min_vals_ = std::move(that.min_vals_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* \brief Build histogram cuts. */
|
||||||
|
void Build(DMatrix* dmat, uint32_t const max_num_bins);
|
||||||
|
/* \brief How many bins a feature has. */
|
||||||
|
uint32_t FeatureBins(uint32_t feature) const {
|
||||||
|
return cut_ptrs_.at(feature+1) - cut_ptrs_[feature];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getters. Cuts should be of no use after building histogram indices, but currently
|
||||||
|
// it's deeply linked with quantile_hist, gpu sketcher and gpu_hist. So we preserve
|
||||||
|
// these for now.
|
||||||
|
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_; }
|
||||||
|
std::vector<float> const& Values() const { return cut_values_; }
|
||||||
|
std::vector<float> const& MinValues() const { return min_vals_; }
|
||||||
|
|
||||||
|
size_t TotalBins() const { return cut_ptrs_.back(); }
|
||||||
|
|
||||||
|
BinIdx SearchBin(float value, uint32_t column_id) {
|
||||||
|
auto beg = cut_ptrs_.at(column_id);
|
||||||
|
auto end = cut_ptrs_.at(column_id + 1);
|
||||||
|
auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value);
|
||||||
|
if (it == cut_values_.cend()) {
|
||||||
|
it = cut_values_.cend() - 1;
|
||||||
|
}
|
||||||
|
BinIdx idx = it - cut_values_.cbegin();
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
BinIdx SearchBin(Entry const& e) {
|
||||||
|
return SearchBin(e.fvalue, e.index);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* \brief An interface for building quantile cuts.
|
||||||
|
*
|
||||||
|
* `DenseCuts' always assumes there are `max_bins` for each feature, which makes it not
|
||||||
|
* suitable for sparse dataset. On the other hand `SparseCuts' uses `GetColumnBatches',
|
||||||
|
* which doubles the memory usage, hence can not be applied to dense dataset.
|
||||||
|
*/
|
||||||
|
class CutsBuilder {
|
||||||
|
public:
|
||||||
|
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
HistogramCuts* p_cuts_;
|
||||||
|
/* \brief return whether group for ranking is used. */
|
||||||
|
static bool UseGroup(DMatrix* dmat);
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
|
||||||
|
virtual ~CutsBuilder() = default;
|
||||||
|
|
||||||
|
static uint32_t SearchGroupIndFromRow(
|
||||||
|
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
|
||||||
|
using KIt = std::vector<bst_uint>::const_iterator;
|
||||||
|
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
|
||||||
|
// Cannot use CHECK_NE because it will try to print the iterator.
|
||||||
|
bool const found = res != group_ptr.cend() - 1;
|
||||||
|
if (!found) {
|
||||||
|
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!";
|
||||||
|
}
|
||||||
|
uint32_t group_ind = std::distance(group_ptr.cbegin(), res);
|
||||||
|
return group_ind;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddCutPoint(WXQSketch::SummaryContainer const& summary) {
|
||||||
|
if (summary.size > 1 && summary.size <= 16) {
|
||||||
|
/* specialized code categorial / ordinal data -- use midpoints */
|
||||||
|
for (size_t i = 1; i < summary.size; ++i) {
|
||||||
|
bst_float cpt = (summary.data[i].value + summary.data[i - 1].value) / 2.0f;
|
||||||
|
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
|
||||||
|
p_cuts_->cut_values_.push_back(cpt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 2; i < summary.size; ++i) {
|
||||||
|
bst_float cpt = summary.data[i - 1].value;
|
||||||
|
if (i == 2 || cpt > p_cuts_->cut_values_.back()) {
|
||||||
|
p_cuts_->cut_values_.push_back(cpt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* \brief Build histogram indices. */
|
||||||
|
virtual void Build(DMatrix* dmat, uint32_t const max_num_bins) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief Cut configuration for sparse dataset. */
|
||||||
|
class SparseCuts : public CutsBuilder {
|
||||||
|
/* \brief Distrbute columns to each thread according to number of entries. */
|
||||||
|
static std::vector<size_t> LoadBalance(SparsePage const& page, size_t const nthreads);
|
||||||
|
Monitor monitor_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit SparseCuts(HistogramCuts* container) :
|
||||||
|
CutsBuilder(container) {
|
||||||
|
monitor_.Init(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* \brief Concatonate the built cuts in each thread. */
|
||||||
|
void Concat(std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols);
|
||||||
|
/* \brief Build histogram indices in single thread. */
|
||||||
|
void SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||||
|
uint32_t max_num_bins,
|
||||||
|
bool const use_group_ind,
|
||||||
|
uint32_t beg, uint32_t end, uint32_t thread_id);
|
||||||
|
void Build(DMatrix* dmat, uint32_t const max_num_bins) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief Cut configuration for dense dataset. */
|
||||||
|
class DenseCuts : public CutsBuilder {
|
||||||
|
protected:
|
||||||
|
Monitor monitor_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit DenseCuts(HistogramCuts* container) :
|
||||||
|
CutsBuilder(container) {
|
||||||
|
monitor_.Init(__FUNCTION__);
|
||||||
|
}
|
||||||
|
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
||||||
|
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// FIXME(trivialfis): Merge this into generic cut builder.
|
||||||
/*! \brief Builds the cut matrix on the GPU.
|
/*! \brief Builds the cut matrix on the GPU.
|
||||||
*
|
*
|
||||||
* \return The row stride across the entire dataset.
|
* \return The row stride across the entire dataset.
|
||||||
*/
|
*/
|
||||||
size_t DeviceSketch
|
size_t DeviceSketch
|
||||||
(const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
|
(const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows,
|
||||||
DMatrix* dmat, HistCutMatrix* hmat);
|
DMatrix* dmat, HistogramCuts* hmat);
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief A single row in global histogram index.
|
|
||||||
* Directly represent the global index in the histogram entry.
|
|
||||||
*/
|
|
||||||
using GHistIndexRow = Span<uint32_t const>;
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief preprocessed global index matrix, in CSR format
|
* \brief preprocessed global index matrix, in CSR format
|
||||||
@ -178,7 +309,7 @@ struct GHistIndexMatrix {
|
|||||||
/*! \brief hit count of each index */
|
/*! \brief hit count of each index */
|
||||||
std::vector<size_t> hit_count;
|
std::vector<size_t> hit_count;
|
||||||
/*! \brief The corresponding cuts */
|
/*! \brief The corresponding cuts */
|
||||||
HistCutMatrix cut;
|
HistogramCuts cut;
|
||||||
// Create a global histogram matrix, given cut
|
// Create a global histogram matrix, given cut
|
||||||
void Init(DMatrix* p_fmat, int max_num_bins);
|
void Init(DMatrix* p_fmat, int max_num_bins);
|
||||||
// get i-th row
|
// get i-th row
|
||||||
@ -188,10 +319,10 @@ struct GHistIndexMatrix {
|
|||||||
row_ptr[i + 1] - row_ptr[i])};
|
row_ptr[i + 1] - row_ptr[i])};
|
||||||
}
|
}
|
||||||
inline void GetFeatureCounts(size_t* counts) const {
|
inline void GetFeatureCounts(size_t* counts) const {
|
||||||
auto nfeature = cut.row_ptr.size() - 1;
|
auto nfeature = cut.Ptrs().size() - 1;
|
||||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||||
auto ibegin = cut.row_ptr[fid];
|
auto ibegin = cut.Ptrs()[fid];
|
||||||
auto iend = cut.row_ptr[fid + 1];
|
auto iend = cut.Ptrs()[fid + 1];
|
||||||
for (auto i = ibegin; i < iend; ++i) {
|
for (auto i = ibegin; i < iend; ++i) {
|
||||||
counts[fid] += hit_count[i];
|
counts[fid] += hit_count[i];
|
||||||
}
|
}
|
||||||
@ -234,7 +365,7 @@ class GHistIndexBlockMatrix {
|
|||||||
private:
|
private:
|
||||||
std::vector<size_t> row_ptr_;
|
std::vector<size_t> row_ptr_;
|
||||||
std::vector<uint32_t> index_;
|
std::vector<uint32_t> index_;
|
||||||
const HistCutMatrix* cut_;
|
const HistogramCuts* cut_;
|
||||||
struct Block {
|
struct Block {
|
||||||
const size_t* row_ptr_begin;
|
const size_t* row_ptr_begin;
|
||||||
const size_t* row_ptr_end;
|
const size_t* row_ptr_end;
|
||||||
|
|||||||
@ -549,7 +549,7 @@ class Span {
|
|||||||
detail::ExtentValue<Extent, Offset, Count>::value> {
|
detail::ExtentValue<Extent, Offset, Count>::value> {
|
||||||
SPAN_CHECK(Offset >= 0 && (Offset < size() || size() == 0));
|
SPAN_CHECK(Offset >= 0 && (Offset < size() || size() == 0));
|
||||||
SPAN_CHECK(Count == dynamic_extent ||
|
SPAN_CHECK(Count == dynamic_extent ||
|
||||||
Count >= 0 && Offset + Count <= size());
|
(Count >= 0 && Offset + Count <= size()));
|
||||||
|
|
||||||
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class Transform {
|
|||||||
Evaluator(Functor func, Range range, GPUSet devices, bool shard) :
|
Evaluator(Functor func, Range range, GPUSet devices, bool shard) :
|
||||||
func_(func), range_{std::move(range)},
|
func_(func), range_{std::move(range)},
|
||||||
shard_{shard},
|
shard_{shard},
|
||||||
distribution_{std::move(GPUDistribution::Block(devices))} {}
|
distribution_{GPUDistribution::Block(devices)} {}
|
||||||
Evaluator(Functor func, Range range, GPUDistribution dist,
|
Evaluator(Functor func, Range range, GPUDistribution dist,
|
||||||
bool shard) :
|
bool shard) :
|
||||||
func_(func), range_{std::move(range)}, shard_{shard},
|
func_(func), range_{std::move(range)}, shard_{shard},
|
||||||
@ -142,7 +142,7 @@ class Transform {
|
|||||||
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
const int GRID_SIZE =
|
const int GRID_SIZE =
|
||||||
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
|
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
|
||||||
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
|
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
|
||||||
_func, shard_range, UnpackHDV(_vectors, device)...);
|
_func, shard_range, UnpackHDV(_vectors, device)...);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,14 +52,14 @@ class SparsePageSource : public DataSource {
|
|||||||
* \param page_size Page size for external memory.
|
* \param page_size Page size for external memory.
|
||||||
*/
|
*/
|
||||||
static void CreateRowPage(dmlc::Parser<uint32_t>* src,
|
static void CreateRowPage(dmlc::Parser<uint32_t>* src,
|
||||||
const std::string& cache_info,
|
const std::string& cache_info,
|
||||||
const size_t page_size = DMatrix::kPageSize);
|
const size_t page_size = DMatrix::kPageSize);
|
||||||
/*!
|
/*!
|
||||||
* \brief Create source cache by copy content from DMatrix.
|
* \brief Create source cache by copy content from DMatrix.
|
||||||
* \param cache_info The cache_info of cache file location.
|
* \param cache_info The cache_info of cache file location.
|
||||||
*/
|
*/
|
||||||
static void CreateRowPage(DMatrix* src,
|
static void CreateRowPage(DMatrix* src,
|
||||||
const std::string& cache_info);
|
const std::string& cache_info);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Create source cache by copy content from DMatrix. Creates transposed column page, may be sorted or not.
|
* \brief Create source cache by copy content from DMatrix. Creates transposed column page, may be sorted or not.
|
||||||
@ -67,7 +67,7 @@ class SparsePageSource : public DataSource {
|
|||||||
* \param sorted Whether columns should be pre-sorted
|
* \param sorted Whether columns should be pre-sorted
|
||||||
*/
|
*/
|
||||||
static void CreateColumnPage(DMatrix* src,
|
static void CreateColumnPage(DMatrix* src,
|
||||||
const std::string& cache_info, bool sorted);
|
const std::string& cache_info, bool sorted);
|
||||||
/*!
|
/*!
|
||||||
* \brief Check if the cache file already exists.
|
* \brief Check if the cache file already exists.
|
||||||
* \param cache_info The cache prefix of files.
|
* \param cache_info The cache prefix of files.
|
||||||
|
|||||||
@ -238,7 +238,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
auto& offsets = *out_offsets;
|
auto& offsets = *out_offsets;
|
||||||
size_t n_shards = devices_.Size();
|
size_t n_shards = devices_.Size();
|
||||||
offsets.resize(n_shards + 2);
|
offsets.resize(n_shards + 2);
|
||||||
size_t rows_per_shard = dh::DivRoundUp(batch_size, n_shards);
|
size_t rows_per_shard = common::DivRoundUp(batch_size, n_shards);
|
||||||
for (size_t shard = 0; shard < devices_.Size(); ++shard) {
|
for (size_t shard = 0; shard < devices_.Size(); ++shard) {
|
||||||
size_t n_rows = std::min(batch_size, shard * rows_per_shard);
|
size_t n_rows = std::min(batch_size, shard * rows_per_shard);
|
||||||
offsets[shard] = batch_offset + n_rows * n_classes;
|
offsets[shard] = batch_offset + n_rows * n_classes;
|
||||||
@ -284,7 +284,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
const int BLOCK_THREADS = 128;
|
const int BLOCK_THREADS = 128;
|
||||||
size_t num_rows = batch.offset.DeviceSize(device_) - 1;
|
size_t num_rows = batch.offset.DeviceSize(device_) - 1;
|
||||||
const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS));
|
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
|
|
||||||
int shared_memory_bytes = static_cast<int>
|
int shared_memory_bytes = static_cast<int>
|
||||||
(sizeof(float) * num_features * BLOCK_THREADS);
|
(sizeof(float) * num_features * BLOCK_THREADS);
|
||||||
|
|||||||
@ -170,7 +170,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
|
|||||||
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
|
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
|
||||||
int constexpr kBlockThreads = 256;
|
int constexpr kBlockThreads = 256;
|
||||||
const int n_grids = static_cast<int>(
|
const int n_grids = static_cast<int>(
|
||||||
dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
|
common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
|
||||||
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
|
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
|
||||||
output_buffer_bits_, input_buffer_bits_);
|
output_buffer_bits_, input_buffer_bits_);
|
||||||
}
|
}
|
||||||
@ -227,7 +227,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
|
|||||||
|
|
||||||
int constexpr kBlockThreads = 256;
|
int constexpr kBlockThreads = 256;
|
||||||
const int n_grids = static_cast<int>(
|
const int n_grids = static_cast<int>(
|
||||||
dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
|
common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
|
||||||
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
|
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
|
||||||
|
|
||||||
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
|
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
|
||||||
@ -328,8 +328,8 @@ void FeatureInteractionConstraint::Split(
|
|||||||
BitField right = s_node_constraints_[right_id];
|
BitField right = s_node_constraints_[right_id];
|
||||||
|
|
||||||
dim3 const block3(16, 64, 1);
|
dim3 const block3(16, 64, 1);
|
||||||
dim3 const grid3(dh::DivRoundUp(n_sets_, 16),
|
dim3 const grid3(common::DivRoundUp(n_sets_, 16),
|
||||||
dh::DivRoundUp(s_fconstraints_.size(), 64));
|
common::DivRoundUp(s_fconstraints_.size(), 64));
|
||||||
RestoreFeatureListFromSetsKernel<<<grid3, block3>>>
|
RestoreFeatureListFromSetsKernel<<<grid3, block3>>>
|
||||||
(feature_buffer_,
|
(feature_buffer_,
|
||||||
feature_id,
|
feature_id,
|
||||||
@ -339,7 +339,7 @@ void FeatureInteractionConstraint::Split(
|
|||||||
s_sets_ptr_);
|
s_sets_ptr_);
|
||||||
|
|
||||||
int constexpr kBlockThreads = 256;
|
int constexpr kBlockThreads = 256;
|
||||||
const int n_grids = static_cast<int>(dh::DivRoundUp(node.Size(), kBlockThreads));
|
const int n_grids = static_cast<int>(common::DivRoundUp(node.Size(), kBlockThreads));
|
||||||
InteractionConstraintSplitKernel<<<n_grids, kBlockThreads>>>
|
InteractionConstraintSplitKernel<<<n_grids, kBlockThreads>>>
|
||||||
(feature_buffer_,
|
(feature_buffer_,
|
||||||
feature_id,
|
feature_id,
|
||||||
|
|||||||
@ -76,7 +76,7 @@ static const int kNoneKey = -100;
|
|||||||
*/
|
*/
|
||||||
template <int BLKDIM_L1L3 = 256>
|
template <int BLKDIM_L1L3 = 256>
|
||||||
int ScanTempBufferSize(int size) {
|
int ScanTempBufferSize(int size) {
|
||||||
int num_blocks = dh::DivRoundUp(size, BLKDIM_L1L3);
|
int num_blocks = common::DivRoundUp(size, BLKDIM_L1L3);
|
||||||
return num_blocks;
|
return num_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,7 +250,7 @@ void ReduceScanByKey(common::Span<GradientPair> sums,
|
|||||||
common::Span<GradientPair> tmpScans,
|
common::Span<GradientPair> tmpScans,
|
||||||
common::Span<int> tmpKeys,
|
common::Span<int> tmpKeys,
|
||||||
common::Span<const int> colIds, NodeIdT nodeStart) {
|
common::Span<const int> colIds, NodeIdT nodeStart) {
|
||||||
int nBlks = dh::DivRoundUp(size, BLKDIM_L1L3);
|
int nBlks = common::DivRoundUp(size, BLKDIM_L1L3);
|
||||||
cudaMemset(sums.data(), 0, nUniqKeys * nCols * sizeof(GradientPair));
|
cudaMemset(sums.data(), 0, nUniqKeys * nCols * sizeof(GradientPair));
|
||||||
CubScanByKeyL1<BLKDIM_L1L3>
|
CubScanByKeyL1<BLKDIM_L1L3>
|
||||||
<<<nBlks, BLKDIM_L1L3>>>(scans, vals, instIds, tmpScans, tmpKeys, keys,
|
<<<nBlks, BLKDIM_L1L3>>>(scans, vals, instIds, tmpScans, tmpKeys, keys,
|
||||||
@ -448,7 +448,7 @@ void ArgMaxByKey(common::Span<ExactSplitCandidate> nodeSplits,
|
|||||||
dh::FillConst<ExactSplitCandidate, BLKDIM, ITEMS_PER_THREAD>(
|
dh::FillConst<ExactSplitCandidate, BLKDIM, ITEMS_PER_THREAD>(
|
||||||
*(devices.begin()), nodeSplits.data(), nUniqKeys,
|
*(devices.begin()), nodeSplits.data(), nUniqKeys,
|
||||||
ExactSplitCandidate());
|
ExactSplitCandidate());
|
||||||
int nBlks = dh::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM);
|
int nBlks = common::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM);
|
||||||
switch (algo) {
|
switch (algo) {
|
||||||
case kAbkGmem:
|
case kAbkGmem:
|
||||||
AtomicArgMaxByKeyGmem<<<nBlks, BLKDIM>>>(
|
AtomicArgMaxByKeyGmem<<<nBlks, BLKDIM>>>(
|
||||||
@ -793,11 +793,11 @@ class GPUMaker : public TreeUpdater {
|
|||||||
const int BlkDim = 256;
|
const int BlkDim = 256;
|
||||||
const int ItemsPerThread = 4;
|
const int ItemsPerThread = 4;
|
||||||
// assign default node ids first
|
// assign default node ids first
|
||||||
int nBlks = dh::DivRoundUp(n_rows_, BlkDim);
|
int nBlks = common::DivRoundUp(n_rows_, BlkDim);
|
||||||
FillDefaultNodeIds<<<nBlks, BlkDim>>>(node_assigns_per_inst_.data(),
|
FillDefaultNodeIds<<<nBlks, BlkDim>>>(node_assigns_per_inst_.data(),
|
||||||
nodes_.data(), n_rows_);
|
nodes_.data(), n_rows_);
|
||||||
// evaluate the correct child indices of non-missing values next
|
// evaluate the correct child indices of non-missing values next
|
||||||
nBlks = dh::DivRoundUp(n_vals_, BlkDim * ItemsPerThread);
|
nBlks = common::DivRoundUp(n_vals_, BlkDim * ItemsPerThread);
|
||||||
AssignNodeIds<<<nBlks, BlkDim>>>(
|
AssignNodeIds<<<nBlks, BlkDim>>>(
|
||||||
node_assigns_per_inst_.data(), nodeLocations_.Current(),
|
node_assigns_per_inst_.data(), nodeLocations_.Current(),
|
||||||
nodeAssigns_.Current(), instIds_.Current(), nodes_.data(),
|
nodeAssigns_.Current(), instIds_.Current(), nodes_.data(),
|
||||||
@ -823,7 +823,7 @@ class GPUMaker : public TreeUpdater {
|
|||||||
|
|
||||||
void MarkLeaves() {
|
void MarkLeaves() {
|
||||||
const int BlkDim = 128;
|
const int BlkDim = 128;
|
||||||
int nBlks = dh::DivRoundUp(maxNodes_, BlkDim);
|
int nBlks = common::DivRoundUp(maxNodes_, BlkDim);
|
||||||
MarkLeavesKernel<<<nBlks, BlkDim>>>(nodes_.data(), maxNodes_);
|
MarkLeavesKernel<<<nBlks, BlkDim>>>(nodes_.data(), maxNodes_);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -480,8 +480,8 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
|
||||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||||
const Entry* __restrict__ entries, // One batch of input data
|
const Entry* __restrict__ entries, // One batch of input data
|
||||||
const float* __restrict__ cuts, // HistCutMatrix::cut
|
const float* __restrict__ cuts, // HistogramCuts::cut
|
||||||
const uint32_t* __restrict__ cut_rows, // HistCutMatrix::row_ptrs
|
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
|
||||||
size_t base_row, // batch_row_begin
|
size_t base_row, // batch_row_begin
|
||||||
size_t n_rows,
|
size_t n_rows,
|
||||||
size_t row_stride,
|
size_t row_stride,
|
||||||
@ -593,7 +593,7 @@ struct DeviceShard {
|
|||||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||||
DeviceHistogram<GradientSumT> hist;
|
DeviceHistogram<GradientSumT> hist;
|
||||||
|
|
||||||
/*! \brief row_ptr form HistCutMatrix. */
|
/*! \brief row_ptr form HistogramCuts. */
|
||||||
common::Span<uint32_t> feature_segments;
|
common::Span<uint32_t> feature_segments;
|
||||||
/*! \brief minimum value for each feature. */
|
/*! \brief minimum value for each feature. */
|
||||||
common::Span<bst_float> min_fvalue;
|
common::Span<bst_float> min_fvalue;
|
||||||
@ -654,10 +654,10 @@ struct DeviceShard {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void InitCompressedData(
|
void InitCompressedData(
|
||||||
const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense);
|
const common::HistogramCuts& hmat, size_t row_stride, bool is_dense);
|
||||||
|
|
||||||
void CreateHistIndices(
|
void CreateHistIndices(
|
||||||
const SparsePage &row_batch, const common::HistCutMatrix &hmat,
|
const SparsePage &row_batch, const common::HistogramCuts &hmat,
|
||||||
const RowStateOnDevice &device_row_state, int rows_per_batch);
|
const RowStateOnDevice &device_row_state, int rows_per_batch);
|
||||||
|
|
||||||
~DeviceShard() {
|
~DeviceShard() {
|
||||||
@ -718,7 +718,7 @@ struct DeviceShard {
|
|||||||
// Work out cub temporary memory requirement
|
// Work out cub temporary memory requirement
|
||||||
GPUTrainingParam gpu_param(param);
|
GPUTrainingParam gpu_param(param);
|
||||||
DeviceSplitCandidateReduceOp op(gpu_param);
|
DeviceSplitCandidateReduceOp op(gpu_param);
|
||||||
size_t temp_storage_bytes;
|
size_t temp_storage_bytes = 0;
|
||||||
DeviceSplitCandidate*dummy = nullptr;
|
DeviceSplitCandidate*dummy = nullptr;
|
||||||
cub::DeviceReduce::Reduce(
|
cub::DeviceReduce::Reduce(
|
||||||
nullptr, temp_storage_bytes, dummy,
|
nullptr, temp_storage_bytes, dummy,
|
||||||
@ -806,7 +806,7 @@ struct DeviceShard {
|
|||||||
const int items_per_thread = 8;
|
const int items_per_thread = 8;
|
||||||
const int block_threads = 256;
|
const int block_threads = 256;
|
||||||
const int grid_size = static_cast<int>(
|
const int grid_size = static_cast<int>(
|
||||||
dh::DivRoundUp(n_elements, items_per_thread * block_threads));
|
common::DivRoundUp(n_elements, items_per_thread * block_threads));
|
||||||
if (grid_size <= 0) {
|
if (grid_size <= 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1106,9 +1106,9 @@ struct DeviceShard {
|
|||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||||
const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) {
|
const common::HistogramCuts &hmat, size_t row_stride, bool is_dense) {
|
||||||
n_bins = hmat.row_ptr.back();
|
n_bins = hmat.Ptrs().back();
|
||||||
int null_gidx_value = hmat.row_ptr.back();
|
int null_gidx_value = hmat.Ptrs().back();
|
||||||
|
|
||||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||||
@ -1121,14 +1121,14 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
|||||||
&gpair, n_rows,
|
&gpair, n_rows,
|
||||||
&prediction_cache, n_rows,
|
&prediction_cache, n_rows,
|
||||||
&node_sum_gradients_d, max_nodes,
|
&node_sum_gradients_d, max_nodes,
|
||||||
&feature_segments, hmat.row_ptr.size(),
|
&feature_segments, hmat.Ptrs().size(),
|
||||||
&gidx_fvalue_map, hmat.cut.size(),
|
&gidx_fvalue_map, hmat.Values().size(),
|
||||||
&min_fvalue, hmat.min_val.size(),
|
&min_fvalue, hmat.MinValues().size(),
|
||||||
&monotone_constraints, param.monotone_constraints.size());
|
&monotone_constraints, param.monotone_constraints.size());
|
||||||
|
|
||||||
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut);
|
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
|
||||||
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val);
|
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
|
||||||
dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr);
|
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
|
||||||
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
||||||
|
|
||||||
node_sum_gradients.resize(max_nodes);
|
node_sum_gradients.resize(max_nodes);
|
||||||
@ -1153,26 +1153,26 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
|||||||
// check if we can use shared memory for building histograms
|
// check if we can use shared memory for building histograms
|
||||||
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
// (assuming atleast we need 2 CTAs per SM to maintain decent latency
|
||||||
// hiding)
|
// hiding)
|
||||||
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
|
auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back();
|
||||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||||
if (histogram_size <= max_smem) {
|
if (histogram_size <= max_smem) {
|
||||||
use_shared_memory_histograms = true;
|
use_shared_memory_histograms = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(device_id, hmat.NumBins());
|
hist.Init(device_id, hmat.Ptrs().back());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||||
const SparsePage &row_batch,
|
const SparsePage &row_batch,
|
||||||
const common::HistCutMatrix &hmat,
|
const common::HistogramCuts &hmat,
|
||||||
const RowStateOnDevice &device_row_state,
|
const RowStateOnDevice &device_row_state,
|
||||||
int rows_per_batch) {
|
int rows_per_batch) {
|
||||||
// Has any been allocated for me in this batch?
|
// Has any been allocated for me in this batch?
|
||||||
if (!device_row_state.rows_to_process_from_batch) return;
|
if (!device_row_state.rows_to_process_from_batch) return;
|
||||||
|
|
||||||
unsigned int null_gidx_value = hmat.row_ptr.back();
|
unsigned int null_gidx_value = hmat.Ptrs().back();
|
||||||
size_t row_stride = this->ellpack_matrix.row_stride;
|
size_t row_stride = this->ellpack_matrix.row_stride;
|
||||||
|
|
||||||
const auto &offset_vec = row_batch.offset.ConstHostVector();
|
const auto &offset_vec = row_batch.offset.ConstHostVector();
|
||||||
@ -1184,8 +1184,8 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
|||||||
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
|
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
|
||||||
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
||||||
|
|
||||||
size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch,
|
size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch,
|
||||||
gpu_batch_nrows);
|
gpu_batch_nrows);
|
||||||
|
|
||||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
|
||||||
@ -1216,8 +1216,8 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
|||||||
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
|
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
|
||||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||||
const dim3 block3(32, 8, 1); // 256 threads
|
const dim3 block3(32, 8, 1); // 256 threads
|
||||||
const dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
|
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
common::DivRoundUp(row_stride, block3.y), 1);
|
||||||
CompressBinEllpackKernel<<<grid3, block3>>>
|
CompressBinEllpackKernel<<<grid3, block3>>>
|
||||||
(common::CompressedBufferWriter(num_symbols),
|
(common::CompressedBufferWriter(num_symbols),
|
||||||
gidx_buffer.data(),
|
gidx_buffer.data(),
|
||||||
@ -1361,13 +1361,13 @@ class GPUHistMakerSpecialised {
|
|||||||
});
|
});
|
||||||
|
|
||||||
monitor_.StartCuda("Quantiles");
|
monitor_.StartCuda("Quantiles");
|
||||||
// Create the quantile sketches for the dmatrix and initialize HistCutMatrix
|
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
|
||||||
size_t row_stride = common::DeviceSketch(param_, *learner_param_,
|
size_t row_stride = common::DeviceSketch(param_, *learner_param_,
|
||||||
hist_maker_param_.gpu_batch_nrows,
|
hist_maker_param_.gpu_batch_nrows,
|
||||||
dmat, &hmat_);
|
dmat, &hmat_);
|
||||||
monitor_.StopCuda("Quantiles");
|
monitor_.StopCuda("Quantiles");
|
||||||
|
|
||||||
n_bins_ = hmat_.row_ptr.back();
|
n_bins_ = hmat_.Ptrs().back();
|
||||||
|
|
||||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
||||||
|
|
||||||
@ -1475,9 +1475,9 @@ class GPUHistMakerSpecialised {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
TrainParam param_; // NOLINT
|
TrainParam param_; // NOLINT
|
||||||
common::HistCutMatrix hmat_; // NOLINT
|
common::HistogramCuts hmat_; // NOLINT
|
||||||
MetaInfo* info_; // NOLINT
|
MetaInfo* info_; // NOLINT
|
||||||
|
|
||||||
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT
|
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT
|
||||||
|
|
||||||
|
|||||||
@ -247,15 +247,15 @@ int32_t QuantileHistMaker::Builder::FindSplitCond(int32_t nid,
|
|||||||
// Categorize member rows
|
// Categorize member rows
|
||||||
const bst_uint fid = node.SplitIndex();
|
const bst_uint fid = node.SplitIndex();
|
||||||
const bst_float split_pt = node.SplitCond();
|
const bst_float split_pt = node.SplitCond();
|
||||||
const uint32_t lower_bound = gmat.cut.row_ptr[fid];
|
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
|
||||||
const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1];
|
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
|
||||||
int32_t split_cond = -1;
|
int32_t split_cond = -1;
|
||||||
// convert floating-point split_pt into corresponding bin_id
|
// convert floating-point split_pt into corresponding bin_id
|
||||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||||
CHECK_LT(upper_bound,
|
CHECK_LT(upper_bound,
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
|
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
|
||||||
if (split_pt == gmat.cut.cut[i]) {
|
if (split_pt == gmat.cut.Values()[i]) {
|
||||||
split_cond = static_cast<int32_t>(i);
|
split_cond = static_cast<int32_t>(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -533,7 +533,7 @@ void QuantileHistMaker::Builder::BuildHistsBatch(const std::vector<ExpandEntry>&
|
|||||||
perf_monitor.TickStart();
|
perf_monitor.TickStart();
|
||||||
const size_t block_size_rows = 256;
|
const size_t block_size_rows = 256;
|
||||||
const size_t nthread = static_cast<size_t>(this->nthread_);
|
const size_t nthread = static_cast<size_t>(this->nthread_);
|
||||||
const size_t nbins = gmat.cut.row_ptr.back();
|
const size_t nbins = gmat.cut.Ptrs().back();
|
||||||
const size_t hist_size = 2 * nbins;
|
const size_t hist_size = 2 * nbins;
|
||||||
|
|
||||||
hist_buffers->resize(nodes.size());
|
hist_buffers->resize(nodes.size());
|
||||||
@ -856,8 +856,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma omp parallel for schedule(guided)
|
#pragma omp parallel for schedule(guided)
|
||||||
for (int32_t k = 0; k < tasks_elem.size(); ++k) {
|
for (omp_ulong k = 0; k < tasks_elem.size(); ++k) {
|
||||||
const RowSetCollection::Elem rowset = tasks_elem[k];
|
const RowSetCollection::Elem rowset = tasks_elem[k];
|
||||||
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.node_id != -1) {
|
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.node_id != -1) {
|
||||||
const size_t nrows = rowset.Size();
|
const size_t nrows = rowset.Size();
|
||||||
@ -909,7 +909,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
|||||||
// clear local prediction cache
|
// clear local prediction cache
|
||||||
leaf_value_cache_.clear();
|
leaf_value_cache_.clear();
|
||||||
// initialize histogram collection
|
// initialize histogram collection
|
||||||
uint32_t nbins = gmat.cut.row_ptr.back();
|
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||||
hist_.Init(nbins);
|
hist_.Init(nbins);
|
||||||
hist_buff_.Init(nbins);
|
hist_buff_.Init(nbins);
|
||||||
|
|
||||||
@ -999,7 +999,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
|||||||
const size_t ncol = info.num_col_;
|
const size_t ncol = info.num_col_;
|
||||||
const size_t nnz = info.num_nonzero_;
|
const size_t nnz = info.num_nonzero_;
|
||||||
// number of discrete bins for feature 0
|
// number of discrete bins for feature 0
|
||||||
const uint32_t nbins_f0 = gmat.cut.row_ptr[1] - gmat.cut.row_ptr[0];
|
const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0];
|
||||||
if (nrow * ncol == nnz) {
|
if (nrow * ncol == nnz) {
|
||||||
// dense data with zero-based indexing
|
// dense data with zero-based indexing
|
||||||
data_layout_ = kDenseDataZeroBased;
|
data_layout_ = kDenseDataZeroBased;
|
||||||
@ -1029,7 +1029,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
|||||||
choose the column that has a least positive number of discrete bins.
|
choose the column that has a least positive number of discrete bins.
|
||||||
For dense data (with no missing value),
|
For dense data (with no missing value),
|
||||||
the sum of gradient histogram is equal to snode[nid] */
|
the sum of gradient histogram is equal to snode[nid] */
|
||||||
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
|
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
|
||||||
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
|
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
|
||||||
uint32_t min_nbins_per_feature = 0;
|
uint32_t min_nbins_per_feature = 0;
|
||||||
for (bst_uint i = 0; i < nfeature; ++i) {
|
for (bst_uint i = 0; i < nfeature; ++i) {
|
||||||
@ -1079,8 +1079,8 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch(
|
|||||||
// partial results
|
// partial results
|
||||||
std::vector<std::pair<SplitEntry, SplitEntry>> splits(tasks.size());
|
std::vector<std::pair<SplitEntry, SplitEntry>> splits(tasks.size());
|
||||||
// parallel enumeration
|
// parallel enumeration
|
||||||
#pragma omp parallel for schedule(guided)
|
#pragma omp parallel for schedule(guided)
|
||||||
for (int32_t i = 0; i < tasks.size(); ++i) {
|
for (omp_ulong i = 0; i < tasks.size(); ++i) {
|
||||||
// node_idx : offset within `nodes` list
|
// node_idx : offset within `nodes` list
|
||||||
const int32_t node_idx = tasks[i].first;
|
const int32_t node_idx = tasks[i].first;
|
||||||
const size_t fid = tasks[i].second;
|
const size_t fid = tasks[i].second;
|
||||||
@ -1098,7 +1098,7 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch(
|
|||||||
|
|
||||||
// reduce needed part of a hist here to have it in cache before enumeration
|
// reduce needed part of a hist here to have it in cache before enumeration
|
||||||
if (!rabit::IsDistributed()) {
|
if (!rabit::IsDistributed()) {
|
||||||
const std::vector<uint32_t>& cut_ptr = gmat.cut.row_ptr;
|
const std::vector<uint32_t>& cut_ptr = gmat.cut.Ptrs();
|
||||||
const size_t ibegin = 2 * cut_ptr[fid];
|
const size_t ibegin = 2 * cut_ptr[fid];
|
||||||
const size_t iend = 2 * cut_ptr[fid + 1];
|
const size_t iend = 2 * cut_ptr[fid + 1];
|
||||||
ReduceHistograms(hist_data, sibling_hist_data, parent_hist_data, ibegin, iend, node_idx,
|
ReduceHistograms(hist_data, sibling_hist_data, parent_hist_data, ibegin, iend, node_idx,
|
||||||
@ -1179,8 +1179,8 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
|||||||
CHECK(d_step == +1 || d_step == -1);
|
CHECK(d_step == +1 || d_step == -1);
|
||||||
|
|
||||||
// aliases
|
// aliases
|
||||||
const std::vector<uint32_t>& cut_ptr = gmat.cut.row_ptr;
|
const std::vector<uint32_t>& cut_ptr = gmat.cut.Ptrs();
|
||||||
const std::vector<bst_float>& cut_val = gmat.cut.cut;
|
const std::vector<bst_float>& cut_val = gmat.cut.Values();
|
||||||
|
|
||||||
// statistics on both sides of split
|
// statistics on both sides of split
|
||||||
GradStats c;
|
GradStats c;
|
||||||
@ -1239,7 +1239,7 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
|||||||
|
|
||||||
if (i == imin) {
|
if (i == imin) {
|
||||||
// for leftmost bin, left bound is the smallest feature value
|
// for leftmost bin, left bound is the smallest feature value
|
||||||
split_pt = gmat.cut.min_val[fid];
|
split_pt = gmat.cut.MinValues()[fid];
|
||||||
} else {
|
} else {
|
||||||
split_pt = cut_val[i - 1];
|
split_pt = cut_val[i - 1];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,7 +33,6 @@ namespace common {
|
|||||||
}
|
}
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
using xgboost::common::HistCutMatrix;
|
|
||||||
using xgboost::common::GHistIndexMatrix;
|
using xgboost::common::GHistIndexMatrix;
|
||||||
using xgboost::common::GHistIndexBlockMatrix;
|
using xgboost::common::GHistIndexBlockMatrix;
|
||||||
using xgboost::common::GHistIndexRow;
|
using xgboost::common::GHistIndexRow;
|
||||||
|
|||||||
@ -53,10 +53,10 @@ TEST(c_api, XGDMatrixCreateFromMat_omp) {
|
|||||||
ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing);
|
ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing);
|
||||||
|
|
||||||
for (const auto &batch : (*dmat)->GetRowBatches()) {
|
for (const auto &batch : (*dmat)->GetRowBatches()) {
|
||||||
for (int i = 0; i < batch.Size(); i++) {
|
for (size_t i = 0; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = batch[i];
|
||||||
for (int j = 0; i < inst.size(); i++) {
|
for (auto e : inst) {
|
||||||
ASSERT_EQ(inst[j].fvalue, 1.5);
|
ASSERT_EQ(e.fvalue, 1.5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
TEST(DenseColumn, Test) {
|
TEST(DenseColumn, Test) {
|
||||||
auto dmat = CreateDMatrix(100, 10, 0.0);
|
auto dmat = CreateDMatrix(100, 10, 0.0);
|
||||||
GHistIndexMatrix gmat;
|
GHistIndexMatrix gmat;
|
||||||
@ -17,7 +18,7 @@ TEST(DenseColumn, Test) {
|
|||||||
for (auto i = 0ull; i < (*dmat)->Info().num_row_; i++) {
|
for (auto i = 0ull; i < (*dmat)->Info().num_row_; i++) {
|
||||||
for (auto j = 0ull; j < (*dmat)->Info().num_col_; j++) {
|
for (auto j = 0ull; j < (*dmat)->Info().num_col_; j++) {
|
||||||
auto col = column_matrix.GetColumn(j);
|
auto col = column_matrix.GetColumn(j);
|
||||||
EXPECT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
|
ASSERT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
|
||||||
col.GetGlobalBinIdx(i));
|
col.GetGlobalBinIdx(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -33,7 +34,7 @@ TEST(SparseColumn, Test) {
|
|||||||
auto col = column_matrix.GetColumn(0);
|
auto col = column_matrix.GetColumn(0);
|
||||||
ASSERT_EQ(col.Size(), gmat.index.size());
|
ASSERT_EQ(col.Size(), gmat.index.size());
|
||||||
for (auto i = 0ull; i < col.Size(); i++) {
|
for (auto i = 0ull; i < col.Size(); i++) {
|
||||||
EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
|
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
|
||||||
col.GetGlobalBinIdx(i));
|
col.GetGlobalBinIdx(i));
|
||||||
}
|
}
|
||||||
delete dmat;
|
delete dmat;
|
||||||
|
|||||||
@ -28,7 +28,7 @@ TEST(CompressedIterator, Test) {
|
|||||||
|
|
||||||
CompressedIterator<int> ci(buffer.data(), alphabet_size);
|
CompressedIterator<int> ci(buffer.data(), alphabet_size);
|
||||||
std::vector<int> output(input.size());
|
std::vector<int> output(input.size());
|
||||||
for (int i = 0; i < input.size(); i++) {
|
for (size_t i = 0; i < input.size(); i++) {
|
||||||
output[i] = ci[i];
|
output[i] = ci[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,12 +38,12 @@ TEST(CompressedIterator, Test) {
|
|||||||
std::vector<unsigned char> buffer2(
|
std::vector<unsigned char> buffer2(
|
||||||
CompressedBufferWriter::CalculateBufferSize(input.size(),
|
CompressedBufferWriter::CalculateBufferSize(input.size(),
|
||||||
alphabet_size));
|
alphabet_size));
|
||||||
for (int i = 0; i < input.size(); i++) {
|
for (size_t i = 0; i < input.size(); i++) {
|
||||||
cbw.WriteSymbol(buffer2.data(), input[i], i);
|
cbw.WriteSymbol(buffer2.data(), input[i], i);
|
||||||
}
|
}
|
||||||
CompressedIterator<int> ci2(buffer.data(), alphabet_size);
|
CompressedIterator<int> ci2(buffer.data(), alphabet_size);
|
||||||
std::vector<int> output2(input.size());
|
std::vector<int> output2(input.size());
|
||||||
for (int i = 0; i < input.size(); i++) {
|
for (size_t i = 0; i < input.size(); i++) {
|
||||||
output2[i] = ci2[i];
|
output2[i] = ci2[i];
|
||||||
}
|
}
|
||||||
ASSERT_TRUE(input == output2);
|
ASSERT_TRUE(input == output2);
|
||||||
|
|||||||
@ -48,11 +48,11 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
|
|||||||
int gpu_batch_nrows = 0;
|
int gpu_batch_nrows = 0;
|
||||||
|
|
||||||
// find quantiles on the CPU
|
// find quantiles on the CPU
|
||||||
HistCutMatrix hmat_cpu;
|
HistogramCuts hmat_cpu;
|
||||||
hmat_cpu.Init((*dmat).get(), p.max_bin);
|
hmat_cpu.Build((*dmat).get(), p.max_bin);
|
||||||
|
|
||||||
// find the cuts on the GPU
|
// find the cuts on the GPU
|
||||||
HistCutMatrix hmat_gpu;
|
HistogramCuts hmat_gpu;
|
||||||
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows,
|
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows,
|
||||||
dmat->get(), &hmat_gpu);
|
dmat->get(), &hmat_gpu);
|
||||||
|
|
||||||
@ -69,12 +69,12 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
|
|||||||
|
|
||||||
// compare the cuts
|
// compare the cuts
|
||||||
double eps = 1e-2;
|
double eps = 1e-2;
|
||||||
ASSERT_EQ(hmat_gpu.min_val.size(), num_cols);
|
ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols);
|
||||||
ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1);
|
ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1);
|
||||||
ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size());
|
ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size());
|
||||||
ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows);
|
ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows);
|
||||||
for (int i = 0; i < hmat_gpu.cut.size(); ++i) {
|
for (int i = 0; i < hmat_gpu.Values().size(); ++i) {
|
||||||
ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows);
|
ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete dmat;
|
delete dmat;
|
||||||
|
|||||||
@ -9,15 +9,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
class HistCutMatrixMock : public HistCutMatrix {
|
TEST(CutsBuilder, SearchGroupInd) {
|
||||||
public:
|
|
||||||
size_t SearchGroupIndFromBaseRow(
|
|
||||||
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
|
|
||||||
return HistCutMatrix::SearchGroupIndFromBaseRow(group_ptr, base_rowid);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST(HistCutMatrix, SearchGroupInd) {
|
|
||||||
size_t constexpr kNumGroups = 4;
|
size_t constexpr kNumGroups = 4;
|
||||||
size_t constexpr kNumRows = 17;
|
size_t constexpr kNumRows = 17;
|
||||||
size_t constexpr kNumCols = 15;
|
size_t constexpr kNumCols = 15;
|
||||||
@ -34,18 +26,102 @@ TEST(HistCutMatrix, SearchGroupInd) {
|
|||||||
p_mat->Info().SetInfo(
|
p_mat->Info().SetInfo(
|
||||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||||
|
|
||||||
HistCutMatrixMock hmat;
|
HistogramCuts hmat;
|
||||||
|
|
||||||
size_t group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 0);
|
size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0);
|
||||||
ASSERT_EQ(group_ind, 0);
|
ASSERT_EQ(group_ind, 0);
|
||||||
|
|
||||||
group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 5);
|
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
|
||||||
ASSERT_EQ(group_ind, 2);
|
ASSERT_EQ(group_ind, 2);
|
||||||
|
|
||||||
EXPECT_ANY_THROW(hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 17));
|
EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
|
||||||
|
|
||||||
delete pp_mat;
|
delete pp_mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class SparseCutsWrapper : public SparseCuts {
|
||||||
|
public:
|
||||||
|
std::vector<uint32_t> const& ColPtrs() const { return p_cuts_->Ptrs(); }
|
||||||
|
std::vector<float> const& ColValues() const { return p_cuts_->Values(); }
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(SparseCuts, SingleThreadedBuild) {
|
||||||
|
size_t constexpr kRows = 267;
|
||||||
|
size_t constexpr kCols = 31;
|
||||||
|
size_t constexpr kBins = 256;
|
||||||
|
|
||||||
|
// Dense matrix.
|
||||||
|
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
|
||||||
|
DMatrix* p_fmat = (*pp_mat).get();
|
||||||
|
|
||||||
|
common::GHistIndexMatrix hmat;
|
||||||
|
hmat.Init(p_fmat, kBins);
|
||||||
|
|
||||||
|
HistogramCuts cuts;
|
||||||
|
SparseCuts indices(&cuts);
|
||||||
|
auto const& page = *(p_fmat->GetColumnBatches().begin());
|
||||||
|
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
|
||||||
|
|
||||||
|
ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size());
|
||||||
|
ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs());
|
||||||
|
ASSERT_EQ(hmat.cut.Values(), cuts.Values());
|
||||||
|
ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues());
|
||||||
|
|
||||||
|
delete pp_mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseCuts, MultiThreadedBuild) {
|
||||||
|
size_t constexpr kRows = 17;
|
||||||
|
size_t constexpr kCols = 15;
|
||||||
|
size_t constexpr kBins = 255;
|
||||||
|
|
||||||
|
omp_ulong ori_nthreads = omp_get_max_threads();
|
||||||
|
omp_set_num_threads(16);
|
||||||
|
|
||||||
|
auto Compare =
|
||||||
|
#if defined(_MSC_VER) // msvc fails to capture
|
||||||
|
[kBins](DMatrix* p_fmat) {
|
||||||
|
#else
|
||||||
|
[](DMatrix* p_fmat) {
|
||||||
|
#endif
|
||||||
|
HistogramCuts threaded_container;
|
||||||
|
SparseCuts threaded_indices(&threaded_container);
|
||||||
|
threaded_indices.Build(p_fmat, kBins);
|
||||||
|
|
||||||
|
HistogramCuts container;
|
||||||
|
SparseCuts indices(&container);
|
||||||
|
auto const& page = *(p_fmat->GetColumnBatches().begin());
|
||||||
|
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
|
||||||
|
|
||||||
|
ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size());
|
||||||
|
ASSERT_EQ(container.Values().size(), threaded_container.Values().size());
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < container.Ptrs().size(); ++i) {
|
||||||
|
ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]);
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < container.Values().size(); ++i) {
|
||||||
|
ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
|
||||||
|
DMatrix* p_fmat = (*pp_mat).get();
|
||||||
|
Compare(p_fmat);
|
||||||
|
delete pp_mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto pp_mat = CreateDMatrix(kRows, kCols, 0.0001);
|
||||||
|
DMatrix* p_fmat = (*pp_mat).get();
|
||||||
|
Compare(p_fmat);
|
||||||
|
delete pp_mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
omp_set_num_threads(ori_nthreads);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -53,8 +53,8 @@ TEST(ColumnSampler, Test) {
|
|||||||
TEST(ColumnSampler, ThreadSynchronisation) {
|
TEST(ColumnSampler, ThreadSynchronisation) {
|
||||||
const int64_t num_threads = 100;
|
const int64_t num_threads = 100;
|
||||||
int n = 128;
|
int n = 128;
|
||||||
int iterations = 10;
|
size_t iterations = 10;
|
||||||
int levels = 5;
|
size_t levels = 5;
|
||||||
std::vector<int> reference_result;
|
std::vector<int> reference_result;
|
||||||
bool success =
|
bool success =
|
||||||
true; // Cannot use google test asserts in multithreaded region
|
true; // Cannot use google test asserts in multithreaded region
|
||||||
|
|||||||
@ -310,7 +310,7 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(first.size(), 4);
|
ASSERT_EQ(first.size(), 4);
|
||||||
ASSERT_EQ(first.data(), arr);
|
ASSERT_EQ(first.data(), arr);
|
||||||
|
|
||||||
for (size_t i = 0; i < first.size(); ++i) {
|
for (int64_t i = 0; i < first.size(); ++i) {
|
||||||
ASSERT_EQ(first[i], arr[i]);
|
ASSERT_EQ(first[i], arr[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,7 +329,7 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(last.size(), 4);
|
ASSERT_EQ(last.size(), 4);
|
||||||
ASSERT_EQ(last.data(), arr + 12);
|
ASSERT_EQ(last.data(), arr + 12);
|
||||||
|
|
||||||
for (size_t i = 0; i < last.size(); ++i) {
|
for (int64_t i = 0; i < last.size(); ++i) {
|
||||||
ASSERT_EQ(last[i], arr[i+12]);
|
ASSERT_EQ(last[i], arr[i+12]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -348,7 +348,7 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(first.size(), 4);
|
ASSERT_EQ(first.size(), 4);
|
||||||
ASSERT_EQ(first.data(), s.data());
|
ASSERT_EQ(first.data(), s.data());
|
||||||
|
|
||||||
for (size_t i = 0; i < first.size(); ++i) {
|
for (int64_t i = 0; i < first.size(); ++i) {
|
||||||
ASSERT_EQ(first[i], s[i]);
|
ASSERT_EQ(first[i], s[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,7 +368,7 @@ TEST(Span, FirstLast) {
|
|||||||
ASSERT_EQ(last.size(), 4);
|
ASSERT_EQ(last.size(), 4);
|
||||||
ASSERT_EQ(last.data(), s.data() + 12);
|
ASSERT_EQ(last.data(), s.data() + 12);
|
||||||
|
|
||||||
for (size_t i = 0; i < last.size(); ++i) {
|
for (int64_t i = 0; i < last.size(); ++i) {
|
||||||
ASSERT_EQ(s[12 + i], last[i]);
|
ASSERT_EQ(s[12 + i], last[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -50,7 +50,7 @@ TEST(SparsePage, PushCSC) {
|
|||||||
inst = page[1];
|
inst = page[1];
|
||||||
ASSERT_EQ(inst.size(), 6);
|
ASSERT_EQ(inst.size(), 6);
|
||||||
std::vector<size_t> indices_sol {1, 2, 3};
|
std::vector<size_t> indices_sol {1, 2, 3};
|
||||||
for (size_t i = 0; i < inst.size(); ++i) {
|
for (int64_t i = 0; i < inst.size(); ++i) {
|
||||||
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
|
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,13 +21,13 @@ TEST(cpu_predictor, Test) {
|
|||||||
HostDeviceVector<float> out_predictions;
|
HostDeviceVector<float> out_predictions;
|
||||||
cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0);
|
cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0);
|
||||||
std::vector<float>& out_predictions_h = out_predictions.HostVector();
|
std::vector<float>& out_predictions_h = out_predictions.HostVector();
|
||||||
for (int i = 0; i < out_predictions.Size(); i++) {
|
for (size_t i = 0; i < out_predictions.Size(); i++) {
|
||||||
ASSERT_EQ(out_predictions_h[i], 1.5);
|
ASSERT_EQ(out_predictions_h[i], 1.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test predict instance
|
// Test predict instance
|
||||||
auto &batch = *(*dmat)->GetRowBatches().begin();
|
auto &batch = *(*dmat)->GetRowBatches().begin();
|
||||||
for (int i = 0; i < batch.Size(); i++) {
|
for (size_t i = 0; i < batch.Size(); i++) {
|
||||||
std::vector<float> instance_out_predictions;
|
std::vector<float> instance_out_predictions;
|
||||||
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
|
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
|
||||||
ASSERT_EQ(instance_out_predictions[0], 1.5);
|
ASSERT_EQ(instance_out_predictions[0], 1.5);
|
||||||
|
|||||||
@ -94,7 +94,7 @@ void TestUpdatePosition() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
|
TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
|
||||||
|
|
||||||
void TestFinalise() {
|
void TestFinalise() {
|
||||||
const int kNumRows = 10;
|
const int kNumRows = 10;
|
||||||
RowPartitioner rp(0, kNumRows);
|
RowPartitioner rp(0, kNumRows);
|
||||||
|
|||||||
@ -53,27 +53,43 @@ TEST(GpuHist, DeviceHistogram) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class HistogramCutsWrapper : public common::HistogramCuts {
|
||||||
|
public:
|
||||||
|
using SuperT = common::HistogramCuts;
|
||||||
|
void SetValues(std::vector<float> cuts) {
|
||||||
|
SuperT::cut_values_ = cuts;
|
||||||
|
}
|
||||||
|
void SetPtrs(std::vector<uint32_t> ptrs) {
|
||||||
|
SuperT::cut_ptrs_ = ptrs;
|
||||||
|
}
|
||||||
|
void SetMins(std::vector<float> mins) {
|
||||||
|
SuperT::min_vals_ = mins;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
|
void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
|
||||||
bst_float sparsity=0) {
|
bst_float sparsity=0) {
|
||||||
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
|
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
|
||||||
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
|
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
|
||||||
|
|
||||||
common::HistCutMatrix cmat;
|
HistogramCutsWrapper cmat;
|
||||||
cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24};
|
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||||
cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f};
|
|
||||||
// 24 cut fields, 3 cut fields for each feature (column).
|
// 24 cut fields, 3 cut fields for each feature (column).
|
||||||
cmat.cut = {0.30f, 0.67f, 1.64f,
|
cmat.SetValues({0.30f, 0.67f, 1.64f,
|
||||||
0.32f, 0.77f, 1.95f,
|
0.32f, 0.77f, 1.95f,
|
||||||
0.29f, 0.70f, 1.80f,
|
0.29f, 0.70f, 1.80f,
|
||||||
0.32f, 0.75f, 1.85f,
|
0.32f, 0.75f, 1.85f,
|
||||||
0.18f, 0.59f, 1.69f,
|
0.18f, 0.59f, 1.69f,
|
||||||
0.25f, 0.74f, 2.00f,
|
0.25f, 0.74f, 2.00f,
|
||||||
0.26f, 0.74f, 1.98f,
|
0.26f, 0.74f, 1.98f,
|
||||||
0.26f, 0.71f, 1.83f};
|
0.26f, 0.71f, 1.83f});
|
||||||
|
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
|
||||||
|
|
||||||
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
||||||
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
||||||
@ -241,20 +257,20 @@ TEST(GpuHist, BuildHistSharedMem) {
|
|||||||
TestBuildHist<GradientPair>(true);
|
TestBuildHist<GradientPair>(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
common::HistCutMatrix GetHostCutMatrix () {
|
HistogramCutsWrapper GetHostCutMatrix () {
|
||||||
common::HistCutMatrix cmat;
|
HistogramCutsWrapper cmat;
|
||||||
cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24};
|
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||||
cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f};
|
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
|
||||||
// 24 cut fields, 3 cut fields for each feature (column).
|
// 24 cut fields, 3 cut fields for each feature (column).
|
||||||
// Each row of the cut represents the cuts for a data column.
|
// Each row of the cut represents the cuts for a data column.
|
||||||
cmat.cut = {0.30f, 0.67f, 1.64f,
|
cmat.SetValues({0.30f, 0.67f, 1.64f,
|
||||||
0.32f, 0.77f, 1.95f,
|
0.32f, 0.77f, 1.95f,
|
||||||
0.29f, 0.70f, 1.80f,
|
0.29f, 0.70f, 1.80f,
|
||||||
0.32f, 0.75f, 1.85f,
|
0.32f, 0.75f, 1.85f,
|
||||||
0.18f, 0.59f, 1.69f,
|
0.18f, 0.59f, 1.69f,
|
||||||
0.25f, 0.74f, 2.00f,
|
0.25f, 0.74f, 2.00f,
|
||||||
0.26f, 0.74f, 1.98f,
|
0.26f, 0.74f, 1.98f,
|
||||||
0.26f, 0.71f, 1.83f};
|
0.26f, 0.71f, 1.83f});
|
||||||
return cmat;
|
return cmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,21 +309,21 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
shard->node_sum_gradients = {{6.4f, 12.8f}};
|
shard->node_sum_gradients = {{6.4f, 12.8f}};
|
||||||
|
|
||||||
// Initialize DeviceShard::cut
|
// Initialize DeviceShard::cut
|
||||||
common::HistCutMatrix cmat = GetHostCutMatrix();
|
auto cmat = GetHostCutMatrix();
|
||||||
|
|
||||||
// Copy cut matrix to device.
|
// Copy cut matrix to device.
|
||||||
shard->ba.Allocate(0,
|
shard->ba.Allocate(0,
|
||||||
&(shard->feature_segments), cmat.row_ptr.size(),
|
&(shard->feature_segments), cmat.Ptrs().size(),
|
||||||
&(shard->min_fvalue), cmat.min_val.size(),
|
&(shard->min_fvalue), cmat.MinValues().size(),
|
||||||
&(shard->gidx_fvalue_map), 24,
|
&(shard->gidx_fvalue_map), 24,
|
||||||
&(shard->monotone_constraints), kNCols);
|
&(shard->monotone_constraints), kNCols);
|
||||||
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
|
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.Ptrs());
|
||||||
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
|
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.Values());
|
||||||
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
|
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
|
||||||
param.monotone_constraints);
|
param.monotone_constraints);
|
||||||
shard->ellpack_matrix.feature_segments = shard->feature_segments;
|
shard->ellpack_matrix.feature_segments = shard->feature_segments;
|
||||||
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
|
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
|
||||||
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
|
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.MinValues());
|
||||||
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
|
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
|
||||||
|
|
||||||
// Initialize DeviceShard::hist
|
// Initialize DeviceShard::hist
|
||||||
|
|||||||
@ -13,7 +13,7 @@ namespace xgboost {
|
|||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
TEST(Updater, Prune) {
|
TEST(Updater, Prune) {
|
||||||
int constexpr kNRows = 32, kNCols = 16;
|
int constexpr kNCols = 16;
|
||||||
|
|
||||||
std::vector<std::pair<std::string, std::string>> cfg;
|
std::vector<std::pair<std::string, std::string>> cfg;
|
||||||
cfg.emplace_back(std::pair<std::string, std::string>(
|
cfg.emplace_back(std::pair<std::string, std::string>(
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2018 by Contributors
|
* Copyright 2018-2019 by Contributors
|
||||||
*/
|
*/
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "../../../src/tree/param.h"
|
#include "../../../src/tree/param.h"
|
||||||
@ -46,23 +46,25 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
const size_t num_row = p_fmat->Info().num_row_;
|
const size_t num_row = p_fmat->Info().num_row_;
|
||||||
const size_t num_col = p_fmat->Info().num_col_;
|
const size_t num_col = p_fmat->Info().num_col_;
|
||||||
/* Validate HistCutMatrix */
|
/* Validate HistCutMatrix */
|
||||||
ASSERT_EQ(gmat.cut.row_ptr.size(), num_col + 1);
|
ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1);
|
||||||
for (size_t fid = 0; fid < num_col; ++fid) {
|
for (size_t fid = 0; fid < num_col; ++fid) {
|
||||||
// Each feature must have at least one quantile point (cut)
|
const size_t ibegin = gmat.cut.Ptrs()[fid];
|
||||||
const size_t ibegin = gmat.cut.row_ptr[fid];
|
const size_t iend = gmat.cut.Ptrs()[fid + 1];
|
||||||
const size_t iend = gmat.cut.row_ptr[fid + 1];
|
// Ordered, but empty feature is allowed.
|
||||||
ASSERT_LT(ibegin, iend);
|
ASSERT_LE(ibegin, iend);
|
||||||
for (size_t i = ibegin; i < iend - 1; ++i) {
|
for (size_t i = ibegin; i < iend - 1; ++i) {
|
||||||
// Quantile points must be sorted in ascending order
|
// Quantile points must be sorted in ascending order
|
||||||
// No duplicates allowed
|
// No duplicates allowed
|
||||||
ASSERT_LT(gmat.cut.cut[i], gmat.cut.cut[i + 1]);
|
ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1])
|
||||||
|
<< "ibegin: " << ibegin << ", "
|
||||||
|
<< "iend: " << iend;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Validate GHistIndexMatrix */
|
/* Validate GHistIndexMatrix */
|
||||||
ASSERT_EQ(gmat.row_ptr.size(), num_row + 1);
|
ASSERT_EQ(gmat.row_ptr.size(), num_row + 1);
|
||||||
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
|
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
|
||||||
gmat.cut.row_ptr.back());
|
gmat.cut.Ptrs().back());
|
||||||
for (const auto& batch : p_fmat->GetRowBatches()) {
|
for (const auto& batch : p_fmat->GetRowBatches()) {
|
||||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||||
const size_t rid = batch.base_rowid + i;
|
const size_t rid = batch.base_rowid + i;
|
||||||
@ -71,20 +73,20 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
ASSERT_LT(gmat_row_offset, gmat.index.size());
|
ASSERT_LT(gmat_row_offset, gmat.index.size());
|
||||||
SparsePage::Inst inst = batch[i];
|
SparsePage::Inst inst = batch[i];
|
||||||
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
|
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
|
||||||
for (size_t j = 0; j < inst.size(); ++j) {
|
for (int64_t j = 0; j < inst.size(); ++j) {
|
||||||
// Each entry of GHistIndexMatrix represents a bin ID
|
// Each entry of GHistIndexMatrix represents a bin ID
|
||||||
const size_t bin_id = gmat.index[gmat_row_offset + j];
|
const size_t bin_id = gmat.index[gmat_row_offset + j];
|
||||||
const size_t fid = inst[j].index;
|
const size_t fid = inst[j].index;
|
||||||
// The bin ID must correspond to correct feature
|
// The bin ID must correspond to correct feature
|
||||||
ASSERT_GE(bin_id, gmat.cut.row_ptr[fid]);
|
ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]);
|
||||||
ASSERT_LT(bin_id, gmat.cut.row_ptr[fid + 1]);
|
ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]);
|
||||||
// The bin ID must correspond to a region between two
|
// The bin ID must correspond to a region between two
|
||||||
// suitable quantile points
|
// suitable quantile points
|
||||||
ASSERT_LT(inst[j].fvalue, gmat.cut.cut[bin_id]);
|
ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]);
|
||||||
if (bin_id > gmat.cut.row_ptr[fid]) {
|
if (bin_id > gmat.cut.Ptrs()[fid]) {
|
||||||
ASSERT_GE(inst[j].fvalue, gmat.cut.cut[bin_id - 1]);
|
ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]);
|
||||||
} else {
|
} else {
|
||||||
ASSERT_GE(inst[j].fvalue, gmat.cut.min_val[fid]);
|
ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,11 +108,12 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
std::vector<std::vector<uint8_t>> hist_is_init;
|
std::vector<std::vector<uint8_t>> hist_is_init;
|
||||||
std::vector<ExpandEntry> nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)};
|
std::vector<ExpandEntry> nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)};
|
||||||
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, gpair, &hist_buffers, &hist_is_init);
|
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, gpair, &hist_buffers, &hist_is_init);
|
||||||
RealImpl::InitNewNode(nid, gmat, gpair, fmat, const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
|
RealImpl::InitNewNode(nid, gmat, gpair, fmat,
|
||||||
|
const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
|
||||||
EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers);
|
EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers);
|
||||||
|
|
||||||
// Check if number of histogram bins is correct
|
// Check if number of histogram bins is correct
|
||||||
ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back());
|
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
|
||||||
std::vector<GradientPairPrecise> histogram_expected(hist_[nid].size());
|
std::vector<GradientPairPrecise> histogram_expected(hist_[nid].size());
|
||||||
|
|
||||||
// Compute the correct histogram (histogram_expected)
|
// Compute the correct histogram (histogram_expected)
|
||||||
@ -126,7 +129,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now validate the computed histogram returned by BuildHist
|
// Now validate the computed histogram returned by BuildHist
|
||||||
for (size_t i = 0; i < hist_[nid].size(); ++i) {
|
for (int64_t i = 0; i < hist_[nid].size(); ++i) {
|
||||||
GradientPairPrecise sol = histogram_expected[i];
|
GradientPairPrecise sol = histogram_expected[i];
|
||||||
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
|
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
|
||||||
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
|
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
|
||||||
@ -140,7 +143,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||||
size_t constexpr kMaxBins = 4;
|
size_t constexpr kMaxBins = 4;
|
||||||
auto dmat = CreateDMatrix(kNRows, kNCols, 0, 3);
|
auto dmat = CreateDMatrix(kNRows, kNCols, 0, 3);
|
||||||
// dense, no missing values
|
// dense, no missing values
|
||||||
|
|
||||||
common::GHistIndexMatrix gmat;
|
common::GHistIndexMatrix gmat;
|
||||||
gmat.Init((*dmat).get(), kMaxBins);
|
gmat.Init((*dmat).get(), kMaxBins);
|
||||||
@ -152,7 +155,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
std::vector<std::vector<float*>> hist_buffers;
|
std::vector<std::vector<float*>> hist_buffers;
|
||||||
std::vector<std::vector<uint8_t>> hist_is_init;
|
std::vector<std::vector<uint8_t>> hist_is_init;
|
||||||
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init);
|
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init);
|
||||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
|
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat),
|
||||||
|
const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
|
||||||
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
|
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
|
||||||
|
|
||||||
/* Compute correct split (best_split) using the computed histogram */
|
/* Compute correct split (best_split) using the computed histogram */
|
||||||
@ -178,8 +182,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
size_t best_split_feature = std::numeric_limits<size_t>::max();
|
size_t best_split_feature = std::numeric_limits<size_t>::max();
|
||||||
// Enumerate all features
|
// Enumerate all features
|
||||||
for (size_t fid = 0; fid < num_feature; ++fid) {
|
for (size_t fid = 0; fid < num_feature; ++fid) {
|
||||||
const size_t bin_id_min = gmat.cut.row_ptr[fid];
|
const size_t bin_id_min = gmat.cut.Ptrs()[fid];
|
||||||
const size_t bin_id_max = gmat.cut.row_ptr[fid + 1];
|
const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1];
|
||||||
// Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible
|
// Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible
|
||||||
// choice of thresholds for feature fid
|
// choice of thresholds for feature fid
|
||||||
for (size_t split_thresh = bin_id_min;
|
for (size_t split_thresh = bin_id_min;
|
||||||
@ -217,7 +221,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
|
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
|
||||||
|
|
||||||
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
||||||
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]);
|
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||||
|
|
||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user